In [None]:
!pip install torch_snippets

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch_snippets import *
from torchvision import transforms
from tqdm import tqdm 
import matplotlib.pyplot as plt
import numpy as np
from torch.optim import Adam, lr_scheduler

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class SegData(Dataset):
  def __init__(self, image_directory, annotation_directory):
    self.items = stems(image_directory)
    self.image_directory = image_directory
    self.annotation_directory = annotation_directory
    self.transformer = transforms.Compose([transforms.ToTensor(), 
                                      transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
  
  def __len__(self):
    return len(self.items)
  
  def __getitem__(self, index):
    image = read(f'{self.image_directory}/{self.items[index]}.png', 1)
    image = cv2.resize(image, (224,224))
    mask = read(f'{self.annotation_directory}/{self.items[index]}.png')
    mask = cv2.resize(mask, (224,224))
    return image, mask

  def collate_fn(self, batch):
    ims, masks = list(zip(*batch))
    ims = torch.cat([self.transformer(im.copy()/255.)[None] for im in ims]).float().to(device)
    ce_masks = torch.cat([torch.Tensor(mask[None]) for mask in masks]).long().to(device)
    return ims, ce_masks

In [None]:
class UNET(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = self.contract_block(in_channels, 32, 7, 3)
        self.conv2 = self.contract_block(32, 64, 3, 1)
        self.conv3 = self.contract_block(64, 128, 3, 1)

        self.upconv3 = self.expand_block(128, 64, 3, 1)
        self.upconv2 = self.expand_block(64*2, 32, 3, 1)
        self.upconv1 = self.expand_block(32*2, out_channels, 3, 1)

    def __call__(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)

        upconv3 = self.upconv3(conv3)

        upconv2 = self.upconv2(torch.cat([upconv3, conv2], 1))
        upconv1 = self.upconv1(torch.cat([upconv2, conv1], 1))

        return upconv1

    def contract_block(self, in_channels, out_channels, kernel_size, padding):

        contract = nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                                 )

        return contract

    def expand_block(self, in_channels, out_channels, kernel_size, padding):

        expand = nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) 
                            )
        return expand

In [None]:
def loss_function(prediction, target):
  ce = nn.CrossEntropyLoss()
  ce_loss = ce(prediction, target)
  acc = (torch.max(prediction, 1)[1] == target).float().mean()
  return ce_loss, acc

def train_batch(model, data, optimizer):
  model.train()
  ims, ce_mask = data 
  prediction = model(ims)
  optimizer.zero_grad()
  loss_value, acc = loss_function(prediction, ce_mask)
  loss_value.backward()
  optimizer.step()

  return loss_value.item(), acc.item()

def valid_batch(model, data):
  model.eval()
  ims, ce_mask = data
  prediction = model(ims)
  loss_value, acc = loss_function(prediction, ce_mask)
  return loss_value.item(), acc.item()

def get_model():
  model = UNET(3,12).to(device)
  optimizer = Adam(model.parameters(), lr = 1e-3)
  return model, optimizer

def save_plot(tr_list, val_list, title):
  plt.plot(tr_list, label = "Training")
  plt.plot(val_list, label = "Validation")
  plt.xlabel("Epoch")
  plt.ylabel("Accuracy")
  plt.legend()
  plt.title(title)
  plt.show()
  plt.savefig(f"{title}.png")
  print(f"Saved {title}.png")

def get_dataLoaders(image_directory, annotation_directory):
  ds_set = SegData(image_directory, annotation_directory)
  ds_loader = DataLoader(ds_set, batch_size = 4, shuffle = True, collate_fn = ds_set.collate_fn)
  return ds_loader

def run(epoch_range, train_dir, train_anno_dir, val_dir, val_anno_dir, real_dir = None, model_dict = None):
  model, optimizer = get_model()
  if model_dict:
    model.load_state_dict(torch.load(model_dict))

  tr_dl = get_dataLoaders(train_dir, train_anno_dir)
  vl_dl = get_dataLoaders(val_dir, val_anno_dir)

  if real_dir:
    real_dl = get_dataLoaders(real_dir, real_dir)

  train_loss = []
  train_acc = []
  val_loss = []
  val_acc = []
  max_acc = 0

  print("---------------------------------------------------------")
  for epoch in range(epoch_range):
    tl = 0
    ta = 0
    count = 0
    for bx, data in enumerate(tqdm(tr_dl, desc = "TRAINING")):
      loss, acc = train_batch(model, data, optimizer)
      tl+=loss
      ta+=ta
      count+=1
    tl = tl/count
    ta = ta/count
    train_loss.append(tl)
    train_acc.append(ta)
    
    vl = 0
    va = 0
    count = 0
    for bx, data in enumerate(tqdm(vl_dl, desc = "VALIDATE")):
      loss, acc = valid_batch(model, data)
      vl+=loss
      va+=acc
      count+=1
    vl = vl/count
    va = va/count
    val_loss.append(vl)
    val_acc.append(va)


    print("\n Epoch: {}/{} | Average Training Loss: {:.4f} | Average Validation Loss: {:.4f} | Validation Accuracy: {:.4f} | Learning Rate: {}".format(
      epoch+1,
      epoch_range,
      tl, 
      vl,
      va,
      1e-4
    ))

    if va > max_acc:
      max_acc = va
      torch.save(model.state_dict(),f'Epoch_{epoch+1}_model.pth')
      print("New Model Saved!")
    
    print("---------------------------------------------------------")
  save_plot(train_loss, val_loss, "Training and Validation Loss")
  save_plot(train_acc, val_acc, "Training and Validation Accuracy")
  
  im, mask = next(iter(real_dl))
  _mask = model(im)
  _, _mask = torch.max(_mask, dim=1)
  subplots([im[0].permute(1,2,0).detach().cpu()[:,:,0], 
                      _mask.permute(1,2,0).detach().cpu()[:,:,0]],nc=2, 
                      titles=['Original image', 
                      'Predicted mask'])

    

In [None]:
if __name__ =='__main__':
  run(30, "/content/drive/MyDrive/Data/dataset1/images_prepped_train", 
      "/content/drive/MyDrive/Data/dataset1/annotations_prepped_train",
      "/content/drive/MyDrive/Data/dataset1/images_prepped_test",
      "/content/drive/MyDrive/Data/dataset1/annotations_prepped_test",
      "/content/drive/MyDrive/Data/dataset1/Test")