In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torchsummary import summary
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict, defaultdict
from PIL import Image


In [None]:
class NormalizeMask:
    def __call__(self, mask):
        # Make values be 0, 1 and 2
        mask *= 1/0.0039
        mask = torch.round(mask)
        mask -= 1
        return mask
    

class Flip:
    '''
    Be careful with tehse calss, it is designed to flip image and mask at same time
    but it maybe in case of some interuption would stop doing it at same time,
    in other words image will get fliped and mask wont for example.
    So just rerun it.
    '''
    def __init__(self, seed=42):
        self.rng = np.random.RandomState(seed)
        
    def __call__(self, img):
        if self.rng.rand(1) > 0.5:
            img = transforms.functional.hflip(img)
        return img


# These is already normalized form 0 to 1
img_transform = transforms.Compose([
    transforms.Resize((128, 128)),  
    transforms.ToTensor(),
    Flip(),
])

mask_transform = transforms.Compose([
    transforms.Resize((128, 128), interpolation= Image.NEAREST),  
    transforms.ToTensor(),
    NormalizeMask(),
    Flip(),
    transforms.Lambda(lambda x: x.long())
])

In [None]:
# Load the Oxford-IIIT Pet dataset
train_dataset = datasets.OxfordIIITPet(root='./data',target_types="segmentation", 
                                 transform=img_transform,
                                 target_transform=mask_transform, 
                                 split="trainval",
                                 download=True)

test_dataset = datasets.OxfordIIITPet(root='./data',target_types="segmentation", 
                                 transform=img_transform,
                                 target_transform=mask_transform, 
                                 split="test",
                                 download=True)
# Create a DataLoader
BATCH_SIZE = 64
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)


In [None]:
def display(display_list):
    plt.figure(figsize=(15, 15))
    
    title = ['Input Image', 'True Mask', 'Predicted Mask']
    
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off')
    plt.show()
    

In [None]:
for images, masks in train_loader:
    sample_img, sample_mask = images[0].permute(1,2,0), masks[0].permute(1,2,0)
    display([sample_img, sample_mask])
    break

In [None]:
import torch.nn.init as init
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2):
        super(UpBlock, self).__init__()
        padding = 1
        output_padding = 1
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, 
                                         kernel_size=kernel_size, stride=stride, 
                                         padding=padding, output_padding=output_padding, bias=False)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.initialize_weights()

    def initialize_weights(self):
        # Apply the random normal initializer to the ConvTranspose2d weights
        if hasattr(self.upconv, 'weight') and self.upconv.weight is not None:
            init.normal_(self.upconv.weight, mean=0., std=0.02)

        # Optionally, initialize biases (if your layer has biases)
        if hasattr(self.upconv, 'bias') and self.upconv.bias is not None:
            init.constant_(self.upconv.bias, 0.)
            
    def forward(self, x):
        x = self.upconv(x)
        x = self.batchnorm(x)
        x = self.relu(x)
        return x

up_stack = nn.Sequential(
    UpBlock(320, 512),
    UpBlock(1088, 256),
    UpBlock(448, 128),
    UpBlock(272, 64),
)


In [None]:
import pytorch_lightning as L
from torchmetrics import Accuracy
from pytorch_lightning.callbacks import Callback

In [None]:
class ShowPicCallback(Callback):
    def __init__(self, img, mask):
        super().__init__()
        #img and mask shape = C, H, W
        self.img = img.unsqueeze(0)
        self.mask = mask.permute(1,2,0)
        
    def on_train_epoch_start(self, trainer, pl_module):
        prediction = pl_module(self.img)
        prediction = pl_module.create_mask(prediction, 'whc')[0]
        
        display([self.img[0].permute(1,2,0),
                self.mask,
                prediction])


In [None]:
class IdentityLayer(nn.Module):
    def forward(self, x):
        return x
    
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
# Layers we wont access to to make skip connections
down_stack = [
    mobilenet_v2.features[2].conv[0],
    mobilenet_v2.features[4].conv[0],
    mobilenet_v2.features[7].conv[0],
    mobilenet_v2.features[14].conv[0],
    mobilenet_v2.features[17].conv[2],
]
# Just remove these layers
mobilenet_v2.classifier = IdentityLayer()
mobilenet_v2.features[18] = IdentityLayer()

for param in mobilenet_v2.parameters():
    param.requires_grad = False
        
