# ResNet implementation
Credits: 
- [ResNet paper](https://arxiv.org/abs/1512.03385)
- [Aladdin Persson on YouTube](https://www.youtube.com/watch?v=DkNIBBBvcPs&ab_channel=AladdinPersson)

In [1]:
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import torch.optim as optim
from tqdm import trange
from datetime import datetime
from skimage.io import imread

from data_process import label_2_colour, resize_2_256

In [2]:
# # Make the datasets (only needed once)
# test_path = './test_set/'
# train_path = './train_set/'
# test_set = make_datasets(test_path)   # saves the dataset as dataset{len(dataset)}.pth
# train_set = make_datasets(train_path) # saves the dataset as dataset{len(dataset)}.pth

In [3]:
# Loads the datasets (run the code above once to create the datasets) 
test_set = torch.load('../data/dataset30.pth')
train_set = torch.load('../data/dataset2973.pth')

In [4]:
train_size = int(0.9 * len(train_set))   #90% for train
val_size = len(train_set) - train_size   #10% for validation
train_dataset, val_dataset = random_split(train_set, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=12, shuffle=True, num_workers=2)
test_dataloader = DataLoader(val_dataset, batch_size=12, shuffle=True, num_workers=2)

In [5]:
class Block(nn.Module):     # ResNet block
    def __init__(self, in_c, out_c, downsample=None, stride=1):
        """Class defining a convolutional block as per the ResNet architecture

        Args:
            in_c (int): number of input channels (3 for RGB)
            out_c (int): number of output channels (classes)
            downsample (nn.Sequential, optional): downsample object. Defaults to None.
            stride (int, optional): denotes the stride for this block. Defaults to 1.
        """
        super().__init__()
        self.expansion = 4
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.conv3 = nn.Conv2d(out_c, out_c*self.expansion, kernel_size=1, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(out_c*self.expansion)
        self.relu = nn.ReLU()
        self.downsample = downsample
        
    def forward(self, x):
        identity = x.clone()    
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        self.bn2(x)
        self.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)
        
        if self.downsample is not None:
            identity = self.downsample(identity)
        x += identity
        x = self.relu(x)
        
        return x
    
class ResNet(nn.Module):
    def __init__(self, block, layers, img_c, n_classes): 
        """Initialise the ResNet model

        Args:
            block (Block): one convolutional block as per the ResNet architecture
            layers (list[int]): list denoting the number of blocks in each layer
            img_c (int): number of channels in the images (3 for RGB)
            n_classes (int): number of classes in the dataset
        """
        super().__init__()
        # initial layer, not resnet
        self.in_c = 64
        self.conv1 = nn.Conv2d(img_c, self.in_c, kernel_size=7, stride=2, padding=3) 
        self.bn1 = nn.BatchNorm2d(self.in_c)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # ResNet layers
        self.layer1 = self._make_layer(block, layers[0], out_c=64, stride=1)
        self.layer2 = self._make_layer(block, layers[1], out_c=128, stride=2)
        self.layer3 = self._make_layer(block, layers[2], out_c=256, stride=2)
        self.layer4 = self._make_layer(block, layers[3], out_c=512, stride=2)
        
        self.upsample = nn.Upsample((256,256), mode='bilinear', align_corners=True)
        self.conv2 = nn.Conv2d(512*4, n_classes, kernel_size=1)
        
        
    def forward(self, x):
        x = self.conv1(x)   # initial layer
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)  # ResNet layers
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.upsample(self.conv2(x))
        
        return x
    
    def _make_layer(self, block, n_res_blocks, out_c, stride):
        """Internal function to create the ResNet layers

        Args:
            block (Block): convolutional block as per the ResNet architecture
            n_res_blocks (int): number of residual blocks, number of times blocks are used
            out_c (int): number of channels when done with this layer
            stride (int): 1 or 2 depending on the layer
        """
        downsample = None
        layers = []
        
        if stride != 1 or self.in_c != out_c * 4: # 
            downsample = nn.Sequential(
                nn.Conv2d(self.in_c, out_c*4, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_c*4)
            )
        layers.append(block(self.in_c, out_c, downsample, stride)) # layer changing the number of channels
        self.in_c = out_c * 4 
        
        for _ in range(n_res_blocks - 1): # one already computed above, subtracted
            layers.append(block(self.in_c, out_c)) 
        
        return nn.Sequential(*layers)   

In [6]:
# Difference between the three models is how deep they are, defined by the list input (denoting the number of blocks in each layer)
def ResNet50(img_c=3, n_classes=10):
    return ResNet(Block, [3, 4, 6, 3], img_c, n_classes)

def ResNet101(img_c=3, n_classes=10):
    return ResNet(Block, [3, 4, 23, 3], img_c, n_classes)

def ResNet152(img_c=3, n_classes=10):
    return ResNet(Block, [3, 8, 36, 3], img_c, n_classes)

In [7]:
def test():
    x = torch.randn((2, 3, 256, 256))   # n = 2, in_channel = 3, h = w =256
    model = ResNet50(img_c=3, n_classes=10)
    y = model(x)
    print(f"Predicted: {y.shape}")
    print(f"Input: {x.shape}")
    #shape of input and output is same

test()

Predicted: torch.Size([2, 10, 256, 256])
Input: torch.Size([2, 3, 256, 256])


## Training

In [8]:
def train(epo_num=50, visualise=False):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = ResNet50(img_c=3, n_classes=10)    #input is rgb output is 10 classes
    model = model.to(device)
    criterion = nn.CrossEntropyLoss().to(device) #loss
    optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.7) #optimizer

    all_train_iter_loss = []
    all_test_iter_loss = []

    # start timing
    prev_time = datetime.now()
    for epo in trange(epo_num): # trange to show progress bar

        train_loss = 0
        model.train()
        for _, (car, car_msk) in enumerate(train_dataloader):
            # car.shape is torch.Size([12, 3, 256, 256])
            # car_msk.shape is torch.Size([12, 10, 256, 256])

            car = car.to(device)
            car_msk = car_msk.to(device)
            
            optimizer.zero_grad()
            output = model(car)                    # output.shape is torch.Size([12, 10, 256, 256])
            car_msk = torch.argmax(car_msk, dim=1) # car_msk.shape is now torch.Size([12, 256, 256])
            loss = criterion(output, car_msk)
            loss.backward()
            iter_loss = loss.item()
            all_train_iter_loss.append(iter_loss)
            train_loss += iter_loss
            optimizer.step()
            
            # Display the predicted mask and the ground truth mask
            if visualise:
                car_msk = (car_msk.cpu()*10)                        # Nx10x256x256 (0-9) -> (0-90)
                predictions = output.cpu()                          # Nx10x256x256 (0-9)
                predictions = (torch.argmax(predictions, dim=1)*10) # Nx256x256 (0-9) -> (0-90)
                print(f"Unique mask: {torch.unique(car_msk)}")
                print(f"Unique predictions: {torch.unique(predictions)}")
                _, ax = plt.subplots(1, 2, figsize=(15, 10))
                ax[0].imshow(label_2_colour(car_msk[0]))
                ax[0].set_title('Ground truth')
                ax[1].imshow(label_2_colour(predictions[0]))
                ax[1].set_title('Prediction')
                plt.show()

        # Update the time 
        cur_time = datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        prev_time = cur_time

        # Print information about epochs and loss
        print('epoch:', epo, '/', epo_num)
        print('epoch train loss = %f, %s' %(train_loss/len(train_dataloader), time_str))

    return model


