# Model Implementation

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(DoubleConv, self).__init__()

    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),

        nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

  def forward(self, x):
    return self.conv(x)

class UNet(nn.Module):
  def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
    super(UNet, self).__init__()

    self.downs = nn.ModuleList()
    self.ups = nn.ModuleList()
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    # Downsamplin part of UNet
    for feature in features:
      self.downs.append(DoubleConv(in_channels, feature))
      in_channels = feature

    # Upsampling part of UNet
    for feature in reversed(features):
      self.ups.append(
          nn.ConvTranspose2d(
              feature*2, feature, kernel_size=2, stride=2)
      )
      self.ups.append(DoubleConv(feature*2, feature))

      self.bottleneck = DoubleConv(features[-1], features[-1]*2)
      self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

  def forward(self, x):
    skip_connections = []

    for down in self.downs:
      x = down(x)
      skip_connections.append(x)
      x = self.pool(x)

    x = self.bottleneck(x)
    skip_connections = skip_connections[::-1]

    for idx in range(0, len(self.ups), 2):
      x = self.ups[idx](x)
      skip_connection = skip_connections[idx//2]

      if x.shape != skip_connection.shape:
        x = TF.resize(x, size=skip_connection.shape[2:])

      concat_skip = torch.cat((x, skip_connection), dim=1)
      x = self.ups[idx+1](concat_skip)

    return self.final_conv(x)

In [None]:
def test():
  x = torch.randn((3, 1, 161, 161))
  model = UNet(in_channels=1, out_channels=1)
  preds = model(x)
  print(preds.shape)
  print(x.shape)

  assert preds.shape == x.shape

# Dataset Loading

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

class SegmentationDataset(Dataset):
  def __init__(self, image_dir, mask_dir, transform=None):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transform
    self.images = os.listdir(image_dir)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, index):
    img_path = os.path.join(self.image_dir, self.images[index])
    mask_path = os.path.join(self.mask_dir, self.images[index].replace('.jpg', '.png'))

    image = np.array(Image.open(img_path).convert('RGB'), dtype=np.float32)
    mask = np.array(Image.open(mask_path).convert('L'), dtype=np.float32)
    mask[mask == 255] = 1

    if self.transform is not None:
      augmentations = self.transform(image=image, mask = mask)

      image = augmentations['image']
      mask = augmentations['mask']

    return image, mask

def check_accuracy(loader, model, device='cuda'):
      num_correct = 0
      num_pixels = 0
      dice_score = 0

      model.eval()

      with torch.no_grad():
        for x, y in loader:
          x = x.to(device)
          y = y.to(device).unsqueeze(1)

          preds = torch.sigmoid(model(x))
          preds = (preds > 0.5).float()

          num_correct += (preds == y).sum()
          num_pixels += torch.numel(preds)
          dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)

      print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
      print(f"Dice score: {dice_score/len(loader)}")

      model.train()

def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"):
      model.eval()
      for idx, (x, y) in enumerate(loader):
          x = x.to(device=device)
          with torch.no_grad():
              preds = torch.sigmoid(model(x))
              preds = (preds > 0.5).float()
          torchvision.utils.save_image(
              preds, f"{folder}/pred_{idx}.png"
          )
          torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

      model.train()

# Utils

In [None]:
import torch
import torchvision
from torch.utils.data import DataLoader

def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
  train_ds = SegmentationDataset(
          image_dir=train_dir,
          mask_dir=train_maskdir,
          transform=train_transform,
      )

  train_loader = DataLoader(
          train_ds,
          batch_size=batch_size,
          num_workers=num_workers,
          pin_memory=pin_memory,
          shuffle=True,
      )

  val_ds = SegmentationDataset(
          image_dir=val_dir,
          mask_dir=val_maskdir,
          transform=val_transform,
      )

  val_loader = DataLoader(
          val_ds,
          batch_size=batch_size,
          num_workers=num_workers,
          pin_memory=pin_memory,
          shuffle=False,
      )
  return train_loader, val_loader

aasffsfaffffffffffffffffafjiuiui# Training

