In [None]:
import torch
import torch.nn as nn

In [None]:
def crop_image(tensor, target):
  tensor_size = tensor.size()[2]
  target_size = target.size()[2]
  delta = tensor_size - target_size
  delta = delta // 2
  return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]

def double_conv(in_ch, out_ch):
  conv = nn.Sequential(
      nn.Conv2d(in_ch, out_ch, kernel_size=3),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_ch, out_ch, kernel_size=3),
      nn.ReLU(inplace=True)
  )
  return conv

class Unet(nn.Module):
  def __init__(self):
    super(Unet, self).__init__()

    self.max_pool_2x2 = nn.MaxPool2d(kernel_size = 2, stride = 2)

    self.double_conv_1 = double_conv(1, 64)
    self.double_conv_2 = double_conv(64, 128)
    self.double_conv_3 = double_conv(128, 256)
    self.double_conv_4 = double_conv(256, 512)
    self.double_conv_5 = double_conv(512, 1024)

    self.up_trans_1 = nn.ConvTranspose2d(1024, 512, kernel_size = 2, stride = 2)
    self.up_conv_1 = double_conv(1024, 512)

    self.up_trans_2 = nn.ConvTranspose2d(512, 256, kernel_size = 2, stride = 2)
    self.up_conv_2 = double_conv(512, 256)

    self.up_trans_3 = nn.ConvTranspose2d(256, 128, kernel_size = 2, stride = 2)
    self.up_conv_3 = double_conv(256, 128)

    self.up_trans_4 = nn.ConvTranspose2d(128, 64, kernel_size = 2, stride = 2)
    self.up_conv_4 = double_conv(128, 64)

    self.final_conv = nn.Conv2d(64, 2, kernel_size = 1)

  def forward(self, image):
    x1 = self.double_conv_1(image) #
    x2 = self.max_pool_2x2(x1)
    x3 = self.double_conv_2(x2) #
    x4 = self.max_pool_2x2(x3)
    x5 = self.double_conv_3(x4) #
    x6 = self.max_pool_2x2(x5)
    x7 = self.double_conv_4(x6) #
    x8 = self.max_pool_2x2(x7)
    x9 = self.double_conv_5(x8)

    x = self.up_trans_1(x9)
    y = crop_image(x7, x)
    x = self.up_conv_1(torch.cat([x, y], 1))

    x = self.up_trans_2(x)
    y = crop_image(x5, x)
    x = self.up_conv_2(torch.cat([x, y], 1))

    x = self.up_trans_3(x)
    y = crop_image(x3, x)
    x = self.up_conv_3(torch.cat([x, y], 1))

    x = self.up_trans_4(x)
    y = crop_image(x1, x)
    x = self.up_conv_4(torch.cat([x, y], 1))

    x = self.final_conv(x)
    # print(f"final: {x.size()}")
    return x



In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

class CarvanaDataset(Dataset):
  def __init__(self, image_dir, mask_dir, transform=None):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transform
    self.images = os.listdir(image_dir)

  def __len__(self):
    return len(self.images)

  # Make some changes here
  def __getitem__(self, index):
    img_path = os.path.join(self.image_dir, self.images[index])
    mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpeg", "_make.gif"))
    image = np.array(Image.open(img_path).convert("RGB"))
    mask = np.array(Image.open(mask_path).convert("L"))
    mask[mask == 255] = 1

    if self.transform is not None:
      augmentations = self.transform(image = image, mask = mask)
      image = augmentations['image']
      mask = augmentations['mask']

    return image, mask

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch.optim as optim
import tqdm

Learning_Rate = 1e-4
Device = ('cuda' if torch.cuda.is_available() else 'cpu')
Batch_Size = 16
Num_Epochs = 3
Num_Workers = 2
Image_Height = 572
Image_Width = 572
Load_Model = True
Train_Img_Dir = '' # location of that
Train_Mask_Dir = ''
Val_Img_Dir = ''
Val_Mask_Dir = ''


