In [None]:
#####################
#  RUN ONLY ONCE
#####################


# !pip install --upgrade albumentations

In [None]:
# model.py
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF


class DoubleConv(nn.Module):
  def __init__(self,in_channels,out_channels):
    super(DoubleConv,self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1,bias=False),
        nn.BatchNorm2d(num_features=out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1,bias=False),
        nn.BatchNorm2d(num_features=out_channels),
        nn.ReLU(inplace=True),
    )
  def forward(self,x):
    return self.conv(x)


class UNET(nn.Module):
  def __init__(self,in_channels=3,out_channels = 1,features=[64,128,256,512]):
    super(UNET,self).__init__()
    self.downs = nn.ModuleList()
    self.ups = nn.ModuleList()
    self.pool=nn.MaxPool2d(kernel_size=2,stride=2)
    # self.in_channels = in_channels
    # self.out_channels = out_channels

    # Down U part of the UNET
    for feature in features:
      self.downs.append(DoubleConv(in_channels,feature))
      in_channels = feature

    # Up U part of the UNET
    for feature in reversed(features):
      self.ups.append(
          nn.ConvTranspose2d(
          in_channels=feature*2,out_channels=feature,kernel_size=2,stride=2
          )),
      self.ups.append(DoubleConv(feature*2,feature)),
    self.bottleneck = DoubleConv(features[-1],features[-1]*2)
    self.final_conv = nn.Conv2d(features[0],out_channels,kernel_size=1)


  def forward(self,x):
    skip_connections = []
    for down in self.downs:
      # print("prev shape before down")
      print(x.shape)
      x = down(x)
      skip_connections.append(x)
      # print("shape after down")
      # print(x.shape)
      # print("##")

      x = self.pool(x)

    x = self.bottleneck(x)
    # print("x shape after bottleneck")
    # print(x.shape)
    skip_connections = skip_connections[::-1]
    for idx in range(0,len(self.ups),2):
      # print("shape before up")
      # print(x.shape)
      x=self.ups[idx](x)
      # print("shape after up")
      # print(x.shape)
      skip_connection = skip_connections[idx//2]
      if x.shape !=skip_connection.shape:
        # print(x.shape)
        # print(skip_connection.shape)
        # print("input output not same shape")
        x=TF.resize(x,size=skip_connection.shape[2:])
      # print("x and skip_connections shape")
      # print(x.shape)
      # print(skip_connection.shape)
      concat_skip = torch.cat((skip_connection,x),dim=1)
      # print("concat shape")
      # print(concat_skip.shape)
      x= self.ups[idx+1](concat_skip)
      # print("self.ups[idx+1]")
      # print(x.shape)
    x = self.final_conv(x)
    # print("final x shape")
    # print(x.shape)
    return x

def test():
  x = torch.randn((2,1,530,390))
  model = UNET(in_channels=1,out_channels=1)
  preds=model(x)
  print(x.shape)
  print(preds.shape)
  assert(preds.shape==x.shape)

# if __name__=='__main__':
#   test()









In [None]:
# Dataset 
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

class Nucleus_Dataset(Dataset):
  def __init__(self,image_dir,mask_dir,transform=None):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transform
    self.images = os.listdir(image_dir)
    self.masks = os.listdir(mask_dir)
  
  def __len__(self):
    return len(self.images)
  
  def __getitem__(self,index):
    image_path = os.path.join(self.image_dir,self.images[index])
    mask_path = os.path.join(self.mask_dir,self.masks[index])
    image = np.array(Image.open(image_path).convert("RGB"))
    mask = np.array(Image.open(mask_path).convert("L"),dtype=np.float32)
    mask[mask==255.0]=1.0

    if self.transform is not None:
      augmentations= self.transform(image=image,mask=mask)
      image = augmentations["image"]
      mask = augmentations["mask"]

    return image,mask
    
    
    





In [None]:
# Utils
import torch
import torchvision
from torch.utils.data import DataLoader

def save_checkpoint(state,filename="checkpoint.pth.tar"):
  print("Saving checkpoint =>")
  torch.save(state,filename)

def load_checkpoint(checkpoint,model):
  print("Loading Checkpoint =>")
  model.load_state_dict(checkpoint["state_dict"])


def get_loader(
    train_img_dir,
    train_mask_dir,
    val_img_dir,
    val_mask_dir,
    batch_size,
    train_transforms,
    val_transforms,
    num_workers=4,
    pin_memory=True):
  train_ds = Nucleus_Dataset(train_img_dir,train_mask_dir,train_transforms)
  train_loader = DataLoader(
      train_ds,
      batch_size=batch_size,
      num_workers=num_workers,
      pin_memory=pin_memory,
      shuffle=True)
  val_ds = Nucleus_Dataset(val_img_dir,val_mask_dir,val_transforms)
  val_loader = DataLoader(
      val_ds,
      batch_size=batch_size,
      num_workers=num_workers,
      pin_memory=pin_memory,
      shuffle=False)
  return train_loader,val_loader


def check_accuracy(loader,model,device="cuda"):
  num_correct=0
  num_pixels=0
  dice_score=0
  model.eval()
  with torch.no_grad():
    for x,y in loader:
      x = x.to(device=device)
      y = y.to(device=device).unsqueeze(1)
      preds = torch.sigmoid(model(x))
      preds = (preds>0.5).float()
      num_correct += (preds==y).sum()
      num_pixels +=torch.numel(preds)
      dice_score +=(2*(preds*y).sum())/(
          (preds+y).sum() + 1e-8
      ))

  print(f"Got {num_correct}/{num_pixels} with accuracy {num_correct*100/num_pixels:.2f}")
  print(f"Dice Score: {dice_score/len(loader)}")
  model.train()



  