In [None]:
import torch
import albumentations  as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim

# Hyperparameters
LR = 1e-4
BATCH_SIZE = 32
EPOCHS = 50
WORKERS = 2
HEIGHT = 360
WIDTH = 240
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = ""
TRAIN_MASK_DIR = ""
TEST_IMG_DIR = ""
TEST_MASK_DIR = ""
DEVICE = 'cuda' if torch.cuda.is_available() else ' cpu'

In [None]:
def train_fun(loader, model, optimizer, loss_fun, scaler):
  loop = tqdm(loader)

  for batch_idx, (data, targets) in enumerate(loop):
    data = data.to(device=DEVICE)
    targets = targets.float().unsqueeze(1).to(device=DEVICE)

    # forward
    with torch.cuda.amp.autocast():
      predictions = model(data)
      loss = loss_fun(predictions, targets)


    # backward
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    # update tqdm loop
    loop.set_postfix(loss=loss.item)


train_transform = A.Compose(
      [
          A.Resize(height=HEIGHT, width=WIDTH),
          A.Rotate(limit=35, p=0.3),
          A.HorizontalFlip(p=0.2),
          A.VerticalFlip(p=0.1),
          A.Normalize(
              mean=[0.0, 0.0, 0.0],
              std=[1.0, 1.0, 1.0],
              max_pixel_value=255.0,
          ),
          ToTensorV2(),
      ],
  )

val_transforms = A.Compose(
      [
          A.Resize(height=HEIGHT, width=WIDTH),
          A.Normalize(
              mean=[0.0, 0.0, 0.0],
              std=[1.0, 1.0, 1.0],
              max_pixel_value=255.0,
          ),
          ToTensorV2(),
      ],
  )

model = UNet(in_channels=3, out_channels=1).to(DEVICE)
loss_fun = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

train_loader, val_loader = get_loaders(
      TRAIN_IMG_DIR,
      TRAIN_MASK_DIR,
      TEST_IMG_DIR,
      TEST_MASK_DIR,
      BATCH_SIZE,
      train_transform,
      val_transforms,
      WORKERS,
      PIN_MEMORY,
  )

if LOAD_MODEL:
    load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)

scaler = torch.cuda.amp.GradScaler()

for epoch in range(EPOCHS):
  train_fun(train_loader, model, optimizer, loss_fun, scaler)

  # save model
  checkpoint = {
    "state_dict": model.state_dict(),
    "optimizer": optimizer.state_dict(),
  }

  save_checkpoint(checkpoint)

  # check accuracy
  check_accuracy(val_loader, model, device=DEVICE)

  # print some of predictions
  save_predictions_as_imgs(val_loader, model, folder="saved_images", device=DEVICE)

100%|██████████| 32/32 [00:47<00:00,  1.48s/it, loss=<built-in method item of Tensor object at 0x7aed955f1670>]


=> Saving checkpoint
Got 13185045/17280000 with acc 76.30
Dice score: 0.2066970318555832


100%|██████████| 32/32 [00:41<00:00,  1.30s/it, loss=<built-in method item of Tensor object at 0x7aed9673a930>]


=> Saving checkpoint
Got 14411395/17280000 with acc 83.40
Dice score: 0.6968648433685303


100%|██████████| 32/32 [00:42<00:00,  1.32s/it, loss=<built-in method item of Tensor object at 0x7aed95604fe0>]


=> Saving checkpoint
Got 14371263/17280000 with acc 83.17
Dice score: 0.7177923917770386


100%|██████████| 32/32 [00:42<00:00,  1.32s/it, loss=<built-in method item of Tensor object at 0x7aed95606610>]


=> Saving checkpoint
Got 14819068/17280000 with acc 85.76
Dice score: 0.7359612584114075


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed95605210>]


=> Saving checkpoint
Got 14850891/17280000 with acc 85.94
Dice score: 0.7441447973251343


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed95605080>]


=> Saving checkpoint
Got 14392460/17280000 with acc 83.29
Dice score: 0.7260320782661438


