In [17]:
import logging
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from Losses import Dice_Ceff, dice_loss
from DataLoder import IMD2020Dataset
from torch.utils.data import Dataset, random_split
from torch.utils.data import DataLoader


In [18]:
data_dir = "data.csv"

In [3]:
os.mkdir("checkpoints")

FileExistsError: [WinError 183] Cannot create a file when that file already exists: 'checkpoints'

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
device

device(type='cpu')

In [6]:
device.type

'cpu'

In [19]:
def train(model , device, data_dir: str = None , epochs: int = 5,
          batch_size : int = 32,
          learning_rate: float = 1e-5,
          save_checkpoints: bool = True,
          isValidation: bool = False,
          val_split: float = 0.1,
          weight_decay: float = 1e-8,
          momentum: float = 0.999,
          gradient_clipping: float = 1.0,
          amp: bool = None ):
    
    data = IMD2020Dataset(data_dir)

    if isValidation:
        dev_len = int(len(data)*val_split)
        train_len = len(data) - dev_len

        train_set, dev_set = random_split(data, [train_len,dev_len])
        train_loader = DataLoader(train_set,batch_size,shuffle = True)
        val_loader = DataLoader(dev_set, batch_size, False)

    else:
        train_len = len(data)
        train_loader = DataLoader(data, batch_size=batch_size, shuffle=True)

    optimizer1 = torch.optim.AdamW(model.parameters(),lr = learning_rate, weight_decay= weight_decay)
    optimizer2 = torch.optim.RMSprop(model.parameters(),lr=learning_rate,weight_decay=weight_decay,momentum=momentum)
    optimizer3 = torch.optim.Adam(model.parameters(), lr = learning_rate, weight_decay= weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer1,"max",patience=3)
    criterion = nn.BCEWithLogitsLoss()



    for epoch in range(1 , epochs + 1):
        model.train()
        loss = 0
        accuracy = 0
        dice_score = 0
        with tqdm.tqdm(total = train_len, desc = f'Epoch {epoch} of {epochs}', unit = 'img') as pbar:

            for batch in train_loader:

                optimizer1.zero_grad()

                fakes , masks = batch[0], batch[1]

                fakes = fakes.to(device = device, dtype = torch.float32)
                masks = masks.to(device = device, dtype = torch.float32)

                mask_pred = model(fakes)

                loss = criterion(mask_pred.squeeze(1), masks.float())
                loss += dice_loss(mask_pred, masks)
                dice_score = Dice_Ceff(mask_pred,masks)

                
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
                loss.backward()
                scheduler.step(loss)
                optimizer1.step()


                pbar.update(fakes.shape[0])

                epoch_loss += loss.item()
                pbar.set_postfix(**{'loss (batch)': loss.item(), 'acc(batch)':1-loss.item(), 'dice_score(batch)':dice_score})

        if save_checkpoints:
            state_dict = model.state_dict()
            torch.save(state_dict, str("checkpoints/checkpoint_epoch{}.pth".format(epoch)))    

        


In [20]:
from Basline_U_Net_Model import UNet

In [21]:
model = UNet(3,1)

UNet(
  (initial): ConvBlock(
    (dconv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (contract1): EncodeDown(
    (down): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): ConvBlock(
        (dconv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
         

In [22]:
train(model=model, device= device, data_dir=data_dir,save_checkpoints=False)



: 