In [5]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch import Tensor

from torchvision import transforms, datasets
from torchvision.ops import nms
import torchvision.transforms.functional as fn
import torchmetrics

from torch.utils.data import DataLoader

In [6]:
import wandb
WANDB_LOGGING = False
FREEZE_FEATURE_EXTRACTOR = True
CONFIG = {
    "project_name": "name",
    "dataloader": {
        "batch_size": 32
    },
    "bias": True,
    "lr": 0.0001
}

In [7]:
class CardDetector(nn.Module):
    def __init__(self, num_cells):
        super(CardDetector, self).__init__()

        self.num_cells = num_cells
        
        self.feature_extractor = models.resnet18(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(self.feature_extractor.children())[:-2])
        if FREEZE_FEATURE_EXTRACTOR:
            for param in self.feature_extractor.parameters():
                param.requires_grad = False
        
        self.detection_head = nn.Sequential(
            nn.Conv2d(512 * 7 * 7, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, self.num_cells * self.num_cells * 5, kernel_size=1)
        )

    def forward(self, input):
        print(f"input: {input.shape}")
        features = self.feature_extractor(input)
        print(f"features: {features.shape}")

        detection = self.detection_head(features)
        print(f"detection: {detection.shape}")
        detection = detection.view(-1, self.num_cells, self.num_cells, 5)
        print(f"detection: {detection.shape}")

        detection[:, :, 0] = torch.sigmoid(detection[:, 0, :1])
        
        return detection

In [8]:
input_tensor = torch.randn(2, 3, 224, 224)

model = CardDetector(num_cells=4)
model.eval()

output_tensor = model(input_tensor)
print(output_tensor)



input: torch.Size([2, 3, 224, 224])
features: torch.Size([2, 512, 7, 7])
detection: torch.Size([2, 80, 7, 7])
detection: torch.Size([98, 4, 4, 5])
tensor([[[[ 0.4989,  0.4733,  0.4711,  0.4709,  0.4839],
          [-0.0770, -0.0300, -0.0981, -0.1603, -0.1365],
          [-0.0797, -0.0961, -0.0783, -0.0081, -0.2310],
          [-0.2456, -0.1857, -0.1283, -0.1168, -0.0495]],

         [[ 0.4989,  0.4733,  0.4711,  0.4709,  0.4839],
          [-0.0319,  0.0114, -0.0092, -0.1977, -0.2036],
          [-0.1989, -0.2077, -0.0897, -0.1439, -0.1085],
          [-0.1779, -0.1860, -0.1008, -0.1315, -0.0086]],

         [[ 0.4989,  0.4733,  0.4711,  0.4709,  0.4839],
          [ 0.0263,  0.0444,  0.0090, -0.0127,  0.0638],
          [-0.1440, -0.1632, -0.1389, -0.0470, -0.1371],
          [-0.1456,  0.2172,  0.0686,  0.0333,  0.1153]],

         [[ 0.4989,  0.4733,  0.4711,  0.4709,  0.4839],
          [ 0.0360,  0.1366,  0.2164,  0.1828, -0.0227],
          [ 0.1802,  0.1220,  0.0747,  0.1417,  0