100%|██████████| 32/32 [00:42<00:00,  1.33s/it, loss=<built-in method item of Tensor object at 0x7aed95599b70>]


=> Saving checkpoint
Got 14640726/17280000 with acc 84.73
Dice score: 0.7476402521133423


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed955b6e80>]


=> Saving checkpoint
Got 14926270/17280000 with acc 86.38
Dice score: 0.7528188824653625


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed95635850>]


=> Saving checkpoint
Got 14967407/17280000 with acc 86.62
Dice score: 0.7675405740737915


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed95598540>]


=> Saving checkpoint
Got 14557441/17280000 with acc 84.24
Dice score: 0.7398164868354797


100%|██████████| 32/32 [00:41<00:00,  1.30s/it, loss=<built-in method item of Tensor object at 0x7aed95681710>]


=> Saving checkpoint
Got 15192387/17280000 with acc 87.92
Dice score: 0.7517632246017456


100%|██████████| 32/32 [00:42<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed955b66b0>]


=> Saving checkpoint
Got 15198816/17280000 with acc 87.96
Dice score: 0.7792037725448608


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed956822a0>]


=> Saving checkpoint
Got 15064833/17280000 with acc 87.18
Dice score: 0.7756943106651306


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed95681120>]


=> Saving checkpoint
Got 15286621/17280000 with acc 88.46
Dice score: 0.7702258229255676


100%|██████████| 32/32 [00:41<00:00,  1.30s/it, loss=<built-in method item of Tensor object at 0x7aed95490310>]


=> Saving checkpoint
Got 15330925/17280000 with acc 88.72
Dice score: 0.7838887572288513


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed95491b20>]


=> Saving checkpoint
Got 15366340/17280000 with acc 88.93
Dice score: 0.7864099144935608


100%|██████████| 32/32 [00:42<00:00,  1.32s/it, loss=<built-in method item of Tensor object at 0x7aed954901d0>]


=> Saving checkpoint
Got 14649067/17280000 with acc 84.77
Dice score: 0.7629921436309814


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed9549ca90>]


=> Saving checkpoint
Got 15358853/17280000 with acc 88.88
Dice score: 0.773702085018158


100%|██████████| 32/32 [00:41<00:00,  1.30s/it, loss=<built-in method item of Tensor object at 0x7aed95680bd0>]


=> Saving checkpoint
Got 15329581/17280000 with acc 88.71
Dice score: 0.7885964512825012


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed9549d120>]


=> Saving checkpoint
Got 15421051/17280000 with acc 89.24
Dice score: 0.8028994798660278


100%|██████████| 32/32 [00:41<00:00,  1.30s/it, loss=<built-in method item of Tensor object at 0x7aed9549e480>]


=> Saving checkpoint
Got 14983437/17280000 with acc 86.71
Dice score: 0.7578513026237488


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed95490a90>]


=> Saving checkpoint
Got 15149905/17280000 with acc 87.67
Dice score: 0.7257282137870789


100%|██████████| 32/32 [00:42<00:00,  1.33s/it, loss=<built-in method item of Tensor object at 0x7aed954c5e90>]


=> Saving checkpoint
Got 15387181/17280000 with acc 89.05
Dice score: 0.7788859009742737


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed954c63e0>]


=> Saving checkpoint
Got 15457150/17280000 with acc 89.45
Dice score: 0.7869301438331604


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed954c5210>]


=> Saving checkpoint
Got 15261742/17280000 with acc 88.32
Dice score: 0.7964499592781067


100%|██████████| 32/32 [00:42<00:00,  1.33s/it, loss=<built-in method item of Tensor object at 0x7aed954dd3f0>]


=> Saving checkpoint
Got 15364194/17280000 with acc 88.91
Dice score: 0.7656365036964417


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed954c7ba0>]


=> Saving checkpoint
Got 15586694/17280000 with acc 90.20
Dice score: 0.810836136341095


100%|██████████| 32/32 [00:42<00:00,  1.34s/it, loss=<built-in method item of Tensor object at 0x7aed954df010>]


