In [706]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch import Tensor

from torchvision import transforms, datasets
from torchvision.ops import nms
import torchvision.transforms.functional as fn
import torchmetrics

from torch.utils.data import DataLoader

In [707]:
import wandb
WANDB_LOGGING = False
FREEZE_FEATURE_EXTRACTOR = True
CONFIG = {
    "project_name": "name",
    "dataloader": {
        "batch_size": 32
    },
    "bias": True,
    "lr": 0.0001
}

In [708]:
def intersection_over_union(box1_cx, box1_cy, box1_w, box1_h, box2_cx, box2_cy, box2_w, box2_h):
    area1 = box1_w * box1_h
    area2 = box2_w * box2_h

    top_left1 = (box1_cx - (box1_w/2), box1_cy - (box1_h/2))
    top_left2 = (box2_cx - (box2_w/2), box2_cy - (box2_h/2))

    bottom_right1 = (box1_cx + (box1_w/2), box1_cy + (box1_h/2))
    bottom_right2 = (box2_cx + (box2_w/2), box2_cy + (box2_h/2))

    xx = max(top_left1[0], top_left2[0])
    yy = max(top_left1[1], top_left2[1])
    aa = min(bottom_right1[0], bottom_right2[0])
    bb = min(bottom_right1[1], bottom_right2[1])

    w = max(0, aa - xx)
    h = max(0, bb - yy)

    intersection_area = w*h
    union_area = area1 + area2 - intersection_area

    print(f"top_left1:{top_left1}, bottom_right1:{bottom_right1}, top_left2:{top_left2}, bottom_right2:{bottom_right2}")

    return intersection_area / union_area

In [709]:
box1 = (3,3, 3, 2)
box2 = (4,4, 1, 1)
intersection_over_union(box1[0],box1[1],box1[2],box1[3], box2[0], box2[1], box2[2], box2[3])

top_left1:(1.5, 2.0), bottom_right1:(4.5, 4.0), top_left2:(3.5, 3.5), bottom_right2:(4.5, 4.5)


0.07692307692307693

In [710]:
class CardDetector(nn.Module):
    def __init__(self, num_cells, num_anchors):
        super(CardDetector, self).__init__()

        self.num_cells = num_cells
        self.num_anchors = num_anchors
        
        self.feature_extractor = models.resnet18(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(self.feature_extractor.children())[:-1])
        if FREEZE_FEATURE_EXTRACTOR:
            for param in self.feature_extractor.parameters():
                param.requires_grad = False
        
        self.detection_head = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, self.num_cells * self.num_anchors * 5, kernel_size=1)
        )

    def forward(self, input):
        features = self.feature_extractor(input)
        print(features.shape)

        detection = self.detection_head(features)
        print(detection.shape)
        detection = detection.permute(0, 2, 3, 1)
        print(detection.shape)
        detection = detection.view(-1, self.num_cells * self.num_anchors, 5)
        print(detection.shape)
        detection[:, :, 0] = torch.sigmoid(detection[:, 0, :1])
        
        return detection

In [711]:
class CardDetectorMultiBox(nn.Module):
    def __init__(self, num_anchors, num_cells, max_boxes=5):
        super(CardDetectorMultiBox, self).__init__()

        self.num_anchors = num_anchors
        self.num_cells = num_cells
        self.max_boxes = max_boxes
        
        self.detector = CardDetector(num_anchors=num_anchors, num_cells=num_cells)

        self.box_regression_head = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, max_boxes * 4, kernel_size=1),
            nn.Flatten()
        )
        self.attention = nn.Sequential(
            nn.Linear(max_boxes, 256),
            nn.ReLU(),
            nn.Linear(256, max_boxes),
            nn.Softmax(dim=1)
        )

    def forward(self, input):
        
        features = self.detector.feature_extractor(input)

        detection = self.detector(input)
        box_regression = self.box_regression_head(features)

        #print(f"detection:{detection.shape}, box_reg:{box_regression.shape}")

        attention_weights = self.attention(detection)
        #print(attention_weights)
        #num_boxes_float = torch.sum(attention_weights, dim=1)
        #print(num_boxes_float)
        #num_boxes_float = torch.clamp(num_boxes_float, min=1)
        #print(num_boxes_float)
        #num_boxes_float = num_boxes_float / torch.sum(num_boxes_float, dim=0, keepdim=True)
        #print(num_boxes_float)
        #num_boxes = (num_boxes_float * self.max_boxes).round().to(torch.int)

        #print(f"num_boxes:{num_boxes}")

        print(f"detection:{detection.shape}, box_reg:{box_regression.shape}, attention_weights:{attention_weights.shape}") #, num_boxes:{num_boxes.shape}")

        #detection_scores = detection[:, :, :1]
        #_, topk_indices = torch.topk(detection_scores, k=self.max_boxes, dim=1)
        #detection_topk = torch.gather(detection, dim=1, index=topk_indices)
