[![Open All Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/analytics-club-iitm/DL-Marathon/blob/main/seg/unet-seg.ipynb)

In [None]:
from torch.utils.data import Dataset
from torchvision import transforms
from glob import glob
import os
import random
from PIL import Image
import torch
import numpy as np
SMOOTH = 1e-6
class SemSegData(Dataset):

    def __init__(self, root_dir):
        self.root_dir = root_dir
        
        self._init_dataset()
        
    def _init_dataset(self):

        self.img_list = glob(self.root_dir + '/original_images/*')
        self.mask_list = glob(self.root_dir + '/label_images_semantic/*')
        
    
    
    def __getitem__(self,index):
        img = self.img_list[index]
        mask = self.mask_list[index]

        img = (torch.tensor(np.array(Image.open(img).resize((128,128)))).permute(2,0,1) )/255
        
        
        mask = torch.unsqueeze(torch.tensor(np.array(Image.open(mask).resize((128,128),Image.NEAREST))),0)
        
            
        return img, mask

    def __len__(self):
        return len(self.img_list)

In [None]:

import torch.nn as nn
from torchvision import transforms
from torch.nn import functional as F
#3x3 convolution for the U-Net Architecture
class conv3x3(nn.Module):
    def __init__(self, k,p,s,in_channels = 256, out_channels = 256, activation = nn.ReLU()):
        super(conv3x3,self).__init__()

        self.layer1 = nn.Conv2d(in_channels, out_channels, kernel_size = k, padding = p,  stride = s,bias=False)
        self.layer2 = nn.BatchNorm2d(out_channels)
        self.layer3 = activation
    
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)

        return out

#Double Conv definition which is the biggest independent unit of the architecture
class Double_Conv(nn.Module):
    def __init__(self, in_channels = 256, out_channels = 256, activation = nn.ReLU()):
        super(Double_Conv,self).__init__()

        self.layer1 = conv3x3(3,1,1,in_channels, out_channels)
        
        self.layer2 = conv3x3(3,1,1,out_channels, out_channels)
        

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)

        return out

#Downsampling module for the U-Net
class Down(nn.Module):
    def __init__(self, in_channels , out_channels ):
        super(Down,self).__init__()
        '''
        #self.layer1 = nn.MaxPool2d(2)
        self.model = []
        self.conv1 = conv3x3(3,1,1,in_channels,out_channels)
        
        self.model = nn.ModuleList([Double_Conv(out_channels, out_channels)]+[Double_Conv(out_channels*2, out_channels) for i in range(2)])

        self.conv = nn.Sequential(nn.Conv2d(out_channels*4,out_channels,3,1,1),
                                  nn.BatchNorm2d(out_channels),
                                  nn.ReLU())
        '''
        self.model = nn.Sequential(Double_Conv(in_channels, out_channels),Double_Conv(out_channels, out_channels) )
        self.down = nn.MaxPool2d(2)

    def forward(self, x):
        
        x = self.model(x)

        return self.down(x),x

#Upsampling module for the U-Net
class Up(nn.Module):
    def __init__(self, k,p,s,in_channels, mid,out,activation = nn.ReLU()):
        super(Up,self).__init__()

        #self.layer1 = nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)
        self.layer2 = nn.Sequential(nn.ConvTranspose2d(in_channels, mid, kernel_size = k,padding =p,stride=s,bias=False),
                                    nn.BatchNorm2d(mid),
                                    nn.ReLU()
                                    )
        self.layer3 = nn.Sequential(nn.Conv2d(mid*2, out, kernel_size = 3,padding =1,stride=1,bias=False),
                                    nn.BatchNorm2d(out),
                                    nn.ReLU(),
                                    nn.Conv2d(out, out, kernel_size = 3,padding =1,stride=1,bias=False),
                                    nn.BatchNorm2d(out),
                                    nn.ReLU(),
                                    
                                    )

       


    def forward(self, x,inp):
        
        x = self.layer2(x)
        x = torch.cat((x,inp),1)
        
        out = self.layer3(x)
        
        
        return out

#The U-net Model
class Unet(nn.Module):
    def __init__(self, img_dim = 572, activation = nn.ReLU()):
        super(Unet,self).__init__()
        
        self.layer1 = nn.Sequential(nn.Conv2d(3,64,3,padding=1,bias=False),
                                    nn.BatchNorm2d(64),
                                    nn.ReLU(),
                                    nn.Conv2d(64,64,3,padding=1,bias=False),
                                    nn.BatchNorm2d(64),
                                    nn.ReLU()
                                    )
                                    
        self.layer2 = nn.ModuleList( [  Down(64,64), # 128
                                        Down(64,128), # 64
                                        Down(128,256), # 32
                                        Down(256,512), # 16
                                        Down(512,512), # 8
                                        Down(512,512), # 4
                                        ]
                                    )
        self.conv = conv3x3(3,1,1,512,512)
        self.layer8 = Up(4,1,2,512, 512,512) # 4
        self.layer9 = Up(4,1,2,512, 512,512) # 8
        self.layer10 = Up(4,1,2,512, 512,256) # 16
        self.layer11 = Up(4,1,2,256,256, 128) # 32
        self.layer12 = Up(4,1,2,128,128,64) # 64
        self.layer13 = Up(4,1,2,64,64,64) # 128
        self.layer14 = nn.Conv2d(64,24, 1)

    def forward(self, x):
        
        x = self.layer1(x)
        out = []
        for i in self.layer2:
            
            x,o = i(x)
            
            out.append(o)
        
        x = self.conv(x) 
        x = self.layer8(x,out[-1])
        
        x = self.layer9(x,out[-2])
        
        x = self.layer10(x,out[-3])
        
        x = self.layer11(x,out[-4])
        
        x = self.layer12(x,out[-5])
        
        x = self.layer13(x,out[-6])
        x = self.layer14(x)

        return x
        


