In [None]:
import os
import torch.nn as nn
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, random_split
import numpy as np
from tqdm import tqdm
from PIL import Image
from collections import namedtuple

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
!wget https://thinkautonomous-segmentation.s3.eu-west-3.amazonaws.com/archive.zip && unzip archive.zip

In [None]:
# basic imports
import os
import cv2
import numpy as np
from collections import namedtuple

# DL library imports
import torch
from torchvision import transforms
from torch.utils.data import Dataset


###################################
# FILE CONSTANTS
###################################

# Convert to torch tensor and normalize images using Imagenet values
preprocess = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean=(0.485, 0.56, 0.406), std=(0.229, 0.224, 0.225))
                ])



cs_labels = namedtuple('CityscapesClass', ['name', 'train_id', 'color'])
cs_classes = [
    cs_labels('road',          0, (128, 64, 128)),
    cs_labels('sidewalk',      1, (244, 35, 232)),
    cs_labels('building',      2, (70, 70, 70)),
    cs_labels('wall',          3, (102, 102, 156)),
    cs_labels('fence',         4, (190, 153, 153)),
    cs_labels('pole',          5, (153, 153, 153)),
    cs_labels('traffic light', 6, (250, 170, 30)),
    cs_labels('traffic sign',  7, (220, 220, 0)),
    cs_labels('vegetation',    8, (107, 142, 35)),
    cs_labels('terrain',       9, (152, 251, 152)),
    cs_labels('sky',          10, (70, 130, 180)),
    cs_labels('person',       11, (220, 20, 60)),
    cs_labels('rider',        12, (255, 0, 0)),
    cs_labels('car',          13, (0, 0, 142)),
    cs_labels('truck',        14, (0, 0, 70)),
    cs_labels('bus',          15, (0, 60, 100)),
    cs_labels('train',        16, (0, 80, 100)),
    cs_labels('motorcycle',   17, (0, 0, 230)),
    cs_labels('bicycle',      18, (119, 11, 32)),
    cs_labels('ignore_class', 19, (0, 0, 0)),
]

train_id_to_color = [c.color for c in cs_classes if (c.train_id != -1 and c.train_id != 255)]
train_id_to_color = np.array(train_id_to_color)

In [None]:
class CityscapesDataset(Dataset):
  def __init__(self, root_dir: str, mode: str, tf=None):
    self.tf = tf
    self.images_dir = os.path.join(root_dir, 'images', mode)
    self.images_src = [os.path.join(self.images_dir, img_src) for img_src in os.listdir(self.images_dir)]

    self.target_dir = os.path.join(root_dir, 'groundtruth', mode)
    self.targets_src = [os.path.join(self.target_dir, trgt_src) for trgt_src in os.listdir(self.target_dir)]

  def __len__(self):
    return len(self.images_src)

  def __getitem__(self, index):
    image = Image.open(self.images_src[index])
    image = self.tf(image)

    target = Image.open(self.targets_src[index])
    target = np.array(target)
    target[target == 255] = 19
    target = torch.from_numpy(target).long()

    return image, target


def data_loader(rootDir):
  data = CityscapesDataset(rootDir, mode='train', tf=preprocess)
  test_set = CityscapesDataset(rootDir, mode='val', tf=preprocess)

  total_count = len(data)
  train_count = int(0.8 * total_count)
  train_set, val_set = random_split(data, (train_count, total_count - train_count),
          generator=torch.Generator().manual_seed(1))

  train_loader = DataLoader(train_set, batch_size=1,shuffle=True)
  test_loader = DataLoader(test_set, batch_size=1,shuffle=False)
  val_loader = DataLoader(val_set, batch_size=1, shuffle=False)

  return train_loader, test_loader, val_loader

In [None]:
class Block(nn.Module):

  def __init__(self, in_channels, out_channels):
    super().__init__()

    self.dconv3x3 = nn.Sequential(
          nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
          nn.BatchNorm2d(out_channels),
          nn.ReLU(inplace=True),
          nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
          nn.BatchNorm2d(out_channels),
          nn.ReLU(inplace=True),
          )

  def forward(self, x):
    x = self.dconv3x3(x)
    return x

