
Notes:
- Technically, UNet can be used on images of various input images. Would you recommend doing the same transforms? 
- What does torch.modulelist do? What is it similar to in tensor flow?
    - An `nn.module` is usually a layer / group of layers. `torch.modulelist` is a container for that, and it can be indexed like a regular list. But in the meantime these are layers are registered in the model for calls like `model.to()`, `model.train()`, `model.eval()`

- [ConvTranspose2d](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#convtranspose2d) - TODO
    ```python
    conv_transpose = nn.ConvTranspose2d(in_channels=3, out_channels=1, kernel_size=2, stride=2, padding=0)
    input_tensor = torch.randn(1, 3, 5, 5)  # 1 batch, 3 channels, 5x5 image
    output_tensor = conv_transpose(input_tensor)
    output_tensor.shape
    ```
- How does concat work? horizontally: `torch.cat((x, x, x), 1)`, vertically: `torch.cat((x, x, x), 0)`

- How to get model summary in pytorch? 
```
pip install torchsummary
from torchsummary import summary
# Assuming `model` is your neural network
summary(model, input_size=(3, 224, 224))  # For an input image of size 224x224 with 3 channels (RGB)
```
- Here, we need to do `crop()` actually. The best way is `from torchvision.transforms import CenterCrop`

- increase dimension at dim=0: `new_t = t.unsqueeze(0)`

- What do you do when the final output is `torch.Size([1, 16, 216, 216])`, while your input is `torch.Size([1, 3, 256, 256])`

- my labels are int8, I'm going to train my network. What data type should my output have

- Interpolation for image downsizing, and final output outsizing is important. We choose Nearest neighbor interpolation. Because continuous interpolation or spline does not make sense for image segmentation

In [1]:
import torch
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import Dataset
from torchvision.transforms import v2, CenterCrop
import matplotlib.pyplot as plt
from torchvision.transforms.functional import InterpolationMode
from functools import cached_property

DATA_DIR='./data'
BATCH_SIZE = 16
IGNORE_INDEX = 21

def replace_tensor_val(tensor, a, b):
    tensor[tensor==a] = b
    return tensor

image_seg_transforms = transforms.Compose([
   v2.Resize((256, 256)),
   v2.ToTensor(),
   v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

target_seg_transforms = transforms.Compose([
    v2.Resize((256, 256), interpolation=InterpolationMode.NEAREST),
    v2.PILToTensor(),
    v2.Lambda(lambda tensor: tensor.squeeze()),
    v2.Lambda(lambda x: replace_tensor_val(x.long(), 255, 21)),
])

class VOCSegmentationClass(Dataset):
   def __init__(self, image_set): 
        # Load PASCAL VOC 2007 dataset for segmentation
        self._dataset = datasets.VOCSegmentation(
            root=DATA_DIR,  # Specify where to store the data
            year='2007',    # Specify the year of the dataset (2007 in this case)
            image_set=image_set,  # You can use 'train', 'val', or 'trainval'
            download=True,  # Automatically download if not available
            transform=image_seg_transforms,  # Apply transformations to the images
            target_transform=target_seg_transforms  # Apply transformations to the masks
        )
        self._classes = set()
   @cached_property
   def classes(self):
       if len(self._classes) == 0: 
           for image, target in self._dataset:
            self._classes.update(torch.unique(target).tolist())
       return self._classes
   def __getitem__(self, index): 
       # return an image and a label. In this case, a label is an image with int8 values
       return self._dataset[index]
       # TODO: more transforms?
   def __len__(self):
        return len(self._dataset)

train_dataset = VOCSegmentationClass(image_set='train')
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = BATCH_SIZE,
    shuffle=True,
    num_workers = 2,
    pin_memory = True
)

test_dataset = VOCSegmentationClass(image_set='test')
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size = BATCH_SIZE,
    shuffle=True,
    num_workers = 2,
    pin_memory = True
)



Using downloaded and verified file: ./data/VOCtrainval_06-Nov-2007.tar
Extracting ./data/VOCtrainval_06-Nov-2007.tar to ./data
Using downloaded and verified file: ./data/VOCtest_06-Nov-2007.tar
Extracting ./data/VOCtest_06-Nov-2007.tar to ./data


In [2]:
def visualize_image_and_target(image, target):
    # # See torch.Size([3, 281, 500]) torch.Size([1, 281, 500])
    # # print(image.shape, target.shape)

    plt.subplot(1,2,1)
    # Making channels the last dimension
    plt.imshow(image.permute(1,2,0))
    plt.title('image')

    plt.subplot(1,2,2)
    # Making channels the last dimension
    plt.imshow(target.permute(1,2,0))
    plt.title('mask')
    # See tensor([  0,   1,  15, 255], dtype=torch.uint8)
    print("unique: ", torch.unique(target[0]))
    plt.show()
    
# # print("classes: ", train_dataset.classes)
# for image, target in train_dataset:
    # visualize_image_and_target(image, target)
       

In [11]:
# This is a regular conv block
from torch import nn as nn
from collections import deque

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
        )
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=3,
        )
    def forward(self, x):
        return self.conv2(self.relu(self.conv1(x)))

