In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch import Tensor

from torchvision import transforms, datasets
import torchvision.transforms.functional as fn
import torchmetrics

from torch.utils.data import DataLoader

In [None]:
import wandb
WANDB_LOGGING = False
FREEZE_FEATURE_EXTRACTOR = True
CONFIG = {
    "project_name": "name",
    "dataloader": {
        "batch_size": 32
    },
    "bias": True,
    "lr": 0.0001
}

In [None]:
class CardDetector(nn.Module):
    def __init__(self, num_cells, num_anchors):
        super(CardDetector, self).__init__()

        self.num_cells = num_cells
        self.num_anchors = num_anchors
        
        self.feature_extractor = models.resnet18(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(self.feature_extractor.children())[:-1])
        if FREEZE_FEATURE_EXTRACTOR:
            for param in self.feature_extractor.parameters():
                param.requires_grad = False
        
        self.detection_head = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, self.num_cells * self.num_anchors * 5, kernel_size=1)
        )

    def forward(self, input):
        features = self.feature_extractor(input)

        detection = self.detection_head(features)
        detection = detection.permute(0, 2, 3, 1)
        
        detection = detection.view(-1, self.num_cells * self.num_anchors, 5)
        detection[:, :, 0] = torch.sigmoid(detection[:, :, 0])
        
        return detection

In [None]:
class CardDetectorMultiBox(nn.Module):
    def __init__(self, num_anchors, num_cells, max_boxes=5):
        super(CardDetectorMultiBox, self).__init__()

        self.num_anchors = num_anchors
        self.num_cells = num_cells
        self.max_boxes = max_boxes
        
        self.detector = CardDetector(num_anchors=num_anchors, num_cells=num_cells)

        self.box_regression_head = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, max_boxes * 4, kernel_size=1),
            nn.Flatten()
        )

    def forward(self, input):
        
        features = self.detector.feature_extractor(input)

        detection = self.detector(input)
        box_regression = self.box_regression_head(features)


        print(f"detection:{detection.shape}, box_reg:{box_regression.shape}")

     
        boxes = []
        return boxes


In [None]:
from tqdm.auto import tqdm  # We use tqdm to display a simple progress bar, allowing us to observe the learning progression.

def fit(
  model: nn.Module,
  num_epochs: int,
  optimizer: torch.optim.Optimizer,
  train_dataloader: DataLoader,
  val_dataloader: DataLoader,
  print_rate: int = 100
  ):
    # TODO: figure out accuacy
    #accuracy = torchmetrics.Accuracy(task='multiclass', average="weighted").to(model.device)
    accuracy = None
    model = model.to(model.device)
    # Iterate through epochs with tqdm
    for epoch in tqdm(range(num_epochs)):
        print(f"Epoch: {epoch}\n")
        train_loss = 0
        model.train()  # Set mode of model to train
        
        for batch, (X, y) in enumerate(train_dataloader):
            loss = model.train_step(X, y)
            train_loss += loss.item()
            
            # Getting the loss gradient and making an optimizer step
            optimizer.zero_grad()  
            loss.backward()
            optimizer.step()

            if batch % print_rate == 0: 
                print(f"Looked at {batch} Batches\t---\t{batch * len(X)}/{len(train_dataloader.dataset)} Samples")
            elif batch == len(train_dataloader) - 1:
                print(f"Looked at {batch} Batches\t---\t{len(train_dataloader.dataset)}/{len(train_dataloader.dataset)} Samples")
        
        # Divide the train_loss by the number of batches to get the average train_loss
        avg_train_loss = train_loss / len(train_dataloader)

        # Validation
        # Setup the Val Loss and Accuracy to accumulate over the batches in the val dataset
        val_loss = 0
        val_acc = 0
        # Set model to evaluation mode and use torch.inference_mode to remove unnecessary training operations 
        model.eval()
        with torch.inference_mode():
            for X_val, y_val in val_dataloader:
                loss, acc = model.val_step(X_val, y_val, accuracy)
                val_loss += loss.item()
                val_acc += acc

        # Get the average Val Loss and Accuracy
        avg_val_loss = val_loss / len(val_dataloader)
        avg_val_acc = val_acc / len(val_dataloader)

        print(f"Train loss: {avg_train_loss} | Val Loss: {avg_val_loss} | Val Accuracy: {avg_val_acc}")
        if WANDB_LOGGING:
            wandb.log({"Train Loss": avg_train_loss,"Val Loss": avg_val_loss, "Val Accuracy": avg_val_acc})

In [None]:
input_tensor = torch.randn(2, 3, 224, 224)

model = CardDetectorMultiBox(num_cells=16, num_anchors=2)
model.eval()

output_tensor = model(input_tensor)