# Import Library

In [1]:
%load_ext autoreload
%autoreload 2


import numpy as np
import matplotlib.pyplot as plt
import datetime

import torch

from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch import optim
import numpy as np
import pickle
from tqdm import tqdm

import albumentations as A
from albumentations.pytorch import ToTensorV2

from utils import *

from models import caranet
from unet import pretrained_unet

from metrics import DiceLoss



In [2]:
##### Hyperparameter Settings ####
device = 'cuda' if torch.cuda.is_available() else 'cpu'
learning_rate = 1e-4
weight_decay = 1e-10
batch_size = 16
num_epochs = 1000
early_stopping_patience = 10
random_seed = 42
date_time = datetime.datetime.now().strftime("%m-%d_%H-%M")

# model_type = 'caranet'
model_type = 'unet'

filename = f'models/{model_type}_{date_time}.pt'

if model_type == 'unet':
    mode = 'base'
elif model_type == 'caranet':
    mode = 'caranet'
#################################

# Dataloader

In [6]:
transform = A.Compose([
    A.HorizontalFlip(),
    ToTensorV2(transpose_mask=True)
])
import albumentations as A
from albumentations.pytorch import ToTensorV2
def get_training_augmentation():
    train_transform = [

        A.ShiftScaleRotate(scale_limit=0, rotate_limit=0.1, shift_limit=0.1, p=1, border_mode=0),

        A.OneOf(
            [
                A.CLAHE(p=1),
                A.RandomBrightness(p=1),
                A.RandomContrast(p=1),
            ],
            p=0.9,
        ),

        ToTensorV2(transpose_mask=True)
    ]
    return A.Compose(train_transform)



In [7]:
_2_4_loader, _2_loader, _4_loader = make_dataloader(get_training_augmentation(), random_seed, batch_size, mode)

train_2_4_loader, test_2_4_loader = _2_4_loader
train_2_loader, test_2_loader = _2_loader
train_4_loader, test_4_loader = _4_loader


train image shape: (1600, 400, 400, 3) 
train mask shape: (1600, 400, 400, 1)
test image shape: (200, 400, 400, 3) 
test mask shape: (200, 400, 400, 1)


# Pre-Training

In [8]:
if model_type == 'unet':
    model = pretrained_unet(True).to(device)
elif model_type == 'caranet':
    model = caranet().to(device)
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, min_lr=learning_rate/1000, verbose=False)
criterion = DiceLoss()

early_stopping = EarlyStopping(patience = 10, verbose = False, path = filename)
loss_dict = {'train': [], 'val': []}

In [10]:
train_loader = train_2_4_loader
val_loader = test_2_4_loader

In [13]:
for epoch in range(num_epochs):
    model.train()
    
    train_losses = []
    for it_1, (img, mask) in enumerate(train_loader):
        #print(train_img)
        img = img.to(device).float()
        mask = mask.to(device).float()
        #print(train_label)
        if mode == 'base':
            y_pred = model(img)
            loss = criterion(y_pred, mask)
            train_losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        elif mode == 'caranet':
            train_lrmap_5,train_lrmap_3,train_lrmap_2,train_lrmap_1 = model(img)
            train_loss5 = structure_loss(train_lrmap_5, mask)
            train_loss3 = structure_loss(train_lrmap_3, mask)
            train_loss2 = structure_loss(train_lrmap_2, mask)
            train_loss1 = structure_loss(train_lrmap_1, mask)
        
            loss = train_loss5 + train_loss3 + train_loss2 + train_loss1
        
            train_losses.append(loss.item())
            
            optimizer.zero_grad()
            loss.backward()
            clip_gradient(optimizer, 0.5)
            optimizer.step()
    
    train_loss = np.average(train_losses)

    loss_dict['train'].append(train_loss)

    model.eval()
    with torch.no_grad():
        valid_losses = []
        for it_2, (img, mask) in enumerate(val_loader):
            img = img.float().to(device)
            mask = mask.to(device).float()
            #print(train_label)
            if mode == 'base':
                y_pred = model(img)
                loss = criterion(y_pred, mask)
                valid_losses.append(loss.item())

            elif mode == 'caranet':
                lrmap_5, lrmap_3, lrmap_2, lrmap_1 = model(img)
                loss5 = structure_loss(lrmap_5, mask)
                loss3 = structure_loss(lrmap_3, mask)
                loss2 = structure_loss(lrmap_2, mask)
                loss1 = structure_loss(lrmap_1, mask)
            
                loss = loss5 + loss3 + loss2 + loss1
                valid_losses.append(loss.item())
        
        valid_loss = np.average(valid_losses)
        scheduler.step(valid_loss)
        
        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break
#            scheduler.step(float(val_loss))
        loss_dict['val'].append(valid_loss)
            
    print(f'Epoch {epoch} train_loss: {train_loss:0.5f}   val_loss: {valid_loss:0.5f}')

Epoch 0 train_loss: 0.27685   val_loss: 0.11594
Epoch 1 train_loss: 0.11292   val_loss: 0.10284
Epoch 2 train_loss: 0.09432   val_loss: 0.07520
EarlyStopping counter: 1 out of 10
Epoch 3 train_loss: 0.08663   val_loss: 0.09246
EarlyStopping counter: 2 out of 10
Epoch 4 train_loss: 0.08122   val_loss: 0.07878
Epoch 5 train_loss: 0.07445   val_loss: 0.06550
EarlyStopping counter: 1 out of 10
Epoch 6 train_loss: 0.07002   val_loss: 0.06953
Epoch 7 train_loss: 0.06863   val_loss: 0.05963
Epoch 8 train_loss: 0.06658   val_loss: 0.05564
EarlyStopping counter: 1 out of 10
Epoch 9 train_loss: 0.06362   val_loss: 0.06291
Epoch 10 train_loss: 0.06257   val_loss: 0.05176
EarlyStopping counter: 1 out of 10
Epoch 11 train_loss: 0.06077   val_loss: 0.05313
Epoch 12 train_loss: 0.05916   val_loss: 0.05102
EarlyStopping counter: 1 out of 10
Epoch 13 train_loss: 0.05670   val_loss: 0.05169
EarlyStopping counter: 2 out of 10
Epoch 14 train_loss: 0.05619   val_loss: 0.05421
Epoch 15 train_loss: 0.05582  