In [None]:
import torch.optim as optim
import torchvision
import os
from torch.utils.data import DataLoader, random_split
import datasets
import random
import numpy as np 
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

def train():

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_epochs = 200
    batch_size = 8
    learning_rate = 0.001
    


    def soft_dice_loss(y_true, y_pred, epsilon=1e-6): 
        ''' 
        Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
        Assumes the `channels_last` format.
    
        # Arguments
            y_true: b x X x Y( x Z...) x c One hot encoding of ground truth
            y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax) 
            epsilon: Used for numerical stability to avoid divide by zero errors
        
        # References
            V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation 
            https://arxiv.org/abs/1606.04797
            More details on Dice loss formulation 
            https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72)
            
            Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022
        '''
        
        # skip the batch and class axis for calculating Dice score
        mask = torch.zeros((y_true.size(0),24,128,128)).to(device)
        outs = torch.softmax(y_pred,1)
        mask.permute(0,2,3,1).contiguous().view(-1,24)[torch.arange(y_true.size(0)*128*128),y_true.view(-1)] = 1
        
        numerator = 2. * (outs *mask).sum((2,3))
        denominator = torch.sum(torch.square(outs) + torch.square(mask),(2,3))
        
        return torch.mean(1 - torch.mean((numerator + epsilon) / (denominator + epsilon),1)) 

    dataset = SemSegData(root_dir='E:/DL-Week/archive/dataset/semantic_drone_dataset')
    data_len = len(dataset)

    trainset, valset = random_split(dataset, [int(0.8*data_len), (data_len - int(0.8*data_len))])

    trainloader = DataLoader(dataset=trainset, batch_size=batch_size, shuffle = True)
    valloader = DataLoader(dataset=valset, batch_size=batch_size, shuffle = False)

    model = Unet().to(device)
    
    criterion = nn.CrossEntropyLoss()
    best = float("inf")
    train_step = 0
    val_step = 0
    lr_lbmd = lambda e: max(0.7**(e//20), 0.00001/0.001)
    
    optimiser = optim.Adam(model.parameters(), lr=learning_rate)
    
    
    try: 
        
        model.load_state_dict(torch.load('best4.ckpt'))
        optimiser.load_state_dict(torch.load("optim_best4.ckpt"))
        print("model_loaded")
        

    except:
        optimiser = optim.Adam(model.parameters(), lr=learning_rate)
        print("save not found")
    
        
    lr_scheduler = optim.lr_scheduler.LambdaLR(optimiser, lr_lbmd)

    for epoch in range(num_epochs):
        print(f"Epoch: {epoch}")
        print("---------------")
        model.train()
        loss_cntr = []
        torch.cuda.empty_cache()

        for indx, (img, mask) in enumerate(trainloader):
            
            img, mask = img.to(device), mask.to(device)
            
            

            mask = mask.long()
            
            outputs = model(img)
            
            
            loss = criterion(outputs.permute(0,2,3,1).contiguous().view(-1,24),mask.view(-1,))
            loss2 = soft_dice_loss(mask,outputs)
            t_loss = loss + loss2
            optimiser.zero_grad()
            t_loss.backward(retain_graph = False)
            optimiser.step()


            loss_cntr.append(loss.item())
            #wandb.log({"train loss": loss.item(), "train step": train_step})
            train_step += 1

            datasets.progress_bar(progress=indx/len(trainloader), status=f"loss: {round(np.mean(loss_cntr), 4)}")
            writer.add_scalar("Loss/train", loss, train_step)
            
            
        datasets.progress_bar(progress=1, status=f"loss: {round(np.mean(loss_cntr), 4)}")
        print("\n")

        loss_cntr = []
        model.eval()
        
        torch.cuda.empty_cache()
        with torch.no_grad():
            for indx, (img2, mask2) in enumerate(valloader):
                
                img2, mask2 = img2.to(device), mask2.to(device)
                #out2 = model_base(img2)
                out2 = model(img2)
                
                
                    
                    

                mask2 = mask2.long()
                
            
                
                loss = criterion(out2.permute(0,2,3,1).contiguous().view(-1,24),mask2.view(-1,))
                loss2 = soft_dice_loss(mask2,out2)
            
                loss = loss + loss2        
                
                loss_cntr.append(loss.item())
                val_step += 1

                datasets.progress_bar(progress=indx/len(valloader), status=f"loss: {round(np.mean(loss_cntr), 4)}")
                writer.add_scalar("Loss/val", loss, val_step)
            datasets.progress_bar(progress=1, status=f"loss: {round(np.mean(loss_cntr), 4)}")

        #wandb.log({"val epoch": epoch, "val loss": np.mean(loss_cntr)})
        print("\n")

        if np.mean(loss_cntr) < best:
            best = np.mean(loss_cntr)
            torch.save(model.state_dict(), "best5.ckpt")
            torch.save(optimiser.state_dict(), "optim_best5.ckpt")
            print(f"current best: {round(best, 4)}")
        print("\n")

        torch.save(model.state_dict(), "last.ckpt")
        torch.save(optimiser.state_dict(), "optim_last.ckpt")
        lr_scheduler.step()
        a_mask = torch.argmax(out2,1,keepdim=True).long()
        grid = torchvision.utils.make_grid(a_mask)
        ref_grid = torchvision.utils.make_grid(mask2)
        grid2 = torchvision.utils.make_grid(img2)
        
        
        writer.add_image(f"img", grid2,val_step)
        writer.add_image(f"pred",grid,val_step)
        writer.add_image(f"mask", ref_grid,val_step)
        
if __name__ == "__main__":
    train()