
Notes:
- Technically, UNet can be used on images of various input images. Would you recommend doing the same transforms? 
- What does torch.modulelist do? What is it similar to in tensor flow?
    - An `nn.module` is usually a layer / group of layers. `torch.modulelist` is a container for that, and it can be indexed like a regular list. But in the meantime these are layers are registered in the model for calls like `model.to()`, `model.train()`, `model.eval()`

- [ConvTranspose2d](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#convtranspose2d) - TODO
    ```python
    conv_transpose = nn.ConvTranspose2d(in_channels=3, out_channels=1, kernel_size=2, stride=2, padding=0)
    input_tensor = torch.randn(1, 3, 5, 5)  # 1 batch, 3 channels, 5x5 image
    output_tensor = conv_transpose(input_tensor)
    output_tensor.shape
    ```
- How does concat work? horizontally: `torch.cat((x, x, x), 1)`, vertically: `torch.cat((x, x, x), 0)`

- How to get model summary in pytorch? 
```
pip install torchsummary
from torchsummary import summary
# Assuming `model` is your neural network
summary(model, input_size=(3, 224, 224))  # For an input image of size 224x224 with 3 channels (RGB)
```
- Here, we need to do `crop()` actually. The best way is `from torchvision.transforms import CenterCrop`

- increase dimension at dim=0: `new_t = t.unsqueeze(0)`

- What do you do when the final output is `torch.Size([1, 16, 216, 216])`, while your input is `torch.Size([1, 3, 256, 256])`

In [18]:
import torch
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import Dataset
from torchvision.transforms import v2, CenterCrop
import matplotlib.pyplot as plt
from torchvision.transforms.functional import InterpolationMode
from functools import cached_property

DATA_DIR='./data'

def replace_tensor_val(tensor, a, b):
    tensor[tensor==a] = b
    return tensor

image_seg_transforms = transforms.Compose([
   v2.Resize((256, 256)),
   v2.ToTensor(),
   v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

target_seg_transforms = transforms.Compose([
    v2.Resize((256, 256), interpolation=InterpolationMode.NEAREST),
    v2.PILToTensor(),
    v2.Lambda(lambda x: replace_tensor_val(x.long(), 255, 21)),
])

class VOCSegmentationClass(Dataset):
   def __init__(self, image_set): 
        # Load PASCAL VOC 2007 dataset for segmentation
        self._dataset = datasets.VOCSegmentation(
            root=DATA_DIR,  # Specify where to store the data
            year='2007',    # Specify the year of the dataset (2007 in this case)
            image_set=image_set,  # You can use 'train', 'val', or 'trainval'
            download=True,  # Automatically download if not available
            transform=image_seg_transforms,  # Apply transformations to the images
            target_transform=target_seg_transforms  # Apply transformations to the masks
        )
        self._classes = set()
   @cached_property
   def classes(self):
       if len(self._classes) == 0: 
           for image, target in self._dataset:
            self._classes.update(torch.unique(target).tolist())
       return self._classes
   def __getitem__(self, index): 
       # return an image and a label. In this case, a label is an image with int8 values
       return self._dataset[index]
       # TODO: more transforms?
   def __len__(self):
        return len(self._dataset)

train_dataset = VOCSegmentationClass(image_set='train')
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = 16,
    shuffle=True,
    num_workers = 2,
    pin_memory = True
)

test_dataset = VOCSegmentationClass(image_set='test')
test_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = 16,
    shuffle=True,
    num_workers = 2,
    pin_memory = True
)



Using downloaded and verified file: ./data/VOCtrainval_06-Nov-2007.tar
Extracting ./data/VOCtrainval_06-Nov-2007.tar to ./data
Using downloaded and verified file: ./data/VOCtest_06-Nov-2007.tar
Extracting ./data/VOCtest_06-Nov-2007.tar to ./data


In [3]:
# # print("classes: ", train_dataset.classes)
# for image, target in train_dataset:
# #     # To see what our data looks like
# #     # # See torch.Size([3, 281, 500]) torch.Size([1, 281, 500])
#     # # print(image.shape, target.shape)

# #     plt.subplot(1,2,1)
# #     # Making channels the last dimension
# #     plt.imshow(image.permute(1,2,0))
# #     plt.title('image')

# #     plt.subplot(1,2,2)
# #     # Making channels the last dimension
# #     plt.imshow(target.permute(1,2,0))
# #     plt.title('mask')
# #     # See tensor([  0,   1,  15, 255], dtype=torch.uint8)
# #     print("unique: ", torch.unique(target[0]))
# #     plt.show()
# #     break
       

In [24]:
# This is a regular conv block
from torch import nn as nn
from collections import deque

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
        )
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=3,
        )
    def forward(self, x):
        return self.conv2(self.relu(self.conv1(x)))