def train_fn(loader, model, optimizer, loss_fn, scaler):
  loop = tqdm(loader)
  for batch_idx, (data, targets) in enumerate(loop):
    data = data.to(Device)
    targets = targets.float().unsqueeze(1).to(Device)

    # forward
    with torch.cuda.amp.autocast():
      predictions = model(data)
      loss = loss_fn(predictions, targets)

    # backward
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    # update tqdm loop
    loop.set_postfix(loss = loss.item())

def main():
  train_transform = A.Compose(
      [
          A.Resize(height=Image_Height, width=Image_Width),
          A.Rotate(limit=35, p=1.0),
          A.HorizontalFlip(p=0.5),
          A.VerticalFlip(p=0.5),
          A.Normalize(
              mean = [0.0, 0.0, 0.0],
              std = [1.0, 1.0, 1.0],
              max_pixel_value = 255.0,
          ),
          ToTensorV2()
      ]
  )

  val_transform = A.Compose(
      [
          A.Resize(Image_Height, Image_Width),
          A.Normalize(
              mean = [0.0, 0.0, 0.0],
              std = [1.0, 1.0, 1.0],
              max_pixel_value = 255.0,
          ),
          ToTensorV2()
      ]
  )

  model = Unet(in_channels=3, out_channels=2).to(Device)
  loss_fn = nn.BCEWithLogitsLoss()
  optimizer = optim.Adam(model.parameters(), lr=Learning_Rate)

  train_loader, val_loader = get_loaders(
        Train_Img_Dir,
        Train_Mask_Dir,
        Val_Img_Dir,
        Val_Mask_Dir,
        Batch_Size,
        train_transform,
        val_transform,
        Num_Workers,
  )

  if Load_Model:
    load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)

  check_accuracy(val_loader, model, Device)
  scaler = torch.cuda.map.GradScaler()

  for epoch in range(Num_Epochs):
    train_fn(train_loader, model, optimizer, loss_fn, scaler)

    # save model
    checkpoint = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    save_checkpoint(checkpoint)

    # check accuracy
    check_accuracy(val_loader, model, Device)

    #print some examples to the folder
    save_predictions_as_images(
        val_loader, model, folder='saved_images/', device = Device
    )



In [None]:
import torchvision
from torch.utils.data import DataLoader

def save_checkpoint(state, filename = 'my_checkpoint.pth.tar'):
  print("===> Saving the checkpoint....")
  torch.save(state, filename)

def load_checkpoint(checkpoint, model):
  print("===> Loading the checkpoint....")
  model.load_state_dict(checkpoint['state_dict'])

def get_loaders(
    train_img_dir,
    train_mask_dir,
    val_img_dir,
    val_mask_dir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
):
  train_ds = CarvanaDataset(
      image_dir = train_img_dir,
      mask_dir = train_mask_dir,
      transform = train_transform
  )

  train_loader = DataLoader(
      train_ds,
      batch_size = batch_size,
      num_workers = num_workers,
      shuffle = True
  )

  val_ds = CarvanaDataset(
      image_dir = val_img_dir,
      mask_dir = val_mask_dir,
      transform = val_transform
  )

  val_loader = DataLoader(
      val_ds,
      batch_size = batch_size,
      num_workers = num_workers,
      shuffle = False
  )

  return train_loader, val_loader

def check_accuracy(loader, model, device=Device):
  num_correct = 0
  num_pixels = 0
  dice_score = 0
  model.eval()

  with torch.no_grad():
    for x,y in loader:
      x = x.to(device)
      y = y.to(device).unsqueeze(1)
      preds = torch.sigmoid(model(x))
      preds = (preds > 0.5).float()
      num_correct += (preds == y).sum()
      num_pixels += torch.numel(preds)
      dice_score += (2 * (preds * y).sum()) / (
          (preds * y).sum() + 1e-8 )
  print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
  print(f"Dice Score : {dice_score/len(loader)}")
  model.train()

def save_predictions_as_images(loader, model, folder='saved_images/', device='cuda'):
  model.eval()

  for idx, (x,y) in enumerate(loader):
    x = x.to(Device)
    with torch.no_grad():
      preds = torch.sigmoid(model(x))
      preds = (preds > 0.5).float()
    torchvision.utils.save_image(preds, f"{folder}/pred_{idx}.png")
    torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

  model.train()