In [None]:
# train
# !pip install --upgrade albumentations


import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim


# Hyperparameters

LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32
NUM_EPOCHS = 100
NUM_WORKERS = 2
IMAGE_HEIGHT=530
IMAGE_WIDTH = 390
PIN_MEMORY=True
LOAD_MODEL = False
Train_image_dir = ""
Train_mask_dir = ""
Val_image_dir = ""
Val_mask_dir = ""

def train_fnc(loader,model,optimizer,loss_fnc,scaler):
  loop = tqdm(loader)
  for batch_idx , (data,targets) in enumerate(loop):
    data = data.to(device=DEVICE)
    targets = targets.float().unsqueeze(1).to(device=DEVICE) 
    #unsqueezing to add a dimension for channel (batch_size,1,h,w)

    # forward
    with torch.cuda.amp.autocast():
      preds = model(data)
      loss = loss_fnc(preds,targets)
    

    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    # update tqdm loop
    loop.set_postfix(loss=loss.item())


def main():
  train_transform = A.Compose([
                               A.Resize(height=IMAGE_HEIGHT,width = IMAGE_WIDTH),
                               A.Rotate(limit=35,p=1.0),
                               A.HorizontalFlip(p=0.5),
                               A.VerticalFlip(p=0.1),
                               A.Normalize(mean = [0.0,0.0,0.0],std=[1.0,1.0,1.0],max_pixel_value=255),
                               ToTensorV2()

  ])
  val_transform = A.Compose([
                               A.Resize(height=IMAGE_HEIGHT,width = IMAGE_WIDTH),
                               A.Normalize(mean = [0.0,0.0,0.0],std=[1.0,1.0,1.0],max_pixel_value=255),
                               ToTensorV2()

  ])

  model = UNET(in_channels=3,out_channels=1).to(device=DEVICE)
  loss_fnc = nn.BCEWithLogitsLoss()
  optimizer = optim.Adam(model.parameters(),lr=LEARNING_RATE)

  train_loader,val_loader = get_loader(
      Train_image_dir,
      Train_mask_dir,
      Val_image_dir,
      Val_mask_dir,
      BATCH_SIZE,
      train_transform,
      val_transform,
      NUM_WORKERS,
      PIN_MEMORY 
  )

  scaler = torch.cuda.amp.GradScaler()
  for epoch in range(NUM_EPOCHS):
    train_fnc(train_loader,model,optimizer,loss_fnc,scaler)


    # save model
    # check accuracy
    # save examples to folder


if __name__=='__main__':
  main()




