In [None]:
import torch
import torch.nn as nn
from torchvision.models import vgg16_bn,VGG16_BN_Weights
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
import os
import PIL
from PIL import Image
from torchvision.transforms import Normalize, ToTensor, Compose
from matplotlib import pyplot as plt
import math
import pandas as pd

In [None]:
if torch.cuda.is_available():
    device=torch.device(type="cuda", index=0)
else:
    device=torch.device(type="cpu", index=0)

In [None]:
img=Image.open('/kaggle/input/sai-vessel-segmentation2/all/train/21_training.tif')
mask=Image.open('/kaggle/input/sai-vessel-segmentation2/all/train/21_manual1.gif')

print(img.size)
print(np.array(img).shape, np.array(mask).shape)
print(np.unique(mask))

In [None]:
class TrainCustomDS(Dataset):
    def __init__(self,path,transform=None):
        super().__init__()
        self.path=path
        _,_,self.filepaths=next(os.walk(path))
        self.length=16
        self.transform=Compose([
            ToTensor(), 
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
    
    def __len__(self):
        return self.length
    
    def __getitem__(self,idx):
        idx = idx + 21
        path = self.path + str(idx) + "_training.tif"
        img = self._get_image(path)
        img = self.transform(img)
        
        path = self.path + str(idx) + "_manual1.gif"
        mask = np.array(self._get_image(path))
        mask = torch.from_numpy(mask).type(torch.long)
        mask[mask==255]=1
                      
        return img, mask
    
    def _get_image(self, path, size = 256):
        img = Image.open(path)
        rimg = img.resize((size,size),PIL.Image.NEAREST)
        return rimg

class ValCustomDS(Dataset):
    def __init__(self,path,transform=None):
        super().__init__()
        self.path=path
        _,_,self.filepaths=next(os.walk(path))
        self.length=int(len(self.filepaths)/2)-16
        self.transform=Compose([ToTensor(), Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
        
    
    def __len__(self):
        return self.length
    
    def __getitem__(self,idx):
        idx=idx+37
        path=self.path + str(idx) + "_training.tif"
        img=self._get_image(path)
        img=self.transform(img)
        
        path=self.path + str(idx) + "_manual1.gif"
        mask=self._get_image(path)
        mask=np.array(mask)
        mask=torch.from_numpy(mask).type(torch.long)
        mask[mask==255]=1
                      
        return img, mask
    
    def _get_image(self, path, size = 256):
        img = Image.open(path)
        rimg = img.resize((size,size),PIL.Image.NEAREST)
        return rimg

In [None]:
trainDS=TrainCustomDS("/kaggle/input/sai-vessel-segmentation2/all/train/")
valDS=ValCustomDS("/kaggle/input/sai-vessel-segmentation2/all/train/")

batch_size=4

trainDL=DataLoader(dataset=trainDS,batch_size=batch_size,shuffle=True)
valDL=DataLoader(dataset=valDS,batch_size=batch_size)

In [None]:
class UNET_Full(nn.Module):
    def __init__(self, encoder, center, decoder):
        super().__init__()
        self.encoder=encoder
        self.center=center
        self.decoder=decoder
    
    def forward(self,x):
        enc=self.encoder(x)
        cen=self.center(enc[-1])
        y=self.decoder(cen, enc)
        return y

In [None]:
class Encoder(nn.Module):
    def __init__(self,pretrained_network):
        super().__init__()
        self.encoder=pretrained_network
    
    def forward(self,x):
        enc_out=[]
        for layer in self.encoder.features:
                x=layer(x)
                enc_out.append(x)
                        
        return enc_out

In [None]:
class Center(nn.Sequential):
  def __init__(self):
        conv1=nn.Conv2d(512,1024,3,padding=1)
        bn1=nn.BatchNorm2d(1024)
        rl1=nn.ReLU()

        conv2=nn.Conv2d(1024,1024,3,padding=1)
        bn2=nn.BatchNorm2d(1024)
        rl2=nn.ReLU()
        
        super().__init__(conv1,bn1,rl1,conv2,bn2,rl2)

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
               
        self.rl=nn.ReLU()
                        
        self.conv5_up=nn.Conv2d(1024,512,3,padding=1)
        self.conv5_1=nn.Conv2d(1024,512,3,padding=1)
        self.bn5_1=nn.BatchNorm2d(512)
        self.conv5_2=nn.Conv2d(512,512,3,padding=1)
        self.bn5_2=nn.BatchNorm2d(512)
        self.conv5_3=nn.Conv2d(512,512,3,padding=1)
        self.bn5_3=nn.BatchNorm2d(512)
        
        self.conv4_up=nn.Conv2d(512,512,3,padding=1)
        self.conv4_1=nn.Conv2d(1024,512,3,padding=1)
        self.bn4_1=nn.BatchNorm2d(512)
        self.conv4_2=nn.Conv2d(512,512,3,padding=1)
        self.bn4_2=nn.BatchNorm2d(512)
        self.conv4_3=nn.Conv2d(512,512,3,padding=1)
        self.bn4_3=nn.BatchNorm2d(512)
        
        self.conv3_up=nn.Conv2d(512,256,3,padding=1)
        self.conv3_1=nn.Conv2d(512,256,3,padding=1)
        self.bn3_1=nn.BatchNorm2d(256)
        self.conv3_2=nn.Conv2d(256,256,3,padding=1)
        self.bn3_2=nn.BatchNorm2d(256)
        self.conv3_3=nn.Conv2d(256,256,3,padding=1)
        self.bn3_3=nn.BatchNorm2d(256)
        
        self.conv2_up=nn.Conv2d(256,128,3,padding=1)
        self.conv2_1=nn.Conv2d(256,128,3,padding=1)
        self.bn2_1=nn.BatchNorm2d(128)
        self.conv2_2=nn.Conv2d(128,128,3,padding=1)
        self.bn2_2=nn.BatchNorm2d(128)
        
        self.conv1_up=nn.Conv2d(128,64,3,padding=1)
        self.conv1_1=nn.Conv2d(128,64,3,padding=1)
        self.bn1_1=nn.BatchNorm2d(64)
        self.conv1_2=nn.Conv2d(64,64,3,padding=1)
        self.bn1_2=nn.BatchNorm2d(64)
        
        self.convfinal=nn.Conv2d(64,2,kernel_size=1)
    
    def forward(self,x, encoder_features_output):
        x=F.interpolate(x,scale_factor=2, mode="nearest")
        x=self.conv5_up(x)
        x=self.rl(x)
        x=torch.cat((x,encoder_features_output[42]),dim=1)
        x=self.conv5_1(x)
        x=self.bn5_1(x)
        x=self.rl(x)
        x=self.conv5_2(x)
        x=self.bn5_2(x)
        x=self.rl(x)
        x=self.conv5_3(x)
        x=self.bn5_3(x)
        x=self.rl(x)
        
        x=F.interpolate(x,scale_factor=2, mode="nearest")
        x=self.conv4_up(x)
        x=self.rl(x)
        x=torch.cat((x,encoder_features_output[32]),dim=1)
        x=self.conv4_1(x)
        x=self.bn4_1(x)
        x=self.rl(x)
        x=self.conv4_2(x)
        x=self.bn4_2(x)
        x=self.rl(x)
        x=self.conv4_3(x)
        x=self.bn4_3(x)
        x=self.rl(x)
        
        x=F.interpolate(x,scale_factor=2, mode="nearest")
        x=self.conv3_up(x)
        x=self.rl(x)
        x=torch.cat((x,encoder_features_output[22]),dim=1)
        x=self.conv3_1(x)
        x=self.bn3_1(x)
        x=self.rl(x)
        x=self.conv3_2(x)
        x=self.bn3_2(x)
        x=self.rl(x)
        x=self.conv3_3(x)
        x=self.bn3_3(x)
        x=self.rl(x)
        
        x=F.interpolate(x,scale_factor=2, mode="nearest")
        x=self.conv2_up(x)
        x=self.rl(x)
        x=torch.cat((x,encoder_features_output[12]),dim=1)
        x=self.conv2_1(x)
        x=self.bn2_1(x)
        x=self.rl(x)
        x=self.conv2_2(x)
        x=self.bn2_2(x)
        x=self.rl(x)
        
        x=F.interpolate(x,scale_factor=2, mode="nearest")
        x=self.conv1_up(x)
        x=self.rl(x)
        x=torch.cat((x,encoder_features_output[5]),dim=1)
        x=self.conv1_1(x)
        x=self.bn1_1(x)
        x=self.rl(x)
        x=self.conv1_2(x)
        x=self.bn1_2(x)
        x=self.rl(x)
        
        logits=self.convfinal(x)
        
        return logits

In [None]:
def train_one_epoch(dataloader, model,loss_fn, optimizer):
    model.train()
    track_loss=0
    XintY=0
    X=0
    Y=0
    for i, (imgs, masks) in enumerate(dataloader):
        imgs=imgs.to(device)
        masks=masks.to(device)
        
        preds=model(imgs)
        
        loss=loss_fn(preds,masks)
        
        track_loss+=loss.item()
        
        predclass=torch.argmax(preds,dim=1)
        
        Y+=predclass.sum().item()
        X+=masks.sum().item()
        
        
        predclass[predclass==0]=2
        
        XintY+=(predclass==masks).type(torch.float).sum().item()
        
        print("Trainig Batch",i+1,":")
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        running_loss=round(track_loss/(i+1),2)
        running_dice_coef=round(((2*XintY)/(X+Y)),2)
        
        print("Training Batch", i+1,":","/",len(dataloader), "Running Loss:",running_loss, "Running Dice_Coef:",running_dice_coef)
            
    # Running = Epoch --- at end of epoch
    return running_loss, running_dice_coef


def val_one_epoch(dataloader, model,loss_fn):
    model.eval()
    track_loss=0
    XintY=0
    X=0
    Y=0
    with torch.no_grad():
        for i, (imgs, masks) in enumerate(dataloader):
            imgs=imgs.to(device)
            masks=masks.to(device)
            
            preds=model(imgs)
            
            loss=loss_fn(preds,masks)
            
            track_loss+=loss.item()
            
            predclass=torch.argmax(preds,dim=1)
            
            Y+=predclass.sum().item()
            X+=masks.sum().item()
            
            predclass[predclass==0]=2
            
            XintY+=(predclass==masks).type(torch.float).sum().item()
            
            print("Validation Batch",i+1,":")
            

            running_loss=round(track_loss/(i+1),2)
            running_dice_coef=round(((2*XintY)/(X+Y)),2)
            
            print("Validation Batch", i+1,":","/",len(dataloader), "Running Loss:",running_loss, "Running Dice_Coef:",running_dice_coef)
            
 
    return running_loss, running_dice_coef

## Training on Split Data

In [None]:
pretrained_network=vgg16_bn(weights=VGG16_BN_Weights.DEFAULT)

for param in pretrained_network.features.parameters():
    param.requires_grad=False

encoder=Encoder(pretrained_network).to(device)
center=Center().to(device)
decoder=Decoder().to(device)

model=UNET_Full(encoder,center, decoder).to(device)

loss_fn=nn.CrossEntropyLoss()
lr=0.001
optimizer=torch.optim.Adam(params=model.parameters(), lr=lr)
n_epochs=40

for i in range(n_epochs):
    print("Epoch No:",i+1)
    train_epoch_loss, train_epoch_dice_coef=train_one_epoch(trainDL,model,loss_fn,optimizer)
    print("Training Epoch Loss:", train_epoch_loss, "Training Epoch Dice_Coef:", train_epoch_dice_coef)
    val_epoch_loss, val_epoch_dice_coef=val_one_epoch(valDL,model,loss_fn)
    print("Validation Epoch Loss:", val_epoch_loss, "Validation Epoch Dice_Coef:", val_epoch_dice_coef)
    print("--------------------------------------------------")


# Fine-Tuning
for param in pretrained_network.features.parameters():
    param.requires_grad=True

n_epochs=40  
for i in range(n_epochs):
    print("Epoch No:",i+1)
    train_epoch_loss, train_epoch_dice_coef=train_one_epoch(trainDL,model,loss_fn,optimizer)
    print("Training Epoch Loss:", train_epoch_loss, "Training Epoch Dice_Coef:", train_epoch_dice_coef)
    val_epoch_loss, val_epoch_dice_coef=val_one_epoch(valDL,model,loss_fn)
    print("Validation Epoch Loss:", val_epoch_loss, "Validation Epoch Dice_Coef:", val_epoch_dice_coef)
    print("--------------------------------------------------")


In [None]:
def plotres(img,pred,mask=None):
    img[0,:,:]=img[0,:,:]*0.229 + 0.485
    img[1,:,:]=img[1,:,:]*0.224 + 0.456
    img[2,:,:]=img[2,:,:]*0.225 + 0.406
    if mask!=None:
        print("Image Shape:",img.shape,"Mask Shape:", mask.shape, "Pred Shape:",pred.shape, "Image dtype", img.dtype, "Mask dtype",mask.dtype, "Pred dtype",pred.dtype)
        print("Mask Unique:",mask.unique())
    else:
        print("Image Shape:",img.shape, "Pred Shape:", pred.shape, "Image dtype:",img.dtype, "Pred dtype:",pred.dtype)
    
    print("Pred Unique:",pred.unique())
    
    plt.figure(figsize=(10,5))
    
    plt.subplot(1,3,1)
    plt.title("Original Image 512 x 512")
    plt.imshow(torch.permute(img.cpu(),(1,2,0)))
    
    if mask!=None:
        plt.subplot(1,3,2)
        plt.title("Mask Image  512 x 512")
        plt.imshow(mask.cpu())
    
    plt.subplot(1,3,3)
    plt.title("Predicted Image  512 x 512")
    plt.imshow(pred.cpu())
    plt.show()

In [None]:
imgs,masks=next(iter(valDL))
model.eval()

imgs=imgs.to(device)
masks=masks.to(device)

with torch.no_grad():
    preds=model(imgs)
    
    predclass=torch.argmax(preds,dim=1)
            
    Y=predclass.sum().item()
    X=masks.sum().item()
            
    predclass[predclass==0]=2
    
    XintY=(predclass==masks).type(torch.float).sum().item()
        
    dice_coef=round((2*XintY)/(X+Y),2)
    
    
print("Validation Dice Coef:",dice_coef)

predclass[predclass==2]=0
plotres(imgs[0],predclass[0],masks[0])
plotres(imgs[1],predclass[1],masks[1])
plotres(imgs[2],predclass[2],masks[2])
plotres(imgs[3],predclass[3],masks[3])

## Training on Full Data

In [None]:
trainDS=TrainCustomDS("/kaggle/input/sai-vessel-segmentation2/all/train/", "yes")
batch_size=4
n_epochs=40
trainDL=DataLoader(dataset=trainDS,batch_size=batch_size,shuffle=True)

for i in range(n_epochs):
    print("Epoch No:",i+1)
    train_epoch_loss, train_epoch_dice_coef=train_one_epoch(trainDL,model,loss_fn,optimizer)
    print("Training Epoch Loss:", train_epoch_loss, "Training Epoch Dice_Coef:", train_epoch_dice_coef)
    print("--------------------------------------------------")

In [None]:
class TestCustomDS(Dataset):
    def __init__(self,path):
        super().__init__()
        self.path=path
        _,_,self.filepaths=next(os.walk(path))
        self.length=len(self.filepaths)
        self.transform=Compose([ToTensor(), Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
    
    def __len__(self):
        return self.length
    
    def __getitem__(self,idx):
        idx=idx+1
        if idx <= 9:
            path=self.path + "0" + str(idx) + "_test.tif"
        else:
            path=self.path + str(idx) + "_test.tif"
        
        img=self._get_image(path)
        img=self.transform(img)
        
        return img
    
    def _get_image(self, path, size = 256):
        img = Image.open(path)
        rimg = img.resize((size,size),PIL.Image.NEAREST)
        return rimg
    

In [None]:
torch.cuda.empty_cache()

In [None]:
testDS=TestCustomDS("/kaggle/input/sai-vessel-segmentation2/all/test/")

batch_size=2
testDL=DataLoader(dataset=testDS,batch_size=batch_size)

In [None]:
def eval_one_epoch(dataloader, model):
    model.eval()
    outputs=[]
    for i, imgs in enumerate(dataloader):
        imgs=imgs.to(device)
        preds=model(imgs)
               
        with torch.no_grad():
            for i in range(preds.shape[0]):
                pred=preds[i,:,:,:]
                pred=torch.argmax(pred,dim=0).cpu()
                
                plotres(imgs[i],pred)
                
                predf=pred.flatten()
                
                pixelidx=np.where(predf==1)[0]+1
                
                run_lengths=[]
                
                for pxid in pixelidx:
                    if len(run_lengths)==0:
                        run_lengths.extend((pxid,1))
                    elif pxid>prev+1:
                        run_lengths.extend((pxid,1))
                    else:
                        run_lengths[-1]+=1
                    prev=pxid
                
                output = ' '.join([str(r) for r in run_lengths])
                
                outputs.append(output)
    return outputs

outputs=eval_one_epoch(testDL,model)
df=pd.DataFrame(columns=['Id','Predicted'])   
df['Id']=[str(i) for i in range(20)]
df['Predicted']=outputs
df.to_csv("submission.csv", index=None)