class Encoder(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        # This should include the bottleneck.
        self._layers = nn.ModuleList([ConvBlock(in_channels[i], in_channels[i+1]) for i in range(len(in_channels) - 1)])
        self._maxpool = nn.MaxPool2d(2, stride=2)
    def forward(self, x):
        # returns unpooled output from each block: 
        # [intermediate results ... ], but we don't want to return 
        intermediate_outputs = deque([])
        for i in range(len(self._layers) - 1):
            x = self._layers[i](x)
            intermediate_outputs.appendleft(x)
            x = self._maxpool(x)
        x = self._layers[-1](x)
        return x, intermediate_outputs

class Decoder(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self._upward_conv_blocks = nn.ModuleList([
            nn.ConvTranspose2d(
                in_channels = channels[i], out_channels = channels[i+1], 
                kernel_size=2, stride=2
            ) for i in range(len(channels) - 1)
        ])
        # Then, there's a concat step in between
        self._conv_blocks = nn.ModuleList([
            ConvBlock(in_channels= channels[i], out_channels=channels[i+1]) 
            for i in range(len(channels) - 1)
        ])
    
    def forward(self, x, skip_inputs):
        if len(skip_inputs) != len(self._conv_blocks):
            raise ValueError("Please check implementation. Length of skip inputs and _conv_blocks should be the same!",
                             f"skip inputs, blocks inputs: {len(skip_inputs), len(self._conv_blocks)}")
        # x is smaller than skip inputs, because there's no padding in the conv layers
        for skip_input, up_block, conv_block in zip(skip_inputs, self._upward_conv_blocks, self._conv_blocks):
            # print("x shape before upsampling: ", x.shape)
            x = up_block(x)
            # print(skip_input.shape, x.shape)
            # TODO: here's a small detail. The paper didn't specify if we want to append or prepend. This might cause trouble
            skip_input = self.crop(skip_input=skip_input, x=x)
            x = torch.cat((skip_input, x), 1)
            # TODO, I'm really not sure if we need to crop. 
            x = conv_block(x)
        return x

    def crop(self, skip_input, x):
        _, _, H, W = x.shape
        return CenterCrop((H,W))(skip_input)
        
class UNet(nn.Module):
    def __init__(self, class_num):
        super().__init__()
        encoder_in_channels= [3, 32, 64, 128]    # bottleneck is 128
        decoder_channels = [128, 64, 32] #?
        self._encoder = Encoder(in_channels=encoder_in_channels)
        self._decoder = Decoder(channels=decoder_channels)
        # 1x1
        self._head = nn.Conv2d(in_channels=decoder_channels[-1], out_channels=class_num, kernel_size=1)
        self._init_weight()

    def forward(self, x):
        _, _, H, W = x.shape
        x, intermediate = self._encoder(x)
        output = self._decoder(x, intermediate)
        output = self._head(output)
        output = torch.nn.functional.interpolate(output, size=(H,W),  mode='nearest')
        return output

    def _init_weight(self):
        with torch.no_grad():
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

def forward_pass_poc():
    image, target = train_dataset[0]
    print(target.shape)
    class_num = len(train_dataset.classes)
    image = image.unsqueeze(0)
    _, _, H, W = image.shape
    enc = Encoder([3, 16, 32, 64])
    # # print(image.shape)
    x, intermediate_outputs = enc.forward(image)
    dec = Decoder(channels=[64, 32, 16])
    # torch.Size([1, 16, 216, 216])
    output = dec(x, intermediate_outputs)
    # 1x1
    head = nn.Conv2d(
        in_channels=16,
        out_channels=class_num,
        kernel_size=1,
    )
    output = head(output)
    output = torch.nn.functional.interpolate(output, size=(H,W),  mode='nearest')
    print(output.shape)
forward_pass_poc()

torch.Size([256, 256])
torch.Size([1, 22, 256, 256])


In [20]:
def eval_model(model, test_loader, device, visualize: bool = False):
    # Evaluation phase
    model.eval()
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        # TODO I AM ITERATING OVER TRAIN_LOADER, SO I'M MORE SURE
        for inputs_test, labels_test in test_loader:
            inputs_test = inputs_test.to(device)
            labels_test = labels_test.to(device)
            outputs_test = model(inputs_test)
            _, predicted_test = outputs_test.max(1)
            mask = (labels_test != 21)
            local_total = mask.sum().item()
            local_correct = (predicted_test.eq(labels_test) & mask).sum().item()
            total_test += local_total
            correct_test += local_correct

            if visualize:
                #TODO Remember to remove
                print(f'Rico: predicted test acc {100. * local_correct/local_total}%')
                print(torch.unique(predicted_test[0]), torch.unique(labels_test[0]))
                # visualize_image_and_target(inputs_test[0], predicted_test[0].unsqueeze(0))

    test_acc = 100. * correct_test / total_test
    print(f'Test Acc: {test_acc:.2f}%')

Rico: predicted test acc 70.9432556500004%
tensor([0]) tensor([ 0,  1, 21])
Rico: predicted test acc 70.03683435595154%
tensor([0]) tensor([ 0,  9, 11, 21])
Rico: predicted test acc 73.19779074843441%
tensor([0]) tensor([ 0, 19, 21])
Rico: predicted test acc 64.70813676279762%
tensor([0]) tensor([ 0,  9, 16, 18, 21])
Rico: predicted test acc 77.57816696235291%
tensor([0]) tensor([ 0,  2, 21])
Rico: predicted test acc 74.02743796863264%
tensor([0]) tensor([ 0,  7, 10, 15, 21])
Rico: predicted test acc 74.92632208384943%
tensor([0]) tensor([ 0,  6,  7, 21])
Rico: predicted test acc 69.30629351665686%
tensor([0]) tensor([ 0,  3, 21])
Rico: predicted test acc 76.45517492016228%
tensor([0]) tensor([ 0,  8, 16, 21])
Rico: predicted test acc 68.10569397054084%
tensor([0]) tensor([ 0, 13, 15, 21])
Rico: predicted test acc 68.20071602996565%
tensor([0]) tensor([ 0, 20, 21])
Rico: predicted test acc 70.07979902016217%
tensor([0]) tensor([ 0, 13, 21])
Rico: predicted test acc 71.58363302118126%
t

In [17]:
import time
import os
from torch import optim

# Define the training function
MODEL_PATH = 'unet_pascal.pth'
ACCUMULATION_STEPS = 8

# Check against example
def train_model(model, train_loader, test_loader, criterion, optimizer, scheduler, num_epochs=25, device='cpu'):
    model.to(device)
    for epoch in range(num_epochs):
        # Training phase
        start = time.time()
        print(f'Epoch [{epoch + 1}/{num_epochs}] ')
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        
        # for i, (inputs, labels) in enumerate(train_loader):
        #     inputs = inputs.to(device)
        #     labels = labels.to(device)
        #     # Forward pass
        #     outputs = model(inputs)
        #     # This is because torch.nn.CrossEntropyLoss(reduction='mean') is true, so to simulate a larger batch, we need to further divide
        #     # print(f"output: {outputs.dtype}, labels: {labels.dtype}")
        #     loss = criterion(outputs, labels)/ACCUMULATION_STEPS
        #     # Backward pass and optimization
        #     loss.backward()
        #     if (i+1)%ACCUMULATION_STEPS == 0:
        #         optimizer.step()
        #         # Zero the parameter gradients
        #         optimizer.zero_grad()
        #         # break #TODO
        #     # Statistics
        #     running_loss += loss.item() * inputs.size(0)
        #     _, predicted = outputs.max(1)
        #     # print(predicted.shape)
        #     mask = (labels != 21)
        #     total_train += mask.sum().item()
        #     # print((predicted == labels).sum().item(), ((predicted == labels) & mask).sum().item())
        #     # print(mask.sum().item())
            
        #     correct_train += ((predicted == labels) & mask).sum().item()

        # # adjust after every epoch
        # scheduler.step()  # TODO: disabled for Adam optimizer
        # current_lr = optimizer.param_groups[0]['lr']
        # print(f"Current learning rate: {current_lr}")
        # epoch_loss = running_loss / len(train_loader.dataset)
        # epoch_acc = 100. * correct_train / total_train
        # print("correct train: ", correct_train, " total train: ", total_train)
        # end = time.time()
        
        # print("elapsed: ", end-start)

        # torch.save(model.state_dict(), MODEL_PATH)
        # print(f"epoch: {epoch}, saved the model. "
        #       f'Train Loss: {epoch_loss:.4f} '
        #       f'Train Acc: {epoch_acc:.2f}% ')
    # eval_model(model, test_loader=test_dataloader, device=device) 
    print('Training complete')
    return model

model = UNet(class_num = len(train_dataset.classes))
criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
weight_decay = 0.0001
# momentum=0.9
learning_rate=0.001
num_epochs=70
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 40], gamma=0.1)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Train the model
if os.path.exists(MODEL_PATH):
    model.load_state_dict(torch.load(MODEL_PATH, weights_only=False, map_location=device))
    print("loaded model")
