# Import Library

In [2]:
%load_ext autoreload
%autoreload 2


import numpy as np
import matplotlib.pyplot as plt
import datetime

import torch

from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch import optim
import numpy as np
import pickle
from tqdm import tqdm

import albumentations as A
from albumentations.pytorch import ToTensorV2

from utils import *

from models import caranet
from unet import pretrained_unet

from metrics import DiceLoss



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
##### Hyperparameter Settings ####
device = 'cuda' if torch.cuda.is_available() else 'cpu'
learning_rate = 1e-4
weight_decay = 1e-10
batch_size = 16
num_epochs = 1000
early_stopping_patience = 10
random_seed = 42
date_time = datetime.datetime.now().strftime("%m-%d_%H-%M")

# model_type = 'caranet'
model_type = 'unet'

filename = f'models/{model_type}_{date_time}.pt'

if model_type == 'unet':
    mode = 'base'
elif model_type == 'caranet':
    mode = 'caranet'
#################################

# Dataloader

In [7]:
transform = A.Compose([
    A.HorizontalFlip(),
    ToTensorV2(transpose_mask=True)
])



In [19]:
_2_4_loader, _2_loader, _4_loader = make_dataloader(transform, random_seed, batch_size, mode)

train_2_4_loader, test_2_4_loader = _2_4_loader
train_2_loader, test_2_loader = _2_loader
train_4_loader, test_4_loader = _4_loader


train image shape: (1600, 400, 400, 3) 
train mask shape: (1600, 400, 400, 1)
test image shape: (200, 400, 400, 3) 
test mask shape: (200, 400, 400, 1)


# Pre-Training

In [20]:
if model_type == 'unet':
    model = pretrained_unet(True).to(device)
elif model_type == 'caranet':
    model = caranet().to(device)
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=20, min_lr=leraning_rate/1000, verbose=False)
criterion = DiceLoss()

early_stopping = EarlyStopping(patience = 20, verbose = False, path = filename)
loss_dict = {'train': [], 'val': []}

Downloading: "https://github.com/mateuszbuda/brain-segmentation-pytorch/releases/download/v1.0/unet-e012d006.pt" to /home/sunbinlee/.cache/torch/hub/checkpoints/unet-e012d006.pt


In [21]:
train_loader = train_2_4_loader
val_loader = val_2_4_loader

In [11]:
for epoch in range(num_epochs):
    model.train()
    
    train_losses = []
    for it_1, (img, mask) in enumerate(train_loader):
        #print(train_img)
        img = img.to(device)
        mask = mask.to(device).float()
        #print(train_label)
        if mode == 'base':
            y_pred = model(img)
            loss = criterion(y_pred, mask)
            train_losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        elif mode == 'caranet':
            train_lrmap_5,train_lrmap_3,train_lrmap_2,train_lrmap_1 = model(img)
            train_loss5 = structure_loss(train_lrmap_5, mask)
            train_loss3 = structure_loss(train_lrmap_3, mask)
            train_loss2 = structure_loss(train_lrmap_2, mask)
            train_loss1 = structure_loss(train_lrmap_1, mask)
        
            loss = train_loss5 + train_loss3 + train_loss2 + train_loss1
        
            train_losses.append(loss.item())
            
            optimizer.zero_grad()
            loss.backward()
            clip_gradient(optimizer, 0.5)
            optimizer.step()
    
    train_loss = np.average(train_losses)

    loss_dict['train'].append(train_loss)

    model.eval()
    with torch.no_grad():
        valid_losses = []
        for it_2, (img, mask) in enumerate(val_loader):
            img = img.to(device)
            mask = mask.to(device).float()
            #print(train_label)
            if mode == 'base':
                y_pred = model(img)
                loss = criterion(y_pred, mask)
                valid_losses.append(loss.item())

            elif mode == 'caranet':
                lrmap_5, lrmap_3, lrmap_2, lrmap_1 = model(img)
                loss5 = structure_loss(lrmap_5, mask)
                loss3 = structure_loss(lrmap_3, mask)
                loss2 = structure_loss(lrmap_2, mask)
                loss1 = structure_loss(lrmap_1, mask)
            
                loss = loss5 + loss3 + loss2 + loss1
                valid_losses.append(loss.item())
        
        valid_loss = np.average(valid_losses)
        scheduler.step(valid_loss)
        
        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break
#            scheduler.step(float(val_loss))
        loss_dict['val'].append(valid_loss)
            
    print(f'Epoch {epoch} train_loss: {train_loss:0.5f}   val_loss: {valid_loss:0.5f}')

Epoch 0 train_loss: 0.71702   val_loss: 0.66310
Epoch 1 train_loss: 0.63567   val_loss: 0.62090
Epoch 2 train_loss: 0.60623   val_loss: 0.60054
Epoch 3 train_loss: 0.58137   val_loss: 0.58401
Epoch 4 train_loss: 0.55619   val_loss: 0.54848
Epoch 5 train_loss: 0.52998   val_loss: 0.52371
Epoch 6 train_loss: 0.49966   val_loss: 0.48504
Epoch 7 train_loss: 0.46458   val_loss: 0.47178
Epoch 8 train_loss: 0.43170   val_loss: 0.43726
Epoch 9 train_loss: 0.39988   val_loss: 0.38938
Epoch 10 train_loss: 0.37061   val_loss: 0.36117
Epoch 11 train_loss: 0.34461   val_loss: 0.32991
Epoch 12 train_loss: 0.31447   val_loss: 0.31413
Epoch 13 train_loss: 0.28770   val_loss: 0.29499
EarlyStopping counter: 1 out of 20
Epoch 14 train_loss: 0.26526   val_loss: 0.35615
Epoch 15 train_loss: 0.24566   val_loss: 0.25535
Epoch 16 train_loss: 0.22397   val_loss: 0.22701
Epoch 17 train_loss: 0.20889   val_loss: 0.20628
Epoch 18 train_loss: 0.18948   val_loss: 0.18851
Epoch 19 train_loss: 0.17482   val_loss: 0.1

