In [1]:
import os
from PIL import Image
from typing import Literal

import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

## Define Dataset

In [2]:
class BCSSDataset(Dataset):
    SIZE=(224, 224)
    def __init__(self, path: str, split: Literal['train', 'val', 'test'] = 'train'):
        path = os.path.abspath(path)
        image_path = os.path.join(path, split)
        self.images = [os.path.join(image_path, filename) for filename in os.listdir(image_path)]
        mask_path = os.path.join(path, f'{split}_mask')
        self.masks = [os.path.join(mask_path, filename) for filename in os.listdir(mask_path)]

        self.transformer = transforms.Compose([
            transforms.Resize(self.SIZE),
            transforms.ToTensor()
        ])

        # TODO: handle test set

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        mask = Image.open(self.masks[idx]).convert('L')

        return self.transformer(image), self.transformer(mask)

In [3]:
data_path = './data/bcss'
train_split = BCSSDataset(data_path, 'train')
val_split = BCSSDataset(data_path, 'val')

## Define Unet

In [4]:
class DoubleConv(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 mid_channels: int = None,
                 kernel_size: int = 2,
                 stride: int = 1,
                 padding: int = 0):
        
        super().__init__()

        mid_channels = mid_channels or out_channels
        self.conv_ops = nn.Sequential(
            # first 
            nn.Conv2d(in_channels=in_channels,
                      out_channels=mid_channels,
                      kernel_size=kernel_size,
                      padding=padding,
                      stride=stride),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=mid_channels),

            # second
            nn.Conv2d(in_channels=mid_channels,
                      out_channels=out_channels,
                      kernel_size=kernel_size,
                      padding=padding,
                      stride=stride),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=out_channels)
        )

    def forward(self, X):
        return self.conv_ops(X)

In [5]:
class DownSample(nn.Module):
    def __init__(self,
                 kernel_size: int = 2,
                 stride: int = 1,
                 padding: int = 0):
        super().__init__()
        
        self.pool = nn.MaxPool2d(kernel_size, stride, padding)
        
    def forward(self, X):
        return self.pool(X)

In [6]:
class UpSample(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int = 2,
                 stride: int = 1,
                 padding: int = 0):
        
        super().__init__()

        self.up_conv = nn.ConvTranspose2d(in_channels=in_channels,
                                          out_channels=out_channels,
                                          kernel_size=kernel_size,
                                          stride=stride,
                                          padding=padding)
        
    def forward(self, X):
        return self.up_conv(X)

In [7]:
class CropAndConcat(nn.Module):
    def forward(self, X, contracting_X):
        contracting_X = transforms.functional.center_crop(
            img=contracting_X,
            output_size=(X.shape[2], X.shape[3])
        )
        X = torch.cat((X, contracting_X), dim=1)

        return X

In [8]:
class Unet(nn.Module):
    # TODO: Customize the conv blocks for easy-scalable
    def __init__(self,
                 in_channels: int,
                 output_classes: int,
                 down_conv_kwargs: dict = None,
                 down_sample_kwargs: dict = None,
                 up_conv_kwargs: dict = None,
                 up_sample_kwargs: dict = None,
                 expansive_kwargs: dict = None):
        super().__init__()

        self.down_conv = nn.ModuleList([
            DoubleConv(in_channels=i, out_channels=o, **(down_conv_kwargs or {})) for i, o in ((in_channels, 64), (64, 128), (128, 256), (256, 512))
        ])

        self.down_sample = nn.ModuleList([
            DownSample(**(down_sample_kwargs or {})) for _ in range(4)
        ])

        self.up_conv = nn.ModuleList([
            DoubleConv(in_channels=i, out_channels=o, **(up_conv_kwargs or {})) for i, o in ((1024, 512), (512, 256), (256, 128), (128, 64))
        ])

        self.up_sample = nn.ModuleList([
            UpSample(in_channels=i, out_channels=o, **(up_sample_kwargs or {})) for i, o in ((1024, 512), (512, 256), (256, 128), (128, 64))
        ])

        self.crop_concat = nn.ModuleList([CropAndConcat() for _ in range(4)])

        self.bottlekneck = DoubleConv(in_channels=512,
                                      out_channels=1024,
                                      **(up_conv_kwargs or {}))
        
        self.output = nn.Conv2d(in_channels=64, out_channels=output_classes, kernel_size=1)
        
    def forward(self, X):
        pass_through = []
        for i in range(len(self.down_conv)):
            X = self.down_conv[i](X)
            pass_through = [X] + pass_through
            X = self.down_sample[i](X)

        X = self.bottlekneck(X)

        for i in range(len(self.up_conv)):
            X = self.up_sample[i](X)
            X = self.crop_concat[i](X, pass_through[i])
            print(X.shape)
            X = self.up_conv[i](X)

        X = self.output(X)

        return X