class Encoder(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        # This should include the bottleneck.
        self._layers = nn.ModuleList([ConvBlock(in_channels[i], in_channels[i+1]) for i in range(len(in_channels) - 1)])
        self._maxpool = nn.MaxPool2d(2, stride=2)
    def forward(self, x):
        # returns unpooled output from each block: 
        # [intermediate results ... ], but we don't want to return 
        intermediate_outputs = deque([])
        for i in range(len(self._layers) - 1):
            x = self._layers[i](x)
            intermediate_outputs.appendleft(x)
            x = self._maxpool(x)
        x = self._layers[-1](x)
        return x, intermediate_outputs

class Decoder(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self._upward_conv_blocks = nn.ModuleList([
            nn.ConvTranspose2d(
                in_channels = channels[i], out_channels = channels[i+1], 
                kernel_size=2, stride=2
            ) for i in range(len(channels) - 1)
        ])
        # Then, there's a concat step in between
        self._conv_blocks = nn.ModuleList([
            ConvBlock(in_channels= channels[i], out_channels=channels[i+1]) 
            for i in range(len(channels) - 1)
        ])
    
    def forward(self, x, skip_inputs):
        if len(skip_inputs) != len(self._conv_blocks):
            raise ValueError("Please check implementation. Length of skip inputs and _conv_blocks should be the same!",
                             f"skip inputs, blocks inputs: {len(skip_inputs), len(self._conv_blocks)}")
        # x is smaller than skip inputs, because there's no padding in the conv layers
        for skip_input, up_block, conv_block in zip(skip_inputs, self._upward_conv_blocks, self._conv_blocks):
            print("x shape before upsampling: ", x.shape)
            x = up_block(x)
            print(skip_input.shape, x.shape)
            # TODO: here's a small detail. The paper didn't specify if we want to append or prepend. This might cause trouble
            skip_input = self.crop(skip_input=skip_input, x=x)
            x = torch.cat((skip_input, x), 1)
            # TODO, I'm really not sure if we need to crop. 
            x = conv_block(x)
        return x

    def crop(self, skip_input, x):
        _, _, H, W = x.shape
        return CenterCrop((H,W))(skip_input)
        
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        encoder_in_channels= [3, 16, 32, 64]    # bottleneck is 128
        decoder_channels = [64, 32, 16] #?
        self._encoder = Encoder(in_channels=encoder_in_channels)
        self._decoder = Decoder(channels=decoder_channels)
        # 1x1
        self._head = nn.Conv2d(in_channels=decoder_channels[-1], out_channels=1, kernel_size=1)
    def forward(self, x):
        x, intermediate = self._encoder(x)
        output = self._decoder(x, intermediate)
        output = self._head(output)
        output = torch.nn.functional.interpolate(output, size=(H,W),  mode='nearest')
        return output

def forward_pass_poc():
    image, _ = train_dataset[0]
    image = image.unsqueeze(0)
    _, _, H, W = image.shape
    enc = Encoder([3, 16, 32, 64])
    # # print(image.shape)
    x, intermediate_outputs = enc.forward(image)
    dec = Decoder(channels=[64, 32, 16])
    # torch.Size([1, 16, 216, 216])
    output = dec(x, intermediate_outputs)
    # 1x1
    head = nn.Conv2d(
        in_channels=16,
        out_channels=1,
        kernel_size=1,
    )
    output = head(output)
    output = torch.nn.functional.interpolate(output, size=(H,W),  mode='nearest')

torch.Size([1, 1, 256, 256])

# TODO: 
- do I need scheduler if I'm changing step size? Optional
- What is BCEWithLogitsLoss?

In [9]:
import time
from torch import optim

# Define the training function
MODEL_PATH = 'unet_pascal.pth'
ACCUMULATION_STEPS = 8
# Check against example
def train_model(model, train_loader, test_loader, criterion, optimizer, scheduler, num_epochs=25, device='cpu'):
    model.to(device)
    for epoch in range(num_epochs):
        # Training phase
        start = time.time()
        print(f'Epoch [{epoch + 1}/{num_epochs}] ')
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

criterion = nn.BCEWithLogitsLoss()
weight_decay = 0.0001
# momentum=0.9
learning_rate=0.001
num_epochs=50
batch_size=16
model = UNet()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

ValueError: optimizer got an empty parameter list