In [None]:
# Import the required libraries

import torch
import os
import torchvision.datasets as datasets
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torch.nn.functional import interpolate
import torch.nn as nn
import math
import numpy as np
from natsort import natsorted
from PIL import Image
import pandas as pd

In [None]:
# Code to upload dataset to google colab
try:
    from google.colab import drive
    drive.mount('/content/drive')

    import zipfile

    path_to_zip_file = 'drive/MyDrive/data.zip'
    directory_to_extract_to = 'VISCHEMA_PLUS/'
    with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
        zip_ref.extractall(directory_to_extract_to)

    !cp "drive/MyDrive/viscplus_train.csv" "VISCHEMA_PLUS/"
    !cp "drive/MyDrive/viscplus_val.csv" "VISCHEMA_PLUS/"
except:
    print("not on google colab")

In [None]:
# Create datasets and dataloaders
class VISCHEMA_PLUS(Dataset):
    def __init__(self, dataset_dir = '../VISCHEMA_PLUS/', image_dir = 'images/', label_dir = 'vms/', train = True, transform = None):
        
        if train:
            train_csv = pd.read_csv(f"{dataset_dir}viscplus_train.csv", header = None)
            all_images = train_csv[0].values.tolist()
        else:
            val_csv = pd.read_csv(f"{dataset_dir}viscplus_val.csv" , header = None)
            all_images = val_csv[0].values.tolist()
            
        self.transform = transform
        
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.dataset_dir = dataset_dir
        
        self.all_images = natsorted(all_images)
        
    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, idx):    
        convert_tensor = transforms.ToTensor()
        
        image = Image.open(f"{self.dataset_dir}{self.image_dir}{self.all_images[idx]}").convert("RGB")
        image = convert_tensor(image)

        label = Image.open(f"{self.dataset_dir}{self.label_dir}{self.all_images[idx]}").convert("RGB")
        label = convert_tensor(label)
        
        if self.transform != None:
            image = self.transform(image)
            label = self.transform(label)
            
        return image, label

image_transforms = transforms.Compose([
    transforms.Resize(128),
    transforms.Normalize(0.5,0.5)
])

batch_size = 8

train_dataset = VISCHEMA_PLUS(transform = image_transforms, train=True)
val_dataset   = VISCHEMA_PLUS(transform = image_transforms, train=False)

train_loader = DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle=True)
val_loader   = DataLoader(dataset = val_dataset,   batch_size = batch_size, shuffle=True)

print(f'{len(train_dataset)} Items in Train dataset')
print(f'{len(val_dataset)}  Items in Validation dataset')

In [None]:
# Test that the images and labels look good
import matplotlib.pyplot as plt

images, labels = next(iter(train_loader))
figure = plt.figure(figsize=(15,15))
cols, rows = 1, 4
for i in range(cols * rows):
    figure.add_subplot(rows,cols*2, 2*i+1)
    plt.axis("off")
    plt.imshow((images[i,:].squeeze().permute(1, 2, 0) +1 )/2 )
    figure.add_subplot(rows,cols*2, 2*i+2)
    plt.axis("off")
    plt.imshow((labels[i,:].squeeze().permute(1, 2, 0) +1 )/2 )
plt.tight_layout()
plt.show()

In [None]:
# UNET Model

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, norm = True, norm_func = nn.InstanceNorm2d):
        super().__init__()
        if norm:
            self.main = nn.Sequential(
                nn.Conv2d(in_channels,out_channels,3,padding=1),
                nn.ReLU(),
                norm_func(out_channels),
                nn.Conv2d(out_channels,out_channels,3,padding=1),
                nn.ReLU(),
                norm_func(out_channels),
            )
        else:
            self.main = nn.Sequential(
                nn.Conv2d(in_channels,out_channels,3,padding=1),
                nn.ReLU(),
                nn.Conv2d(out_channels,out_channels,3,padding=1),
                nn.ReLU(),
                norm_func(out_channels),
            )
    def forward(self, x):
        return self.main(x)
    
class Encoder(nn.Module):
    def __init__(self, channels=(3,64,128,256,512), norm_func = nn.InstanceNorm2d):
        super().__init__()
        self.encoding_blocks = nn.ModuleList(
            [Block(channels[i], channels[i+1], norm = (i != 0), norm_func = norm_func) for i in range(len(channels)-1)]
        )
        self.pool = nn.MaxPool2d(2)
    
    def forward(self, x):
        features = []
        
        for block in self.encoding_blocks:
            x = block(x)
            features.append(x)
            x = self.pool(x)
                 
        return features
        
class Decoder(nn.Module):
    def __init__(self, channels=(512,256,128,64), norm_func = nn.InstanceNorm2d):
        super().__init__()
        self.channels = channels
        self.upconvs = nn.ModuleList([nn.ConvTranspose2d(self.channels[i], self.channels[i+1],2,2) for i in range(len(channels)-1)])
        self.decoder_blocks = nn.ModuleList(
            [Block(channels[i], channels[i+1], norm_func = norm_func) for i in range(len(channels)-1)]
        )
    
    def forward(self, x, encoder_features):
        for i in range(len(self.channels)-1):
            x = self.upconvs[i](x)
            features = self.crop(encoder_features[i], x)
            x = torch.cat([x, features], dim=1)
            x = self.decoder_blocks[i](x)
        return x
    
    def crop(self, features, x):
        _, _, height, width = x.shape
        features = transforms.CenterCrop([height, width])(features)
        return features
        
