### Importing Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

import torch.optim as optim

import os
import shutil
from PIL import Image, ImageOps

import random

#import any other library you need below this line

### Defining the Dataset Class

In [12]:
from cv2 import transform
import cv2

class Cell_data(Dataset):
  def __init__(self, data_dir, size, train = 'True', train_test_split = 0.6, augment_data = True):

    super(Cell_data, self).__init__()
    # todo
    #initialize the data class
    self.root_dir = data_dir
    self.t = train
    self.size = size
    self.split = train_test_split
    self.augmentmode = augment_data
    self.image = sorted(os.listdir(os.path.join(self.root_dir,'scans')))
    self.mask = sorted(os.listdir(os.path.join(self.root_dir,'labels')))
   
    self.images = []
    self.masks = []
    tf = transforms.ToTensor()

    self.train = random.sample(self.image, int(len(self.image) * (self.split)))
    self.test = [x for x in self.image if x not in self.train]

    dim = (self.size, self.size)
    rotate = [30, 60, 90, 120]

    if self.t:
      for idx in range(len(self.train)):
      #load image and mask from index idx of your data
        xi = Image.open(os.path.join(self.root_dir,'scans',self.train[idx]))
        ym = Image.open(os.path.join(self.root_dir,'labels',self.train[idx]))
        xi = xi.resize(dim)
        ym = ym.resize(dim)
        x = tf(xi)
        y = tf(ym)

        self.images.append(x)
        self.masks.append(y)
        th = transforms.RandomHorizontalFlip(p=1)
        h_flipped, h_flip = th(x), th(y)
        self.images.append(h_flipped)
        self.masks.append(h_flip)
        #todo
        tv = transforms.RandomVerticalFlip(p=1)
        v_flipped, v_flip = tv(x), tv(y)
        self.images.append(v_flipped)
        self.masks.append(v_flip)
        rand_rotate = random.choice(rotate)
        tr = transforms.RandomRotation(rand_rotate)
        r_flipped, r_flip = tr(x), tr(y)
        self.images.append(r_flipped)
        self.masks.append(r_flip)

    else:
      for idx in range(len(self.test)):
        xi = Image.open(os.path.join(self.root_dir,'scans',self.test[idx]))
        ym = Image.open(os.path.join(self.root_dir,'labels',self.test[idx]))
        xi, ym = xi.resize(dim), ym.resize(dim)
        x, y = tf(xi), tf(ym)
        self.images.append(x)
        self.masks.append(y)
        
  def __getitem__(self, idx):
      # todo
      return self.images[idx], self.masks[idx]
        
      

  def __len__(self):
    return len(self.images)


### Define the Model
1. Define the Convolution blocks
2. Define the down path
3. Define the up path
4. combine the down and up path to get the final model

In [3]:
from torch import conv2d


class twoConvBlock(nn.Module):
  def __init__(self,in_channels, out_channels):
    super(twoConvBlock, self).__init__()
    self.layer = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=False),
      nn.ReLU(),
      nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
      nn.BatchNorm2d(out_channels),
      nn.ReLU())

  def forward(self, x):
    #todo
    return self.layer(x)

class downStep(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(downStep, self).__init__()
    #todo
    self.maxp = nn.Sequential(nn.MaxPool2d(2),
    twoConvBlock(in_channels, out_channels)
    )

  def forward(self, x):
    #todo
    down1 = self.maxp(x)
    return down1

class upStep(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(upStep, self).__init__()
    #todo
    self.expanding = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2,stride = 2)
    self.c = twoConvBlock(in_channels, out_channels)

  def forward(self, x, y):
    #todo
    x1 = self.expanding(x)
    crop_x = (y.size()[2] - x1.size()[2]) // 2
    crop_y = (y.size()[3] - x1.size()[3]) // 2

    y = y[:,:,crop_x:y.size()[2] - crop_x,crop_y:y.size()[3] - crop_y] 

    blk = torch.cat([y,x1], dim=1)
    return self.c(blk)


class UNet(nn.Module):
  def __init__(self):
    super(UNet, self).__init__()

    #initialize the complete model
    self.c1 = twoConvBlock(1,64)
    self.d1 = downStep(64, 128)
    self.d2 = downStep(128, 256)
    self.d3 = downStep(256, 512)
    self.d4 =downStep(512,1024)
    self.u1 = upStep(1024, 512)
    self.u2 = upStep(512, 256)
    self.u3 = upStep(256, 128)
    self.u4 = upStep(128, 64)
    self.c2 = nn.Conv2d(64, 2, kernel_size=1) 

  def forward(self, x):
    
    y = self.c1(x)
    
    l1 = self.d1(y)
    
    l2 = self.d2(l1)
    
    l3 = self.d3(l2)
    
    l4 = self.d4(l3)
    
    l6 = self.u1(l4, l3)
    
    l7 = self.u2(l6, l2)
    
    l8 = self.u3(l7, l1)
    
    l9 = self.u4(l8, y)
    
    out = self.c2(l9)

    return out







### Training

In [None]:
#Paramteres
lr =  0.0005

#number of training epochs
epoch_n = 100
#input image-mask size
image_size = 572
#root directory of project
root_dir = os.getcwd()

#training batch size
batch_size = 4

#use checkpoint model for training
load = False

#use GPU for training
gpu = True

data_dir = os.path.join(root_dir, 'A2\data\cells')


trainset = Cell_data(data_dir = data_dir, size = image_size)
trainloader = DataLoader(trainset, batch_size = batch_size, shuffle=True)

testset = Cell_data(data_dir = data_dir, size = image_size, train = False)
testloader = DataLoader(testset, batch_size = batch_size)

device = torch.device('cuda:0' if gpu else 'cpu')
model = UNet().to(device)

if load:
  print('loading model')
  model.load_state_dict(torch.load('A2\checkpoint2.pt'))


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.0005)