=> Saving checkpoint
Got 15337752/17280000 with acc 88.76
Dice score: 0.7911452054977417


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed954c7a60>]


=> Saving checkpoint
Got 14134897/17280000 with acc 81.80
Dice score: 0.7124385237693787


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed954c44a0>]


=> Saving checkpoint
Got 15529798/17280000 with acc 89.87
Dice score: 0.7930610179901123


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed954ed0d0>]


=> Saving checkpoint
Got 15450669/17280000 with acc 89.41
Dice score: 0.8000096082687378


100%|██████████| 32/32 [00:42<00:00,  1.32s/it, loss=<built-in method item of Tensor object at 0x7aed954ee200>]


=> Saving checkpoint
Got 15418983/17280000 with acc 89.23
Dice score: 0.7965098023414612


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed954ed1c0>]


=> Saving checkpoint
Got 15230521/17280000 with acc 88.14
Dice score: 0.7833796739578247


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed954edda0>]


=> Saving checkpoint
Got 15474389/17280000 with acc 89.55
Dice score: 0.7856608033180237


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed95520cc0>]


=> Saving checkpoint
Got 15556225/17280000 with acc 90.02
Dice score: 0.8051518201828003


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed954de7a0>]


=> Saving checkpoint
Got 15510019/17280000 with acc 89.76
Dice score: 0.7944649457931519


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed95521300>]


=> Saving checkpoint
Got 15557385/17280000 with acc 90.03
Dice score: 0.8058380484580994


100%|██████████| 32/32 [00:42<00:00,  1.33s/it, loss=<built-in method item of Tensor object at 0x7aed95523ab0>]


=> Saving checkpoint
Got 15449872/17280000 with acc 89.41
Dice score: 0.7843325734138489


100%|██████████| 32/32 [00:41<00:00,  1.30s/it, loss=<built-in method item of Tensor object at 0x7aed95535440>]


=> Saving checkpoint
Got 15407807/17280000 with acc 89.17
Dice score: 0.8051965236663818


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed95522250>]


=> Saving checkpoint
Got 15352603/17280000 with acc 88.85
Dice score: 0.766822874546051


100%|██████████| 32/32 [00:42<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed955361b0>]


=> Saving checkpoint
Got 15502458/17280000 with acc 89.71
Dice score: 0.7892621159553528


100%|██████████| 32/32 [00:42<00:00,  1.33s/it, loss=<built-in method item of Tensor object at 0x7aed95535710>]


=> Saving checkpoint
Got 15424461/17280000 with acc 89.26
Dice score: 0.7705010175704956


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed95522f70>]


=> Saving checkpoint
Got 15231861/17280000 with acc 88.15
Dice score: 0.7993014454841614


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed95521800>]


=> Saving checkpoint
Got 15469559/17280000 with acc 89.52
Dice score: 0.780971884727478


100%|██████████| 32/32 [00:42<00:00,  1.33s/it, loss=<built-in method item of Tensor object at 0x7aed95554090>]


=> Saving checkpoint
Got 15560244/17280000 with acc 90.05
Dice score: 0.8177642226219177


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed95555530>]


=> Saving checkpoint
Got 15452665/17280000 with acc 89.43
Dice score: 0.8167678117752075


100%|██████████| 32/32 [00:42<00:00,  1.33s/it, loss=<built-in method item of Tensor object at 0x7aed95568b80>]


=> Saving checkpoint
Got 15580998/17280000 with acc 90.17
Dice score: 0.8048250079154968


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed9805fe20>]


=> Saving checkpoint
Got 15631906/17280000 with acc 90.46
Dice score: 0.8105065226554871


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed955233d0>]


=> Saving checkpoint
Got 15552350/17280000 with acc 90.00
Dice score: 0.8187683820724487


100%|██████████| 32/32 [00:41<00:00,  1.31s/it, loss=<built-in method item of Tensor object at 0x7aed9556ba60>]


=> Saving checkpoint
Got 15506793/17280000 with acc 89.74
Dice score: 0.8148431181907654
