In [1]:
!pip install pycocotools
!pip install scikit-image



In [2]:
import torch.nn as nn, numpy as np, torchvision.transforms.v2 as T
import torch, dotenv, os

from torch.utils.data import Dataset, DataLoader, Subset
from matplotlib import pyplot as plt
from pycocotools.coco import COCO
from pathlib import Path
from skimage import io

dotenv.load_dotenv()
data_path = Path(os.getenv("DATAPATH"))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
batch_size = 1
categories = ['person', 'car', 'horse'] # Example categories

In [4]:
class DS(Dataset):
    def __init__(self, annotation_file, categories=['horse']):
        super().__init__()
        self.coco = COCO(data_path/"annotations"/annotation_file)

        self.category_ids = self.coco.getCatIds(categories)
        self.image_ids = self.coco.getImgIds(catIds=self.category_ids)

        self.T = T.Compose([
            T.ToImage(),
            T.ToDtype(torch.float32, scale=True)
        ])
        self.P = T.Pad(padding=30, padding_mode='symmetric')
        self.P2 = T.Pad(padding=30)

    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, index) -> tuple[torch.Tensor, list[torch.Tensor]]:
        img_id = [self.image_ids[index]]
        annotation_ids = self.coco.getAnnIds(imgIds=img_id, catIds=self.category_ids, iscrowd=None)

        return (
            self.P(self.T(io.imread(data_path/"val2017"/self.coco.loadImgs(img_id)[0]['file_name']))),
            torch.stack([self.P2(self.T(self.coco.annToMask(ann))) for ann in self.coco.loadAnns(annotation_ids)], dim=0)
        )

In [5]:
ds_val = DS('instances_val2017.json')
dl_val = DataLoader(ds_val, batch_size)

loading annotations into memory...
Done (t=0.65s)
creating index...
index created!


In [6]:
def crop(tensor: torch.Tensor, target):
    offset_x = (tensor.shape[-2] - target[-2]) // 2
    offset_y = (tensor.shape[-1] - target[-1]) // 2
    return tensor[:, :, offset_x:(offset_x+target[-2]), offset_y:(offset_y+target[-1])]

class UNet(nn.Module):
    def __init__(self, categories):
        super().__init__()
        s = nn.Sequential

        self.contracting_layers = nn.ModuleList([
            s(
                nn.Conv2d(3, 64, 3),
                nn.ReLU(),
                nn.Conv2d(64, 64, 3),
                nn.ReLU()
            ),
            s(
                nn.MaxPool2d(2),
                nn.Conv2d(64, 128, 3),
                nn.ReLU(),
                nn.Conv2d(128, 128, 3),
                nn.ReLU(),
            ),
            s(
                nn.MaxPool2d(2),
                nn.Conv2d(128, 256, 3),
                nn.ReLU(),
                nn.Conv2d(256, 256, 3),
                nn.ReLU()
            ),
            s(
                nn.MaxPool2d(2),
                nn.Conv2d(256, 512, 3),
                nn.ReLU(),
                nn.Conv2d(512, 512, 3),
                nn.ReLU()
            )
        ])

        self.center = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(512, 1024, 3),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, 3),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, 2, 2)
        )

        self.expanding_layers = nn.ModuleList([
            s(
                nn.Conv2d(1024, 512, 3),
                nn.ReLU(),
                nn.Conv2d(512, 512, 3),
                nn.ReLU(),
                nn.ConvTranspose2d(512, 256, 2, 2)
            ),
            s(
                nn.Conv2d(512, 256, 3),
                nn.ReLU(),
                nn.Conv2d(256, 256, 3),
                nn.ReLU(),
                nn.ConvTranspose2d(256, 128, 2, 2)
            ),
            s(
                nn.Conv2d(256, 128, 3),
                nn.ReLU(),
                nn.Conv2d(128, 128, 3),
                nn.ReLU(),
                nn.ConvTranspose2d(128, 64, 2, 2)
            ),
            s(
                nn.Conv2d(128, 64, 3),
                nn.ReLU(),
                nn.Conv2d(64, 64, 3),
                nn.ReLU(),
            ),
        ])

        self.final = nn.Conv2d(64, categories + 1, 1) # +1 for background

    def forward(self, x):
        c_layer_outputs = []

        for layer in self.contracting_layers:
            x = layer(x)
            c_layer_outputs.append(x)

        x = self.center(c_layer_outputs[-1])

        for i, layer in enumerate(self.expanding_layers, 1):
            x = layer(torch.concat(
                [crop(c_layer_outputs[-i], x.shape), x],
                dim=1
            ))
        
        return self.final(x)

In [7]:
model = UNet(len(categories)).to(device)

img, masks = next(iter(dl_val))
img = img.to(device); masks = list(map(lambda m: m.to(device), masks))
y = model(img)
print(y.shape)
print(img.shape)
print(masks[0].shape)
print(crop(masks[0], y.shape).shape)

torch.Size([1, 4, 356, 516])
torch.Size([1, 3, 540, 700])
torch.Size([1, 1, 540, 700])
torch.Size([1, 1, 356, 516])


In [9]:
class Base_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        s = nn.Sequential

        self.layers = nn.ModuleList([
            s(
                nn.Conv2d(3, 64, 3),
                nn.ReLU(),
                nn.Conv2d(64, 64, 3),
                nn.ReLU()
            ),
            s(
                nn.MaxPool2d(2),
                nn.Conv2d(64, 128, 3),
                nn.ReLU(),
                nn.Conv2d(128, 128, 3),
                nn.ReLU(),
            ),
            s(
                nn.MaxPool2d(2),
                nn.Conv2d(128, 256, 3),
                nn.ReLU(),
                nn.Conv2d(256, 256, 3),
                nn.ReLU()
            ),
            s(
                nn.MaxPool2d(2),
                nn.Conv2d(256, 512, 3),
                nn.ReLU(),
                nn.Conv2d(512, 512, 3),
                nn.ReLU()
            )
        ])

        # output layer
        self.output = nn.Conv2d(512, 2, kernel_size=1)
        self.activation = nn.Softmax(dim=1)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.output(x)
        x = self.activation(x)
        return x

base_model = Base_CNN().to(device)
base_model(img).shape

torch.Size([1, 2, 60, 80])