model.to(device)

model = train_model(model, train_dataloader, test_dataloader, criterion, optimizer, scheduler,
                    num_epochs=num_epochs, device=device)

loaded model
Epoch [1/70] 
Epoch [2/70] 
Epoch [3/70] 
Epoch [4/70] 
Epoch [5/70] 
Epoch [6/70] 
Epoch [7/70] 
Epoch [8/70] 
Epoch [9/70] 
Epoch [10/70] 
Epoch [11/70] 
Epoch [12/70] 
Epoch [13/70] 
Epoch [14/70] 
Epoch [15/70] 
Epoch [16/70] 
Epoch [17/70] 
Epoch [18/70] 
Epoch [19/70] 
Epoch [20/70] 
Epoch [21/70] 
Epoch [22/70] 
Epoch [23/70] 
Epoch [24/70] 
Epoch [25/70] 
Epoch [26/70] 
Epoch [27/70] 
Epoch [28/70] 
Epoch [29/70] 
Epoch [30/70] 
Epoch [31/70] 
Epoch [32/70] 
Epoch [33/70] 
Epoch [34/70] 
Epoch [35/70] 
Epoch [36/70] 
Epoch [37/70] 
Epoch [38/70] 
Epoch [39/70] 
Epoch [40/70] 
Epoch [41/70] 
Epoch [42/70] 
Epoch [43/70] 
Epoch [44/70] 
Epoch [45/70] 
Epoch [46/70] 
Epoch [47/70] 
Epoch [48/70] 
Epoch [49/70] 
Epoch [50/70] 
Epoch [51/70] 
Epoch [52/70] 
Epoch [53/70] 
Epoch [54/70] 
Epoch [55/70] 
Epoch [56/70] 
Epoch [57/70] 
Epoch [58/70] 
Epoch [59/70] 
Epoch [60/70] 
Epoch [61/70] 
Epoch [62/70] 
Epoch [63/70] 
Epoch [64/70] 
Epoch [65/70] 
Epoch [66/70] 
Epoch 