KeyboardInterrupt: 

# Finetuning

In [30]:
learning_rate = 1e-5
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [23]:
train_loader = train_2_loader
val_loader = test_2_loader

In [27]:
filename = 'models/unet_12-01_15-41.pt'
early_stopping = EarlyStopping(patience = 20, verbose = False, path = filename)

In [28]:
model.load_state_dict(torch.load(f'{filename}'))

<All keys matched successfully>

In [None]:
for epoch in range(num_epochs):
    model.train()
    
    train_losses = []
    for it_1, (img, mask) in enumerate(train_loader):
        #print(train_img)
        img = img.to(device)
        mask = mask.to(device).float()
        #print(train_label)
        if mode == 'base':
            y_pred = model(img)
            loss = criterion(y_pred, mask)
            train_losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        elif mode == 'caranet':
            train_lrmap_5,train_lrmap_3,train_lrmap_2,train_lrmap_1 = model(img)
            train_loss5 = structure_loss(train_lrmap_5, mask)
            train_loss3 = structure_loss(train_lrmap_3, mask)
            train_loss2 = structure_loss(train_lrmap_2, mask)
            train_loss1 = structure_loss(train_lrmap_1, mask)
        
            loss = train_loss5 + train_loss3 + train_loss2 + train_loss1
        
            train_losses.append(loss.item())
            
            optimizer.zero_grad()
            loss.backward()
            clip_gradient(optimizer, 0.5)
            optimizer.step()
    
    train_loss = np.average(train_losses)

    loss_dict['train'].append(train_loss)

    model.eval()
    with torch.no_grad():
        valid_losses = []
        for it_2, (img, mask) in enumerate(val_loader):
            img = img.to(device)
            mask = mask.to(device).float()
            #print(train_label)
            if mode == 'base':
                y_pred = model(img)
                loss = criterion(y_pred, mask)
                valid_losses.append(loss.item())

            elif mode == 'caranet':
                lrmap_5, lrmap_3, lrmap_2, lrmap_1 = model(img)
                loss5 = structure_loss(lrmap_5, mask)
                loss3 = structure_loss(lrmap_3, mask)
                loss2 = structure_loss(lrmap_2, mask)
                loss1 = structure_loss(lrmap_1, mask)
            
                loss = loss5 + loss3 + loss2 + loss1
                valid_losses.append(loss.item())
        
        valid_loss = np.average(valid_losses)
        scheduler.step(valid_loss)
        
        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break
#            scheduler.step(float(val_loss))
        loss_dict['val'].append(valid_loss)
            
    print(f'Epoch {epoch} train_loss: {train_loss:0.5f}   val_loss: {valid_loss:0.5f}')

Epoch 0 train_loss: 0.06642   val_loss: 0.05514
Epoch 1 train_loss: 0.06631   val_loss: 0.05412
Epoch 2 train_loss: 0.06525   val_loss: 0.05400
EarlyStopping counter: 1 out of 20
Epoch 3 train_loss: 0.06496   val_loss: 0.05402
Epoch 4 train_loss: 0.06476   val_loss: 0.05365
Epoch 5 train_loss: 0.06543   val_loss: 0.05338
EarlyStopping counter: 1 out of 20
Epoch 6 train_loss: 0.06453   val_loss: 0.05350
Epoch 7 train_loss: 0.06429   val_loss: 0.05331
EarlyStopping counter: 1 out of 20
Epoch 8 train_loss: 0.06460   val_loss: 0.05350
EarlyStopping counter: 2 out of 20
Epoch 9 train_loss: 0.06437   val_loss: 0.05336
Epoch 10 train_loss: 0.06396   val_loss: 0.05326
Epoch 11 train_loss: 0.06372   val_loss: 0.05311
Epoch 12 train_loss: 0.06307   val_loss: 0.05307
Epoch 13 train_loss: 0.06331   val_loss: 0.05298
Epoch 14 train_loss: 0.06292   val_loss: 0.05278
Epoch 15 train_loss: 0.06352   val_loss: 0.05266
Epoch 16 train_loss: 0.06328   val_loss: 0.05258
EarlyStopping counter: 1 out of 20
Ep

# Evaluate

In [12]:
# load best model
model.load_state_dict(torch.load(f'{filename}'))

<All keys matched successfully>

In [13]:
from utils import *

In [14]:
x,y = next(iter(test_2_4_loader))

In [15]:
x.shape, y.shape

(torch.Size([16, 3, 400, 400]), torch.Size([16, 1, 400, 400]))

In [16]:
evaluate(model, test_2_4_loader, mode)

Dice Similarity:    0.9488 
Jaccard Similarity: 0.9047


(0.9488382706760882, 0.9046990612617313)

In [17]:
evaluate(model, test_2_loader, mode)

Dice Similarity:    0.9396 
Jaccard Similarity: 0.8888


(0.9395881880382325, 0.8888093268756143)

In [18]:
evaluate(model, test_4_loader, mode)

Dice Similarity:    0.9581 
Jaccard Similarity: 0.9206


(0.9580883533139442, 0.9205887956478481)