In [None]:
'''

def forward(self, x, y):
    y = self.up(y)
    dw = (x.size()[2] - y.size()[2]) // 2
    dh = (x.size()[3] - y.size()[2]) // 2

    #Centered Cropping
    x = x[:, :, dw:x.size()[2] - dw, dh:x.size()[3] - dh]

    return self.conv(torch.cat((x, y), 1))


def forward(self, x1, x2):
    x1 = self.up(x1)

    dY = x2.size(2) - x1.size(2)
    dX = x2.size(3) - x2.size(3)

    #Symmetric Padding
    x1 = F.pad(x1, [dX // 2, dX - dX // 2,
                  dY // 2, dY - dY //2])

    x = self.conv(torch.cat([x2, x1], dim=1))

    return x
'''

'\n\ndef forward(self, x, y):\n    y = self.up(y)\n    dw = (x.size()[2] - y.size()[2]) // 2\n    dh = (x.size()[3] - y.size()[2]) // 2\n\n    #Centered Cropping\n    x = x[:, :, dw:x.size()[2] - dw, dh:x.size()[3] - dh]\n\n    return self.conv(torch.cat((x, y), 1))\n\n\ndef forward(self, x1, x2):\n    x1 = self.up(x1)\n\n    dY = x2.size(2) - x1.size(2)\n    dX = x2.size(3) - x2.size(3)\n\n    #Symmetric Padding\n    x1 = F.pad(x1, [dX // 2, dX - dX // 2,\n                  dY // 2, dY - dY //2])\n\n    x = self.conv(torch.cat([x2, x1], dim=1))\n\n    return x\n'

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, layer_channels=[64, 128, 256, 512, 1024]):
        super().__init__()

        self.encoder = nn.ModuleList()

        prev_channels = in_channels
        for ch in layer_channels:
            self.encoder.append(Block(prev_channels, ch))
            prev_channels = ch

        self.pool = nn.MaxPool2d(2, 2)

        self.bottleneck = Block(layer_channels[-1], layer_channels[-1]*2)

        self.decoder = nn.ModuleList()
        rev_channels = layer_channels[::-1]
        for ch in rev_channels:
            self.decoder.append(nn.ConvTranspose2d(ch*2, ch, kernel_size=2, stride=2))
            self.decoder.append(Block(ch*2, ch))

        self.out_conv = nn.Conv2d(layer_channels[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.encoder:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for i in range(0, len(self.decoder), 2):
            x = self.decoder[i](x)
            skip = skip_connections[i // 2]
            dY = skip.size(2) - x.size(2)
            dX = skip.size(3) - x.size(3)
            x = torch.nn.functional.pad(x, [dX // 2, dX - dX // 2,
                  dY // 2, dY - dY //2])

            x = torch.cat([x, skip], dim=1)
            x = self.decoder[i + 1](x)

        return self.out_conv(x)




In [None]:
class BiDiceLoss(nn.Module):
  def __init__(self, smooth=1e-6):
      super().__init__()
      self.smooth = smooth

  def forward(self, pred, target):
      pred = torch.sigmoid(pred)
      intersection = (pred * target).flatten(1).sum(1)
      union = (pred + target).flatten(1).sum(1)
      coeff = (2 * intersection + self.smooth) / (union + self.smooth)

      return 1 - coeff.mean()


class MultiDiceLoss(nn.Module):
  def __init__(self, smooth=1e-6):
    super().__init__()
    self.smooth = smooth

  def forward(self, pred, target):
    pred = torch.nn.functional.softmax(pred, dim=1)
    num_classes = pred.size(1)
    dice_loss = 0

    for c in range(num_classes):
      pred_c = pred[:, c]
      target_c = target[:, c]

      intersection = (pred * target).sum(dim=(2, 3))
      union = (pred + target).sum(dim=(2, 3))
      coeff = (2 * intersection + self.smooth) / (union + self.smooth)

      dice_loss += (1 - coeff)

    return dice_loss.mean() / num_classes



class MultiDiceLoss3D(nn.Module):
    def __init__(self, smooth=1e-6):
      super().__init__()
      self.smooth = smooth

    def forward(self, pred, target):
      pred = torch.nn.functional.softmax(pred, dim=1)
      num_classes = pred.size(1)
      dice_loss = 0

      for c in range(num_classes):
        pred_c = pred[:, c]
        target_c = target[:, c]
        intersection = (pred_c * target_c).sum(dim=(2, 3, 4))
        union = pred_c.sum(dim=(2, 3, 4)) + target_c.sum(dim=(2, 3, 4))
        dice_loss += (2. * intersection + self.smooth) / (union + self.smooth)

      return 1 - dice_loss.mean() / num_classes


class PixelAcc(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, outputs, targets, batch_size):
    for idx in range(batch_size):
        output = outputs[idx]
        target = targets[idx]
        correct = torch.sum(torch.eq(output, target).long())
        self.acc += correct / np.prod(np.array(output.shape)) / batch_size

    return self.acc.item()



class IOU(nn.Module):
  def __init__(self):
    super().__init__()
    self.eps = 1e-6

  def forward(self, outputs, targets, batch_size, n_classes):
    for idx in range(batch_size):
        outputs_cpu = outputs[idx].cpu()
        targets_cpu = targets[idx].cpu()

        for c in range(n_classes):
            i_outputs = np.where(outputs_cpu == c)  # indices of 'c' in output
            i_targets = np.where(targets_cpu == c)  # indices of 'c' in target
            intersection = np.intersect1d(i_outputs, i_targets).size
            union = np.union1d(i_outputs, i_targets).size
            class_iou[c] += (intersection + self.eps) / (union + self.eps)

    class_iou /= batch_size

    return class_iou


"""def forward(self, outputs, targets, n_classes):
  iou = []
  for c in range(n_classes):
      pred_c = (outputs == c)
      target_c = (targets == c)
      intersection = (pred_c & target_c).sum().float()
      union = (pred_c | target_c).sum().float()
      iou.append((intersection + self.eps) / (union + self.eps))
  return torch.tensor(iou).mean()"""

'def forward(self, outputs, targets, n_classes):\n  iou = []\n  for c in range(n_classes):\n      pred_c = (outputs == c)\n      target_c = (targets == c)\n      intersection = (pred_c & target_c).sum().float()\n      union = (pred_c | target_c).sum().float()\n      iou.append((intersection + self.eps) / (union + self.eps))\n  return torch.tensor(iou).mean()'

In [None]:
model = UNet(in_channels=3, out_channels=20).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
epochs = 1

loss_train_history = []
acc_train_history = []
loss_val_history = []
loss_val_history = []

train_loader, val_loader, test_loader = data_loader(rootDir='/content/')

for epoch in range(epochs):
  running_train_loss = 0.
  running_train_acc = 0.
  running_val_loss = 0.
  running_val_acc = 0.
  model.train()

  for idx, (images, targets) in enumerate(tqdm(train_loader)):
    images = images.to(device)
    targets = targets.to(device)

    classification_map = model(images)
    loss = MultiDiceLoss()(classification_map, targets)
    running_train_loss += loss.item()
    print(loss.item())

    # acc = pix_acc(classification_map, targets, None)
    # running_train_acc += acc.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  model.eval()
  with torch.no_grad():
    for val_images, val_targets in val_loader:
      val_images = val_images.to(device)
      val_targets = val_targets.to(device)

      classification_map_val = model(val_images)
      val_loss = BiDiceLoss()(classification_map_val, val_targets)
      running_val_loss += val_loss.item()

      # val_acc = pix_acc(classification_map_val, val_targets)
      # running_val_acc += val_acc.item()

  loss_train_history.append(running_train_loss)
  # acc_train_history.append(running_train_acc)
  loss_val_history.append(running_val_loss)
  # acc_train_history.append(running_val_acc)