# Finetuning

In [30]:
learning_rate = 1e-5
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [23]:
train_loader = train_2_loader
val_loader = test_2_loader

In [14]:
filename = 'models/unet_12-02_00-41.pt'
early_stopping = EarlyStopping(patience = 20, verbose = False, path = filename)

In [28]:
model.load_state_dict(torch.load(f'{filename}'))

<All keys matched successfully>

In [None]:
for epoch in range(num_epochs):
    model.train()
    
    train_losses = []
    for it_1, (img, mask) in enumerate(train_loader):
        #print(train_img)
        img = img.to(device)
        mask = mask.to(device).float()
        #print(train_label)
        if mode == 'base':
            y_pred = model(img)
            loss = criterion(y_pred, mask)
            train_losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        elif mode == 'caranet':
            train_lrmap_5,train_lrmap_3,train_lrmap_2,train_lrmap_1 = model(img)
            train_loss5 = structure_loss(train_lrmap_5, mask)
            train_loss3 = structure_loss(train_lrmap_3, mask)
            train_loss2 = structure_loss(train_lrmap_2, mask)
            train_loss1 = structure_loss(train_lrmap_1, mask)
        
            loss = train_loss5 + train_loss3 + train_loss2 + train_loss1
        
            train_losses.append(loss.item())
            
            optimizer.zero_grad()
            loss.backward()
            clip_gradient(optimizer, 0.5)
            optimizer.step()
    
    train_loss = np.average(train_losses)

    loss_dict['train'].append(train_loss)

    model.eval()
    with torch.no_grad():
        valid_losses = []
        for it_2, (img, mask) in enumerate(val_loader):
            img = img.to(device)
            mask = mask.to(device).float()
            #print(train_label)
            if mode == 'base':
                y_pred = model(img)
                loss = criterion(y_pred, mask)
                valid_losses.append(loss.item())

            elif mode == 'caranet':
                lrmap_5, lrmap_3, lrmap_2, lrmap_1 = model(img)
                loss5 = structure_loss(lrmap_5, mask)
                loss3 = structure_loss(lrmap_3, mask)
                loss2 = structure_loss(lrmap_2, mask)
                loss1 = structure_loss(lrmap_1, mask)
            
                loss = loss5 + loss3 + loss2 + loss1
                valid_losses.append(loss.item())
        
        valid_loss = np.average(valid_losses)
        scheduler.step(valid_loss)
        
        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break
#            scheduler.step(float(val_loss))
        loss_dict['val'].append(valid_loss)
            
    print(f'Epoch {epoch} train_loss: {train_loss:0.5f}   val_loss: {valid_loss:0.5f}')

Epoch 0 train_loss: 0.06642   val_loss: 0.05514
Epoch 1 train_loss: 0.06631   val_loss: 0.05412
Epoch 2 train_loss: 0.06525   val_loss: 0.05400
EarlyStopping counter: 1 out of 20
Epoch 3 train_loss: 0.06496   val_loss: 0.05402
Epoch 4 train_loss: 0.06476   val_loss: 0.05365
Epoch 5 train_loss: 0.06543   val_loss: 0.05338
EarlyStopping counter: 1 out of 20
Epoch 6 train_loss: 0.06453   val_loss: 0.05350
Epoch 7 train_loss: 0.06429   val_loss: 0.05331
EarlyStopping counter: 1 out of 20
Epoch 8 train_loss: 0.06460   val_loss: 0.05350
EarlyStopping counter: 2 out of 20
Epoch 9 train_loss: 0.06437   val_loss: 0.05336
Epoch 10 train_loss: 0.06396   val_loss: 0.05326
Epoch 11 train_loss: 0.06372   val_loss: 0.05311
Epoch 12 train_loss: 0.06307   val_loss: 0.05307
Epoch 13 train_loss: 0.06331   val_loss: 0.05298
Epoch 14 train_loss: 0.06292   val_loss: 0.05278
Epoch 15 train_loss: 0.06352   val_loss: 0.05266
Epoch 16 train_loss: 0.06328   val_loss: 0.05258
EarlyStopping counter: 1 out of 20
Ep

# Evaluate

In [15]:
# load best model
model.load_state_dict(torch.load(f'{filename}'))

<All keys matched successfully>

In [16]:
from utils import *

In [14]:
x,y = next(iter(test_2_4_loader))

In [15]:
x.shape, y.shape

(torch.Size([16, 3, 400, 400]), torch.Size([16, 1, 400, 400]))

In [18]:
evaluate(model, test_2_4_loader, mode)

Dice Similarity:    0.9536 
Jaccard Similarity: 0.9132


(0.9536186449431789, 0.9131997225313269)

In [19]:
evaluate(model, test_2_loader, mode)

Dice Similarity:    0.9462 
Jaccard Similarity: 0.9008


(0.9461996130151895, 0.9007514524226488)

In [20]:
evaluate(model, test_4_loader, mode)

Dice Similarity:    0.9610 
Jaccard Similarity: 0.9256


(0.9610376768711683, 0.9256479926400051)