# Import Library

In [1]:
%load_ext autoreload
%autoreload 2


import numpy as np
import matplotlib.pyplot as plt
import datetime

import torch

from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch import optim
import numpy as np
import pickle
from tqdm import tqdm

import albumentations as A
from albumentations.pytorch import ToTensorV2

from utils import *

from models import caranet
from unet import pretrained_unet

from metrics import DiceLoss



In [20]:
##### Hyperparameter Settings ####
device = 'cuda' if torch.cuda.is_available() else 'cpu'
leraning_rate = 0.001
weight_decay = 1e-10
batch_size = 16
num_epochs = 1000
early_stopping_patience = 10
random_seed = 42
date_time = datetime.datetime.now().strftime("%m-%d_%H-%M")

model_type = 'caranet'
#model_type = 'unet'

filename = f'models/{model_type}_{date_time}.pt'

if model_type == 'unet':
    mode = 'base'
elif model_type == 'caranet':
    mode = 'caranet'
#################################

# Dataloader

In [21]:
transform = A.Compose([
    A.HorizontalFlip(),
    ToTensorV2(transpose_mask=True)
])



In [22]:
_2_4_loader, _2_loader, _4_loader = create_loader(transform, random_seed, batch_size, mode)

train_2_4_loader, val_2_4_loader, test_2_4_loader = _2_4_loader
train_2_loader, val_2_loader, test_2_loader = _2_loader
train_4_loader, val_4_loader, test_4_loader = _4_loader


train image shape: (1600, 400, 400, 3) 
train mask shape: (1600, 400, 400, 1)
test image shape: (200, 400, 400, 3) 
test mask shape: (200, 400, 400, 1)


# Training

In [23]:
# model = caranet().to(device)
model = pretrained_unet(True).to(device)
optimizer = Adam(model.parameters(), lr=leraning_rate, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=20, min_lr=leraning_rate/1000, verbose=True)
criterion = DiceLoss()

early_stopping = EarlyStopping(patience = 20, verbose = True, path = filename)
loss_dict = {'train': [], 'val': []}

In [24]:
train_loader = train_2_4_loader
val_loader = val_2_4_loader

In [25]:
for epoch in range(num_epochs):
    model.train()
    
    train_losses = []
    for it_1, (img, mask) in enumerate(train_loader):
        #print(train_img)
        img = img.to(device)
        mask = mask.to(device).float()
        #print(train_label)
        if mode == 'base':
            y_pred = model(img)
            loss = criterion(y_pred, mask)
            train_losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        elif mode == 'caranet':
            train_lrmap_5,train_lrmap_3,train_lrmap_2,train_lrmap_1 = model(img)
            train_loss5 = structure_loss(train_lrmap_5, mask)
            train_loss3 = structure_loss(train_lrmap_3, mask)
            train_loss2 = structure_loss(train_lrmap_2, mask)
            train_loss1 = structure_loss(train_lrmap_1, mask)
        
            loss = train_loss5 + train_loss3 + train_loss2 + train_loss1
        
            train_losses.append(loss.item())
            
            optimizer.zero_grad()
            loss.backward()
            clip_gradient(optimizer, 0.5)
            optimizer.step()
    
    train_loss = np.average(train_losses)

    loss_dict['train'].append(train_loss)

    model.eval()
    with torch.no_grad():
        valid_losses = []
        for it_2, (img, mask) in enumerate(val_loader):
            img = img.to(device)
            mask = mask.to(device).float()
            #print(train_label)
            if mode == 'base':
                y_pred = model(img)
                loss = criterion(y_pred, mask)
                valid_losses.append(loss.item())

            elif mode == 'caranet':
                lrmap_5, lrmap_3, lrmap_2, lrmap_1 = model(img)
                loss5 = structure_loss(lrmap_5, mask)
                loss3 = structure_loss(lrmap_3, mask)
                loss2 = structure_loss(lrmap_2, mask)
                loss1 = structure_loss(lrmap_1, mask)
            
                loss = loss5 + loss3 + loss2 + loss1
                valid_losses.append(loss.item())
        
        valid_loss = np.average(valid_losses)
        scheduler.step(valid_loss)
        
        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break
#            scheduler.step(float(val_loss))
        loss_dict['val'].append(valid_loss)
            
    print(f'Epoch {epoch} train_loss: {train_loss:0.5f}   val_loss: {valid_loss:0.5f}')

Validation loss decreased (inf --> 0.169942).  Saving model ...
Epoch 0 train_loss: 0.23479   val_loss: 0.16994
EarlyStopping counter: 1 out of 20
Epoch 1 train_loss: 0.11068   val_loss: 0.27615
Validation loss decreased (0.169942 --> 0.122242).  Saving model ...
Epoch 2 train_loss: 0.10089   val_loss: 0.12224
Validation loss decreased (0.122242 --> 0.106740).  Saving model ...
Epoch 3 train_loss: 0.08454   val_loss: 0.10674
EarlyStopping counter: 1 out of 20
Epoch 4 train_loss: 0.08212   val_loss: 0.10675
EarlyStopping counter: 2 out of 20
Epoch 5 train_loss: 0.07812   val_loss: 0.12205
Validation loss decreased (0.106740 --> 0.104486).  Saving model ...
Epoch 6 train_loss: 0.07440   val_loss: 0.10449
Validation loss decreased (0.104486 --> 0.092919).  Saving model ...
Epoch 7 train_loss: 0.07052   val_loss: 0.09292
Validation loss decreased (0.092919 --> 0.079587).  Saving model ...
Epoch 8 train_loss: 0.06921   val_loss: 0.07959
EarlyStopping counter: 1 out of 20
Epoch 9 train_loss:

# Finetuning

# Evaluate

In [None]:
# load best model
model.load_state_dict(torch.load(f'{filename}'))

<All keys matched successfully>

In [None]:
from utils import *

In [None]:
x,y = next(iter(test_2_4_loader))

In [None]:
x.shape, y.shape

(torch.Size([16, 3, 400, 400]), torch.Size([16, 1, 400, 400]))

In [None]:
evaluate(model, test_2_4_loader, mode)

(200, 1, 400, 400) (200, 1, 400, 400)
Dice Similarity:    0.9310 
Jaccard Similarity: 0.8746


(0.93097356219048, 0.8746172610912453)

In [None]:
evaluate(model, test_2_loader, mode)

(100, 1, 400, 400) (100, 1, 400, 400)
Dice Similarity:    0.9288 
Jaccard Similarity: 0.8709


(0.9287668186120144, 0.8709392917465969)

In [19]:
evaluate(model, test_4_loader, mode)

(100, 1, 400, 400) (100, 1, 400, 400)
Dice Similarity:    0.9332 
Jaccard Similarity: 0.8783


(0.9331803057689456, 0.8782952304358933)