In [32]:
import os
from glob import glob
import time

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader 
from sklearn import model_selection

import torchvision
from torchvision.transforms import transforms

In [141]:
class VesselDataset(Dataset):
    def __init__(self, root_dir, im_transforms, m_transforms):
        self.root_dir = root_dir
        self.im_transforms = im_transforms
        self.m_transforms = m_transforms
        self.images = data[0]
        self.masks = data[1]
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        image = image.convert("RGB")
        image = np.array(image, dtype=np.uint8)
        #r,g,b=cv2.split(image)
        #image=g
        mask = Image.open(self.masks[idx])
        
        mask = np.array(mask, dtype=np.uint8)
        
        image = self.im_transforms(image)
        mask = self.m_transforms(mask)
        
        return image, mask

In [142]:
img_aug = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((312, 312)),
    transforms.ToTensor(),
   transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
])

mask_aug = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((312, 312)),
    transforms.ToTensor()
    #!!!
])

# train/test dataset & dataloader
data = list(zip(sorted(glob('/content/drive/My Drive//DRIVE/training/images_png/*.png')), sorted(glob('/content/drive/My Drive//DRIVE/training/mask/*.gif'))))

train_data, test_data = model_selection.train_test_split(
      data, random_state=42, test_size=0.1
      )
train_dataset = VesselDataset(root_dir=train_data,
                        im_transforms=img_aug, 
                        m_transforms=mask_aug)

test_dataset = VesselDataset(root_dir=test_data,
                        im_transforms=img_aug, 
                        m_transforms=mask_aug)

train_loader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=5, shuffle=True)

In [143]:


data, target= next(iter(train_loader))
data.shape, target.shape



(torch.Size([2, 3, 312, 312]), torch.Size([2, 1, 312, 312]))

In [144]:
def double_conv(in_c,out_c):
  conv=nn.Sequential(
      nn.Conv2d(in_c,out_c,kernel_size=3),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_c,out_c,kernel_size=3),
      nn.ReLU(inplace=True)      
  )
  return conv

def crop_tensor(tensor,target_tensor):

  target_size = target_tensor.size()[2]
  tensor_size = tensor.size()[2]
  delta = tensor_size - target_size
  delta = delta // 2
  return tensor[:,:, delta:tensor_size-delta, delta:tensor_size-delta]

In [153]:
import torch
import torch.nn as nn

def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True))   


class UNet(nn.Module):

    def __init__(self, n_classes):
        super().__init__()
                
        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)        

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        
        
        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)
        
        self.conv_last = nn.Conv2d(64, n_classes,1)#,1
        
        
    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)   
        
        x = self.dconv_down4(x)
        
        x = self.upsample(x)        
        x = torch.cat([x, conv3], dim=1)
        
        x = self.dconv_up3(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv2], dim=1)       

        x = self.dconv_up2(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv1], dim=1)   
        
        x = self.dconv_up1(x)
        
        out = self.conv_last(x)
        out = torch.sigmoid(out)
        
        return out
    
model = UNet(n_classes=1).cuda()

In [154]:
#model = UNet().cuda()

In [155]:
def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()
    intersection = (pred * target).sum(dim=2).sum(dim=2)
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    return loss.mean()


def soft_dice_loss(inputs, targets):
        num = targets.size(0)
        m1  = inputs.view(num,-1)
        m2  = targets.view(num,-1)
        intersection = (m1 * m2)
        score = 2. * (intersection.sum(1)+1) / (m1.sum(1) + m2.sum(1)+1)
        score = 1 - score.sum()/num
        return score
    
optimizer = torch.optim.Adamax(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=40, verbose=True)

In [156]:


# Training model
num_epochs=200
for epoch in range(num_epochs):
    #print(time.ctime(), 'Epoch:', epoch)
    for batch_i, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        # Forward pass
        outputs = model(data)
        loss = soft_dice_loss(outputs, target)
        # Backward and optimizer
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    scheduler.step(loss.item())
    # logging
    print('Epoch {}/{},  Loss: {:.4f}'
              .format(epoch+1, num_epochs,  loss.item()))

# Weights
torch.save(model.state_dict(), 'unet.pt')



Epoch 1/200,  Loss: 0.6431
Epoch 2/200,  Loss: 0.6156
Epoch 3/200,  Loss: 0.5166
Epoch 4/200,  Loss: 0.4993
Epoch 5/200,  Loss: 0.4962
Epoch 6/200,  Loss: 0.4859
Epoch 7/200,  Loss: 0.4627
Epoch 8/200,  Loss: 0.4529
Epoch 9/200,  Loss: 0.4430
Epoch 10/200,  Loss: 0.4310
Epoch 11/200,  Loss: 0.4213
Epoch 12/200,  Loss: 0.4151
Epoch 13/200,  Loss: 0.4115
Epoch 14/200,  Loss: 0.5033
Epoch 15/200,  Loss: 0.4101
Epoch 16/200,  Loss: 0.4125
Epoch 17/200,  Loss: 0.4139
Epoch 18/200,  Loss: 0.4133
Epoch 19/200,  Loss: 0.4123
Epoch 20/200,  Loss: 0.4114
Epoch 21/200,  Loss: 0.4108
Epoch 22/200,  Loss: 0.4103
Epoch 23/200,  Loss: 0.4100
Epoch 24/200,  Loss: 0.4097
Epoch 25/200,  Loss: 0.4094
Epoch 26/200,  Loss: 0.4091
Epoch 27/200,  Loss: 0.4089
Epoch 28/200,  Loss: 0.4086
Epoch 29/200,  Loss: 0.4083
Epoch 30/200,  Loss: 0.4081
Epoch 31/200,  Loss: 0.4078
Epoch 32/200,  Loss: 0.4074
Epoch 33/200,  Loss: 0.4071
Epoch 34/200,  Loss: 0.4067
Epoch 35/200,  Loss: 0.4064
Epoch 36/200,  Loss: 0.4060
E