In [1]:
import logging
import os
import torch
import torch.nn as nn
import json
import torch.nn.functional as F
import tqdm
from Utils.Losses import Dice_Ceff, dice_loss
from Utils.DataLoder import IMD2020Dataset
from torch.utils.data import Dataset, random_split
from torch.utils.data import DataLoader
from Utils.Eval import evaluate


In [18]:
data_dir = "data.csv"

In [2]:
if not os.path.exists("checkpoints"):
    os.mkdir("checkpoints")

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
device

device(type='cpu')

In [6]:
device.type

'cpu'

In [19]:
def train(model , device, data_dir: str = None , epochs: int = 5,
          batch_size : int = 32,
          learning_rate: float = 1e-5,
          save_checkpoints: bool = True,
          isValidation: bool = True,
          val_split: float = 0.1,
          eval_step : int = 5,
          weight_decay: float = 1e-8,
          momentum: float = 0.999,
          gradient_clipping: float = 1.0,
          amp: bool = True ):

    data = IMD2020Dataset(data_dir)

    if isValidation:
        dev_len = int(len(data)*val_split)
        train_len = len(data) - dev_len

        train_set, dev_set = random_split(data, [train_len,dev_len])
        train_loader = DataLoader(train_set,batch_size,shuffle = True)
        val_loader = DataLoader(dev_set, batch_size, False)

    else:
        train_len = len(data)
        train_loader = DataLoader(data, batch_size=batch_size, shuffle=True)

    optimizer2 = torch.optim.AdamW(model.parameters(),lr = learning_rate, weight_decay= weight_decay)
    optimizer1 = torch.optim.RMSprop(model.parameters(),lr=learning_rate,weight_decay=weight_decay,momentum=momentum)
    optimizer3 = torch.optim.Adam(model.parameters(), lr = learning_rate, weight_decay= weight_decay)
    for optimizer in [optimizer1,optimizer2,optimizer3]:
      scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,"max",patience=4)
      criterion = nn.BCEWithLogitsLoss()
      dice = DiceLoss()
      grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)

      opt_folder_path = f"/content/checkpoints/{type(optimizer).__name__}"
      if not os.path.exists(opt_folder_path):
        os.mkdir(opt_folder_path)
      global_step = 0
      for epoch in range(1 , epochs + 1):
          model.train()
          loss = 0
          accuracy = 0
          dice_score = 0
          epoch_loss = 0
          with tqdm.tqdm(total = train_len, desc = f'Epoch {epoch} of {epochs}', unit = 'img') as pbar:

              for batch in train_loader:


                  fakes , masks = batch[0], batch[1]

                  fakes = fakes.to(device = device, dtype = torch.float32)
                  masks = masks.to(device = device, dtype = torch.long)
                  with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
                    mask_pred = model(fakes)

                    loss = criterion(mask_pred, masks.float())
                    dice_loss = dice(F.sigmoid(mask_pred), masks.float())
                    loss += dice_loss
                    dice_score = 1 - dice_loss


                    optimizer.zero_grad(set_to_none=True)
                    grad_scaler.scale(loss).backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
                    grad_scaler.step(optimizer)
                    grad_scaler.update()



                    pbar.update(fakes.shape[0])
                    pbar.set_postfix(**{'loss (batch)': loss.item(), 'acc(batch)':2-loss.item(), 'dice_score(batch)':dice_score.item()})
          epoch_loss += loss.item()
          if isValidation and epoch % eval_step == 0:
            val_loss,val_score = evaluate(model,device,val_loader,criterion,dev_len,True,eval_step,epoch,learning_rate)
            scheduler.step(val_score)
            pbar.set_description(desc= f"Epoch-Loss : {epoch_loss}, Val_loss : {val_loss}, Val_Score : {val_score}")
          # else:
          #   scheduler.step(1-epoch_loss)
          #   pbar.set_description(desc= f"Epoch-Loss : {epoch_loss}")


          if save_checkpoints:
              mod_state_dict = model.state_dict()
              opt_state_dict = optimizer.state_dict()
              model_path = f"/content/checkpoints/{type(optimizer).__name__}/checkpoint_epoch{epoch}.pt"
              metric_file_path = f"/content/checkpoints/{type(optimizer).__name__}/Epoch_{epoch}.json"
              with open(metric_file_path, "w") as f:
                try:
                  data = {"Epoch" : epoch, "Epoch_Loss":epoch_loss,
                          "Val_loss":val_loss, "Val_Score" : val_score, "model_path": model_path}
                  json.dump(data,f,indent = 2)
                except NameError:
                  data = {"Epoch" : epoch, "Epoch_Loss":epoch_loss,"model_path": model_path}
                  json.dump(data,f,indent = 2)
              torch.save({"model_state_dict" : mod_state_dict,
                          "opt_state_dict" : opt_state_dict}, model_path)
      break



In [20]:
from Utils.Basline_U_Net_Model import UNet

In [21]:
model = UNet(3,1)

In [10]:
model.to(device=device)

UNet(
  (initial): ConvBlock(
    (dconv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (contract1): EncodeDown(
    (down): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): ConvBlock(
        (dconv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
         

In [None]:
try:
  train(model=model, device= device, data_dir=data_dir,save_checkpoints=True,epochs = 100)
except torch.cuda.OutOfMemoryError:
  torch.cuda.empty_cache()
  model.use_checkpointing()
  train(model=model, device= device, data_dir=data_dir,save_checkpoints=True,epochs = 100)
