In [42]:
! pip install pytorch-lightning

Collecting pytorch-lightning
  Downloading pytorch_lightning-1.4.7-py3-none-any.whl (923 kB)
[?25l[K     |▍                               | 10 kB 23.2 MB/s eta 0:00:01[K     |▊                               | 20 kB 29.5 MB/s eta 0:00:01[K     |█                               | 30 kB 22.6 MB/s eta 0:00:01[K     |█▍                              | 40 kB 18.1 MB/s eta 0:00:01[K     |█▊                              | 51 kB 8.2 MB/s eta 0:00:01[K     |██▏                             | 61 kB 8.6 MB/s eta 0:00:01[K     |██▌                             | 71 kB 7.3 MB/s eta 0:00:01[K     |██▉                             | 81 kB 8.2 MB/s eta 0:00:01[K     |███▏                            | 92 kB 8.3 MB/s eta 0:00:01[K     |███▌                            | 102 kB 7.7 MB/s eta 0:00:01[K     |████                            | 112 kB 7.7 MB/s eta 0:00:01[K     |████▎                           | 122 kB 7.7 MB/s eta 0:00:01[K     |████▋                           | 133 kB 7.7

In [43]:
import torch
import torch.nn as nn
from typing import List, Union
import pytorch_lightning as pl

In [2]:
""" 
Information about architecture config:
- Tuple is structured by (kernel_size, filters, stride, padding) 
- "M" is simply maxpooling with stride 2x2 and kernel 2x2
- List is structured by tuples and lastly int with number of repeats
"""

architecture_config = [
    (7, 64, 2, 3),
    "M",
    (3, 192, 1, 1),
    "M",
    (1, 128, 1, 0),
    (3, 256, 1, 1),
    (1, 256, 1, 0),
    (3, 512, 1, 1),
    "M",
    [(1, 256, 1, 0), (3, 512, 1, 1), 4],
    (1, 512, 1, 0),
    (3, 1024, 1, 1),
    "M",
    [(1, 512, 1, 0), (3, 1024, 1, 1), 2],
    (3, 1024, 1, 1),
    (3, 1024, 2, 1),
    (3, 1024, 1, 1),
    (3, 1024, 1, 1),
]

In [39]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(CNNBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.leakyrelu = nn.LeakyReLU(0.1)

    def forward(self, x):
        return self.leakyrelu(self.batchnorm(self.conv(x)))


class SimpleCNN(nn.Module):
    def __init__(
        self,
        architecture: List[Union[tuple, str, list]],
        in_channels: int,
    ):
        super(SimpleCNN, self).__init__()
        layers = []
        for module in architecture:
            if type(module) is tuple:
                layers.append(self._get_cnn_block(module, in_channels))
                in_channels = module[1]
            elif module == 'M':
                layers.append(nn.MaxPool2d(
                    kernel_size=(2, 2),
                    stride=(2, 2),
                ))
            elif type(module) is list:
                for i in range(module[-1]):
                    for j in range(len(module) - 1):
                        layers.append(self._get_cnn_block(module[j], in_channels))
                        in_channels = module[j][1]
        self.model = nn.Sequential(*layers)

    @staticmethod
    def _get_cnn_block(module: tuple, in_channels):
        kernel_size, filters, stride, padding = module
        return CNNBlock(
            in_channels,
            filters,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )

    def forward(self, x):
        return self.model(x)


class YoloV1(nn.Module):
    def __init__(self, in_channels, split_size, num_boxes, num_classes):
        super(YoloV1, self).__init__()
        self.darknet = SimpleCNN(architecture_config, in_channels)
        
        S, B, C = split_size, num_boxes, num_classes
        self.fcs = nn.Sequential(
            nn.Flatten(),
            
            nn.Linear(1024 * S * S, 496),
            nn.Dropout(0.1),
            nn.LeakyReLU(0.1),
            
            nn.Linear(496, S * S * (C + B * 5)),
        )
        self.final_shape = (-1, S, S, (C + B * 5))

    def forward(self, x):
        x = self.darknet(x)
        out = self.fcs(torch.flatten(x, start_dim=1))
        out = out.view(self.final_shape)
        return out

In [40]:
model = YoloV1(in_channels=3, split_size=7, num_boxes=2, num_classes=21)
random_batch = torch.rand((2, 3, 448, 448))
random_output = model(random_batch)
random_output.shape

torch.Size([2, 1519])
torch.Size([2, 7, 7, 31])
torch.Size([2, 7, 7, 31])
(-1, 7, 7, 31)


torch.Size([2, 7, 7, 31])

In [None]:
class YoloV1Loss(nn.Module):
    def __init__(self, num_boxes: int, num_classes: int):
        super().__init__()
        
        self.num_boxes = num_boxes
        self.num_classes = num_classes

    def forward(self, pred, true):
        pass


In [None]:
class YoloV1PL(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.yolo_v1 = YoloV1(
            in_channels=3,
            split_size=7,
            num_boxes=2,
            num_classes=21
        )

    def forward(self, x):
        return self.yolo_v1(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)    
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('val_loss', loss)
    
    