class Generator(nn.Module):
    def __init__(self, 
                 encode_channels=(3,64,128,256,512,1024), 
                 decode_channels=(1024,512,256,128,64), 
                 num_class=3, 
                 retain_dim=True, 
                 output_size=(572,572),
                 norm_func = nn.InstanceNorm2d):
        
        super(Generator, self).__init__()
        
        self.encoder = Encoder(encode_channels, norm_func = norm_func)
        self.decoder = Decoder(decode_channels, norm_func = norm_func)
        
        self.head = nn.Conv2d(decode_channels[-1], num_class, 1)
        
        """# Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.ConvTranspose2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)"""

    def forward(self, x):
        
        encoding_features = self.encoder(x)
        decoding_features = self.decoder(encoding_features[::-1][0], encoding_features[::-1][1:])
        output = self.head(decoding_features)
        output = nn.Tanh()(output)
        
        return output

model = Generator()

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_params} Parameters in UNet Generator')

images, _ = next(iter(train_loader))
output = model(images)

print(output.shape)

In [None]:
# default output test
import matplotlib.pyplot as plt

images, labels = next(iter(train_loader))
figure = plt.figure(figsize=(15,15))
output = model(images)

ideal = np.add(images, labels) 

cols, rows = 2, 4
for i in range(rows):
    figure.add_subplot(rows,cols, cols*i+1)
    plt.axis("off")
    plt.imshow(((ideal[i,:]*0.5).cpu().detach().squeeze().permute(1, 2, 0) +1 )/ 2 )
    figure.add_subplot(rows,cols, cols*i+2)
    plt.axis("off")
    plt.imshow((output[i,:].cpu().detach().squeeze().permute(1, 2, 0) +1 )/ 2 )  
plt.tight_layout()
plt.show()

In [None]:
# Set up our training environment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

images, _ = next(iter(train_loader))
_ , _ , height, width = images.shape
model = UNet(retain_dim=True, output_size=(height,width))
model = model.to(device)

loss_func = nn.L1Loss(reduction = 'mean')

optim = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 40

# These variables will store the data for analysis
training_losses = []
val_losses = []

os.makedirs(os.path.dirname("checkpoints"), exist_ok=True)

In [None]:
images, labels = next(iter(train_loader))
images, labels = images.to(device), labels.to(device)
output = model(images)
loss = loss_func(output,labels)
print(loss)

In [None]:
print('Starting Training')

for epoch in range(1,num_epochs):
    
    if epoch % 10 == 0:
        torch.save(model.state_dict(), f'checkpoints/{epoch}.pkl')
        
    
    # Go into training mode
    model.train()
    
    # Train the model and evaluate on the training set
    total_train_loss = 0
    total_val_loss = 0

    for i, (images, labels) in enumerate(train_loader):
    
        # Move images to device and create an image prediction
        images, labels = images.to(device), labels.to(device)
        output = model(images)
    
        #Evaluate the loss of our model and take a step
        loss = loss_func(output,labels)
        optim.zero_grad()
        loss.backward()
        optim.step()

        total_train_loss += loss*images.shape[0]
        
        del images, labels, output
        torch.cuda.empty_cache()

    total_train_loss /= len(train_dataset)
    training_losses.append(total_train_loss.to("cpu"))

    # Evaluate the model on the val set
    # Reset counters and switch to eval mode
    model.eval()

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            output = model(images)
            loss = loss_func(output,labels)
            total_val_loss += loss*images.shape[0]
            
            del images, labels, output
            
    total_val_loss /= len(val_dataset)        
    val_losses.append(total_val_loss.item()) 
    
    print(f'Epoch [{epoch + 1}],Train Loss: {total_train_loss}, Val Loss: {total_val_loss}')

In [None]:
plt.title("Training curve")

print(training_losses)
print(val_losses)
plt.plot(range(len(training_losses)),training_losses,'r')
plt.plot(range(len(val_losses)),val_losses,'g')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

In [None]:
torch.save(model.state_dict(), 'checkpoints/final_weights.pkl')

In [None]:
import matplotlib.pyplot as plt

images, labels = next(iter(val_loader))
images, labels = images.to(device), labels.to(device)
figure = plt.figure(figsize=(15,15))
output = model(images)
cols, rows = 1, 4
for i in range(cols * rows):
    figure.add_subplot(rows,cols*3, 3*i+1)
    plt.axis("off")
    plt.imshow((images[i,:].cpu().detach().squeeze().permute(1, 2, 0) +1 )/2 )
    figure.add_subplot(rows,cols*3, 3*i+2)
    plt.axis("off")
    plt.imshow((labels[i,:].cpu().detach().squeeze().permute(1, 2, 0) +1 )/2 )
    figure.add_subplot(rows,cols*3, 3*i+3)
    plt.axis("off")
    plt.imshow((output[i,:].cpu().detach().squeeze().permute(1, 2, 0) +1 )/2 )
    
plt.tight_layout()
plt.show()