class SeqmentationModel(L.LightningModule):
    def __init__(self, mob_net, down_stack, up_stack, 
                 padding=1, output_padding=1,
                 loss=nn.CrossEntropyLoss()):
        super().__init__()
        self.mob_net = mob_net
        self.down_stack = down_stack
        self.up_stack = up_stack
        
        self.last = nn.ConvTranspose2d(160, 3, kernel_size=3, stride=2, padding=padding, output_padding=output_padding)
        
        self.criterion = loss
        self.accuracy_metric = Accuracy(num_classes=3, task="multiclass")
        self.history = defaultdict(lambda:[])
        
        self.add_hooks()
        
    def add_hooks(self):
        self.skips = OrderedDict()
        self.hooks = []
        def f(m,i,o,ind):
            self.skips[str(m)]=o
        for ind, layer in enumerate(self.down_stack):
            self.hooks.append(layer.register_forward_hook(lambda m,i,o:f(m,i,o,ind)))
    
    def remove_hooks(self):
        self.hooks = []
        for hook in self.hooks:
            hook.remove()
            
    def forward(self, images):
        x = images
        # down sample image
        self.mob_net(x)

        skips = list(self.skips.values())
        x = skips.pop()
        skips = list(reversed(skips))
        
        # up sample image
        for up, skip in zip(self.up_stack, skips):
            x = up(x)
            x = torch.cat([x, skip], dim=1)
        
        x = self.last(x)
        return x
    
    def training_step(self, batch, batch_idx):
        accuracy, loss, predictions = self._common_step(batch, batch_idx)
        
        self.training_step_outputs = {
            'training_loss':loss,
            'training_accuracy':accuracy
        }
        self.log_dict(self.training_step_outputs, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        accuracy, loss, predictions = self._common_step(batch, batch_idx)
        
        self.validation_step_outputs = {
            'validation_loss':loss,
            'validation_accuracy':accuracy
        }
        self.log_dict(self.validation_step_outputs, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        accuracy, loss, predictions = self._common_step(batch, batch_idx)
        
        self.test_step_outputs = {
            'test_loss':loss,
            'test_accuracy':accuracy
        }
        self.log_dict(self.test_step_outputs, prog_bar=True)
        return loss
    
    def _common_step(self, batch, batch_idx):
        images, masks = batch
        predictions = self.forward(images)
        loss = self.criterion(predictions, masks.squeeze(1))
        accuracy = self.accuracy_metric(self.create_mask(predictions), masks)
        return accuracy, loss, predictions
    
    def on_train_batch_end(self, outputs, batch, batch_idx):
        self._common_on_batch_end(self.training_step_outputs)
    
    def on_validation_batch_end(self, outputs, batch, batch_idx):
        self._common_on_batch_end(self.validation_step_outputs)
    
    def on_test_batch_end(self, outputs, batch, batch_idx):
        self._common_on_batch_end(self.test_step_outputs)
    
    def _common_on_batch_end(self, step_outputs):
        loss, accuracy = step_outputs.keys()
        self.history[loss].append(step_outputs[loss])
        self.history[accuracy].append(step_outputs[accuracy])
    
    def create_mask(self, pred_mask, shape='cwh'):
        pred_mask = torch.argmax(pred_mask, axis=1)
        if shape == 'cwh':
            pred_mask = pred_mask.unsqueeze(1)
        else:
            pred_mask = pred_mask.unsqueeze(-1)
        return pred_mask

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

model = SeqmentationModel(mobilenet_v2, down_stack, up_stack)
trainer = L.Trainer(
    min_epochs=10, 
    max_epochs=20,
    callbacks=[ShowPicCallback(images[0], masks[0])])
trainer.fit(model, train_loader, test_loader)
            

In [None]:
predicted = model.create_mask(model(images), 'whc')[0]
display([sample_img, sample_mask,predicted])

In [None]:
plt.plot(torch.tensor(model.history['training_loss']).numpy(), label='training_loss')
plt.plot(torch.tensor(model.history['training_accuracy']).numpy(), label='training_accuracy')
plt.plot(torch.tensor(model.history['validation_loss']).numpy(), label='validation_loss')
plt.plot(torch.tensor(model.history['validation_accuracy']).numpy(), label='validation_accuracy')
plt.legend()

In [None]:
model_path = 'model_torch.pth'
torch.save(model.state_dict(), model_path)