model.train()
for e in range(epoch_n):
  epoch_loss = 0
  model.train()
  for i, data in enumerate(trainloader):

    image, label = data

    image = image.to(device)
    label = torch.ceil(label)
    label = label.long().squeeze().to(device)

    pred = model(image)
    crop_x = (label.shape[1] - pred.shape[2]) // 2        # -193
    crop_y = (label.shape[2] - pred.shape[3]) // 2        # 92 

    label = label[:, crop_x: label.shape[1] - crop_x, crop_y: label.shape[2] - crop_y]  

    loss = criterion(pred, label)

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

    epoch_loss += loss.item()


  print('Epoch %d / %d --- Loss: %.4f' % (e + 1, epoch_n, epoch_loss / trainset.__len__()))
  
  model.eval()

  total = 0
  correct = 0
  total_loss = 0
  acc = 0
  losses = 0
  with torch.no_grad():
      for i, data in enumerate(testloader):
          image, label = data

          image = image.to(device)
          label = torch.ceil(label)
          label = label.long().squeeze().to(device)

          pred = model(image)
          crop_x = (label.shape[1] - pred.shape[2]) // 2          
          crop_y = (label.shape[2] - pred.shape[3]) // 2          

          label = label[:, crop_x: label.shape[1] - crop_x, crop_y: label.shape[2] - crop_y]           
          #label = label.squeeze().to(device)
          loss = criterion(pred, label)
          total_loss += loss.item()

          _, pred_labels = torch.max(pred, dim = 1)

          total += label.shape[0] * label.shape[1] * label.shape[2]
          correct += (pred_labels == label).sum().item()

      print('Accuracy: %.4f ---- Loss: %.4f' % (correct / total, total_loss / testset.__len__()))
      # wandb.log({"Validation Loss": (total_loss/ testset.__len__())})
      # wandb.log({"Accuracy": (correct / total)})
torch.save(model.state_dict(), 'checkpoint3.pt')

 

### Testing and Visualization

In [14]:
model.eval()


output_masks = []
output_labels = []
inputs = []
with torch.no_grad():
  for i in range(testset.__len__()):
    image, labels = testset.__getitem__(i)
    im = image.squeeze()
    inputs.append(im)
    #labels = torch.ceil(labels)
    input_image = image.unsqueeze(0).to(device)
    pred = model(input_image)
    output_mask = torch.max(pred, dim = 1)[1].cpu().numpy().squeeze()
    crop_x = (labels.shape[1] - output_mask.shape[0]) // 2
    crop_y = (labels.shape[2] - output_mask.shape[1]) // 2
    labels = labels[:,crop_x: labels.shape[1] - crop_x, crop_y: labels.shape[2] - crop_y].numpy().squeeze()
    output_masks.append(output_mask)
    output_labels.append(labels)
    



In [None]:
fig, axes = plt.subplots(testset.__len__(), 3, figsize = (10, 50))

for i in range(testset.__len__()):
  axes[0, 0].set_title('Ground Truth Label')
  axes[0, 1].set_title('Input Images')
  axes[0, 2].set_title('Predicted Labels')
  axes[i, 0].imshow(output_labels[i])
  axes[i, 0].axis('off')
  axes[i, 1].imshow(inputs[i])
  axes[i, 1].axis('off')
  axes[i, 2].imshow(output_masks[i])
  axes[i, 2].axis('off')