#
        #print(f"detection_scores:{detection_scores.shape}, topk_indices:{topk_indices.shape}, detection_topk:{detection_topk.shape}")
        #print(f"box_regression:{box_regression.shape}")

        #box_regression_topk = torch.gather(box_regression, dim=1, index=topk_indices)
        
        boxes = detection[:,0,1:]
        scores = detection[:,0, :1]
        print(boxes.shape)

        
        box_indices = nms(boxes=boxes, scores=scores, iou_threshold=0.2)
        print(box_indices)
        detection = detection[box_indices]

        return detection


In [712]:
from tqdm.auto import tqdm  # We use tqdm to display a simple progress bar, allowing us to observe the learning progression.

def fit(
  model: nn.Module,
  num_epochs: int,
  optimizer: torch.optim.Optimizer,
  train_dataloader: DataLoader,
  val_dataloader: DataLoader,
  print_rate: int = 100
  ):
    # TODO: figure out accuacy
    #accuracy = torchmetrics.Accuracy(task='multiclass', average="weighted").to(model.device)
    accuracy = None
    model = model.to(model.device)
    # Iterate through epochs with tqdm
    for epoch in tqdm(range(num_epochs)):
        print(f"Epoch: {epoch}\n")
        train_loss = 0
        model.train()  # Set mode of model to train
        
        for batch, (X, y) in enumerate(train_dataloader):
            loss = model.train_step(X, y)
            train_loss += loss.item()
            
            # Getting the loss gradient and making an optimizer step
            optimizer.zero_grad()  
            loss.backward()
            optimizer.step()

            if batch % print_rate == 0: 
                print(f"Looked at {batch} Batches\t---\t{batch * len(X)}/{len(train_dataloader.dataset)} Samples")
            elif batch == len(train_dataloader) - 1:
                print(f"Looked at {batch} Batches\t---\t{len(train_dataloader.dataset)}/{len(train_dataloader.dataset)} Samples")
        
        # Divide the train_loss by the number of batches to get the average train_loss
        avg_train_loss = train_loss / len(train_dataloader)

        # Validation
        # Setup the Val Loss and Accuracy to accumulate over the batches in the val dataset
        val_loss = 0
        val_acc = 0
        # Set model to evaluation mode and use torch.inference_mode to remove unnecessary training operations 
        model.eval()
        with torch.inference_mode():
            for X_val, y_val in val_dataloader:
                loss, acc = model.val_step(X_val, y_val, accuracy)
                val_loss += loss.item()
                val_acc += acc

        # Get the average Val Loss and Accuracy
        avg_val_loss = val_loss / len(val_dataloader)
        avg_val_acc = val_acc / len(val_dataloader)

        print(f"Train loss: {avg_train_loss} | Val Loss: {avg_val_loss} | Val Accuracy: {avg_val_acc}")
        if WANDB_LOGGING:
            wandb.log({"Train Loss": avg_train_loss,"Val Loss": avg_val_loss, "Val Accuracy": avg_val_acc})

In [713]:
input_tensor = torch.randn(2, 3, 224, 224)

model = CardDetector(num_cells=4, num_anchors=4)
model.eval()

output_tensor = model(input_tensor)

torch.Size([2, 512, 1, 1])
torch.Size([2, 80, 1, 1])
torch.Size([2, 1, 1, 80])
torch.Size([2, 16, 5])