if __name__ == "__main__":
    model = train(epo_num=2, visualise=False)

  0%|          | 0/2 [00:00<?, ?it/s]

In [1]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if __name__ =='__main__':
    imgA = imread('drive/My Drive/carseg_data/images/black_5_doors/no_segmentation/0001.png/')
    array = imread('drive/My Drive/carseg_data/arrays/black_5_doors_0001.npy')
    imgA = resize_2_256(imgA)

    transform = transforms.Compose([transforms.ToTensor(),  # same transform as above
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                         std=[0.229, 0.224, 0.225])])
    imgA = transform(imgA)
    imgA = imgA.to(device)
    imgA = imgA.unsqueeze(0)
    output = model(imgA)
    output = torch.sigmoid(output)

    output_np = output.cpu().detach().numpy().copy()  # output_np.shape = (1, 10, 256, 256)
    #print(output_np.shape)   #(1, 10, 256, 256)
    output_np = (np.argmax(output_np, axis=1) * 10).astype(np.uint8)
    #print(output_np.shape)  #(1,256, 256)
    #print(output_np[0,...])
    plt.subplot(1, 2, 1)
    plt.imshow(label_2_colour(array[:,:,3]))
    plt.subplot(1, 2, 2)
    plt.imshow(label_2_colour(output_np[...]))


NameError: name 'torch' is not defined