In [19]:
# from torchsummary import summary
# summary(model, (3, 356, 356), summary_mode=True)

def calculate_average_weights(model):
    total_sum = 0
    total_elements = 0
    for name, param in model.named_parameters():
        if 'weight' in name:
            weight_mean = param.mean().item()
            total_sum += param.sum().item()
            total_elements += param.numel()
            print(f"Layer: {name} | Average Weight: {weight_mean:.6f}")
    
    overall_average = total_sum / total_elements if total_elements > 0 else 0
    print(f"Overall Average Weight in the Network: {overall_average:.6f}")

# calculate_average_weights(model)

eval_model(model, test_dataloader, device=device, visualize=True)

Layer: _encoder._layers.0.conv1.weight | Average Weight: -0.001730
Layer: _encoder._layers.0.conv2.weight | Average Weight: 0.000583
Layer: _encoder._layers.1.conv1.weight | Average Weight: 0.000499
Layer: _encoder._layers.1.conv2.weight | Average Weight: -0.000026
Layer: _encoder._layers.2.conv1.weight | Average Weight: 0.000001
Layer: _encoder._layers.2.conv2.weight | Average Weight: 0.000086
Layer: _decoder._upward_conv_blocks.0.weight | Average Weight: -0.000054
Layer: _decoder._upward_conv_blocks.1.weight | Average Weight: 0.000328
Layer: _decoder._conv_blocks.0.conv1.weight | Average Weight: -0.000004
Layer: _decoder._conv_blocks.0.conv2.weight | Average Weight: -0.000198
Layer: _decoder._conv_blocks.1.conv1.weight | Average Weight: 0.000534
Layer: _decoder._conv_blocks.1.conv2.weight | Average Weight: -0.000918
Layer: _head.weight | Average Weight: 0.001126
Overall Average Weight in the Network: 0.000044
