# Import Library

In [1]:
%load_ext autoreload
%autoreload 2


import numpy as np
import matplotlib.pyplot as plt
import datetime

import torch

from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch import optim
import numpy as np
import pickle
from tqdm import tqdm

import albumentations as A
from albumentations.pytorch import ToTensorV2

from utils import *

from models import caranet
from unet import pretrained_unet

from metrics import DiceLoss



In [2]:
##### Hyperparameter Settings ####
device = 'cuda' if torch.cuda.is_available() else 'cpu'
learning_rate = 1e-4
weight_decay = 1e-10
batch_size = 16
num_epochs = 1000
early_stopping_patience = 10
random_seed = 42
date_time = datetime.datetime.now().strftime("%m-%d_%H-%M")

# model_type = 'caranet'
model_type = 'unet'

filename = f'models/{model_type}_{date_time}.pt'

if model_type == 'unet':
    mode = 'base'
elif model_type == 'caranet':
    mode = 'caranet'
#################################

# Dataloader

In [3]:
transform = A.Compose([
    A.HorizontalFlip(),
    ToTensorV2(transpose_mask=True)
])



In [4]:
_2_4_loader, _2_loader, _4_loader = make_dataloader(transform, random_seed, batch_size, mode)

train_2_4_loader, test_2_4_loader = _2_4_loader
train_2_loader, test_2_loader = _2_loader
train_4_loader, test_4_loader = _4_loader


train image shape: (1600, 400, 400, 3) 
train mask shape: (1600, 400, 400, 1)
test image shape: (200, 400, 400, 3) 
test mask shape: (200, 400, 400, 1)


# Pre-Training

In [7]:
if model_type == 'unet':
    model = pretrained_unet(True).to(device)
elif model_type == 'caranet':
    model = caranet().to(device)
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=20, min_lr=learning_rate/1000, verbose=False)
criterion = DiceLoss()

early_stopping = EarlyStopping(patience = 20, verbose = False, path = filename)
loss_dict = {'train': [], 'val': []}

In [9]:
train_loader = train_2_4_loader
val_loader = test_2_4_loader

In [11]:
for epoch in range(num_epochs):
    model.train()
    
    train_losses = []
    for it_1, (img, mask) in enumerate(train_loader):
        #print(train_img)
        img = img.to(device)
        mask = mask.to(device).float()
        #print(train_label)
        if mode == 'base':
            y_pred = model(img)
            loss = criterion(y_pred, mask)
            train_losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        elif mode == 'caranet':
            train_lrmap_5,train_lrmap_3,train_lrmap_2,train_lrmap_1 = model(img)
            train_loss5 = structure_loss(train_lrmap_5, mask)
            train_loss3 = structure_loss(train_lrmap_3, mask)
            train_loss2 = structure_loss(train_lrmap_2, mask)
            train_loss1 = structure_loss(train_lrmap_1, mask)
        
            loss = train_loss5 + train_loss3 + train_loss2 + train_loss1
        
            train_losses.append(loss.item())
            
            optimizer.zero_grad()
            loss.backward()
            clip_gradient(optimizer, 0.5)
            optimizer.step()
    
    train_loss = np.average(train_losses)

    loss_dict['train'].append(train_loss)

    model.eval()
    with torch.no_grad():
        valid_losses = []
        for it_2, (img, mask) in enumerate(val_loader):
            img = img.to(device)
            mask = mask.to(device).float()
            #print(train_label)
            if mode == 'base':
                y_pred = model(img)
                loss = criterion(y_pred, mask)
                valid_losses.append(loss.item())

            elif mode == 'caranet':
                lrmap_5, lrmap_3, lrmap_2, lrmap_1 = model(img)
                loss5 = structure_loss(lrmap_5, mask)
                loss3 = structure_loss(lrmap_3, mask)
                loss2 = structure_loss(lrmap_2, mask)
                loss1 = structure_loss(lrmap_1, mask)
            
                loss = loss5 + loss3 + loss2 + loss1
                valid_losses.append(loss.item())
        
        valid_loss = np.average(valid_losses)
        scheduler.step(valid_loss)
        
        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break
#            scheduler.step(float(val_loss))
        loss_dict['val'].append(valid_loss)
            
    print(f'Epoch {epoch} train_loss: {train_loss:0.5f}   val_loss: {valid_loss:0.5f}')

Epoch 0 train_loss: 0.71702   val_loss: 0.66310
Epoch 1 train_loss: 0.63567   val_loss: 0.62090
Epoch 2 train_loss: 0.60623   val_loss: 0.60054
Epoch 3 train_loss: 0.58137   val_loss: 0.58401
Epoch 4 train_loss: 0.55619   val_loss: 0.54848
Epoch 5 train_loss: 0.52998   val_loss: 0.52371
Epoch 6 train_loss: 0.49966   val_loss: 0.48504
Epoch 7 train_loss: 0.46458   val_loss: 0.47178
Epoch 8 train_loss: 0.43170   val_loss: 0.43726
Epoch 9 train_loss: 0.39988   val_loss: 0.38938
Epoch 10 train_loss: 0.37061   val_loss: 0.36117
Epoch 11 train_loss: 0.34461   val_loss: 0.32991
Epoch 12 train_loss: 0.31447   val_loss: 0.31413
Epoch 13 train_loss: 0.28770   val_loss: 0.29499
EarlyStopping counter: 1 out of 20
Epoch 14 train_loss: 0.26526   val_loss: 0.35615
Epoch 15 train_loss: 0.24566   val_loss: 0.25535
Epoch 16 train_loss: 0.22397   val_loss: 0.22701
Epoch 17 train_loss: 0.20889   val_loss: 0.20628
Epoch 18 train_loss: 0.18948   val_loss: 0.18851
Epoch 19 train_loss: 0.17482   val_loss: 0.1

KeyboardInterrupt: 

# Finetuning

In [36]:
learning_rate = 1e-4
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [37]:
train_loader = train_2_loader
val_loader = test_2_loader

In [40]:
filename = 'models/unet_12-01_15-41.pt'
model.load_state_dict(torch.load(f'{filename}'))

<All keys matched successfully>

In [41]:
date_time = datetime.datetime.now().strftime("%m-%d_%H-%M")
filename = f'models/{model_type}_{date_time}.pt'
early_stopping = EarlyStopping(patience = 20, verbose = False, path = filename)

In [42]:
for epoch in range(num_epochs):
    model.train()
    
    train_losses = []
    for it_1, (img, mask) in enumerate(train_loader):
        #print(train_img)
        img = img.to(device)
        mask = mask.to(device).float()
        #print(train_label)
        if mode == 'base':
            y_pred = model(img)
            loss = criterion(y_pred, mask)
            train_losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        elif mode == 'caranet':
            train_lrmap_5,train_lrmap_3,train_lrmap_2,train_lrmap_1 = model(img)
            train_loss5 = structure_loss(train_lrmap_5, mask)
            train_loss3 = structure_loss(train_lrmap_3, mask)
            train_loss2 = structure_loss(train_lrmap_2, mask)
            train_loss1 = structure_loss(train_lrmap_1, mask)
        
            loss = train_loss5 + train_loss3 + train_loss2 + train_loss1
        
            train_losses.append(loss.item())
            
            optimizer.zero_grad()
            loss.backward()
            clip_gradient(optimizer, 0.5)
            optimizer.step()
    
    train_loss = np.average(train_losses)

    loss_dict['train'].append(train_loss)

    model.eval()
    with torch.no_grad():
        valid_losses = []
        for it_2, (img, mask) in enumerate(val_loader):
            img = img.to(device)
            mask = mask.to(device).float()
            #print(train_label)
            if mode == 'base':
                y_pred = model(img)
                loss = criterion(y_pred, mask)
                valid_losses.append(loss.item())

            elif mode == 'caranet':
                lrmap_5, lrmap_3, lrmap_2, lrmap_1 = model(img)
                loss5 = structure_loss(lrmap_5, mask)
                loss3 = structure_loss(lrmap_3, mask)
                loss2 = structure_loss(lrmap_2, mask)
                loss1 = structure_loss(lrmap_1, mask)
            
                loss = loss5 + loss3 + loss2 + loss1
                valid_losses.append(loss.item())
        
        valid_loss = np.average(valid_losses)
        scheduler.step(valid_loss)
        
        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break
#            scheduler.step(float(val_loss))
        loss_dict['val'].append(valid_loss)
            
    print(f'Epoch {epoch} train_loss: {train_loss:0.5f}   val_loss: {valid_loss:0.5f}')

Epoch 0 train_loss: 0.07450   val_loss: 0.05871
Epoch 1 train_loss: 0.06933   val_loss: 0.05740
Epoch 2 train_loss: 0.06705   val_loss: 0.05477
Epoch 3 train_loss: 0.06616   val_loss: 0.05416
EarlyStopping counter: 1 out of 20
Epoch 4 train_loss: 0.06497   val_loss: 0.05498
EarlyStopping counter: 2 out of 20
Epoch 5 train_loss: 0.06404   val_loss: 0.05470
Epoch 6 train_loss: 0.06264   val_loss: 0.05356
EarlyStopping counter: 1 out of 20
Epoch 7 train_loss: 0.06253   val_loss: 0.05517
Epoch 8 train_loss: 0.06094   val_loss: 0.05098
EarlyStopping counter: 1 out of 20
Epoch 9 train_loss: 0.06238   val_loss: 0.05193
EarlyStopping counter: 2 out of 20
Epoch 10 train_loss: 0.06099   val_loss: 0.05142
EarlyStopping counter: 3 out of 20
Epoch 11 train_loss: 0.06101   val_loss: 0.05224
EarlyStopping counter: 4 out of 20
Epoch 12 train_loss: 0.05951   val_loss: 0.05272
EarlyStopping counter: 5 out of 20
Epoch 13 train_loss: 0.05880   val_loss: 0.05280
EarlyStopping counter: 6 out of 20
Epoch 14 

# Evaluate

In [32]:
def evaluate(model, testloader, mode='base'):
    img_list = []
    pred_mask_list = []
    gt_mask_list = []
    model.eval()
    with torch.no_grad():
        for img, gt_mask in testloader:
            output = model(img.cuda())
            if mode=='base': # 일반적인 모델
                pred_mask_list.append(output.cpu().numpy())
            elif mode=='caranet': # 종욱이 모델
                pred_mask_list.append(output[0].sigmoid().cpu().numpy()) 

            gt_mask_list.append(gt_mask.numpy())
    pred_mask_list = np.vstack(pred_mask_list)
    gt_mask_list = np.vstack(gt_mask_list)
    pred_mask_list_hard = ((pred_mask_list > 0.5) + 0)
    # print(gt_mask_list.shape, pred_mask_list_hard.shape)
    

    DS_list = []
    JS_list = []
    RC_list = []
    PC_list = []

    for i, gt_mask in enumerate(gt_mask_list):      
        Inter = np.sum((pred_mask_list_hard[i] + gt_mask) == 2) ## True positive
        FN = np.sum(((1-pred_mask_list_hard[i]) + gt_mask) == 2) ## False negative
        FP = np.sum((pred_mask_list_hard[i] + (1 - gt_mask)) == 2) ## False positive
        DS_Union = np.sum(pred_mask_list_hard[i]) + np.sum(gt_mask)
        Union = np.sum((pred_mask_list_hard[i] + gt_mask) >= 1)
        DS = (Inter*2) / (DS_Union + 1e-8)
        JS = Inter/(Union + 1e-8)
        RC = Inter/(Inter + FN + 1e-8) ## Recall
        PC = Inter/(Inter + FP + 1e-8) ## precision
        DS_list.append(DS)
        JS_list.append(JS)
        RC_list.append(RC)
        PC_list.append(PC)
        
    DS_mean = np.mean(DS_list)
    JS_mean = np.mean(JS_list)
    RC_mean = np.mean(RC_list)
    PC_mean = np.mean(PC_list)
    print(f'Dice Similarity:    {DS_mean:0.4f} \nJaccard Similarity: {JS_mean:0.4f} \nRecall: {RC_mean:0.4f} \nPrecision: {PC_mean:0.4f}')
    return DS_mean, JS_mean

In [30]:
filename = 'models/unet_12-01_15-41.pt'

In [14]:
# load best model
model.load_state_dict(torch.load(f'{filename}'))

<All keys matched successfully>

In [15]:
train_loader = train_2_4_loader
val_loader = test_2_4_loader

In [33]:
evaluate(model, val_loader, mode)

Dice Similarity:    0.9310 
Jaccard Similarity: 0.8746 
Recall: 0.9600 
Precision: 0.9098


(0.9309747890656567, 0.8746189577149983)

## Recall = TP / (TP + FN) 
## Precision = TP / (TP + FP)
## Precision이 더 낮으니까 FP가 큰 것으로 생각됨 (넓게 잡는듯)

### 그리고 ... loader가 바뀌어서 성능이 더 낮게 나오는 거 아닐까 싶습니다

In [44]:
from utils import *

In [45]:
x,y = next(iter(test_2_4_loader))

In [46]:
x.shape, y.shape

(torch.Size([16, 3, 400, 400]), torch.Size([16, 1, 400, 400]))

In [35]:
evaluate(model, test_2_4_loader, mode)

Dice Similarity:    0.9444 
Jaccard Similarity: 0.8971


(0.9443990112442497, 0.8971101449344325)

In [47]:
evaluate(model, val_loader, mode)

Dice Similarity:    0.9473 
Jaccard Similarity: 0.9027


(0.9473162733524139, 0.9026605143462214)

In [18]:
evaluate(model, test_4_loader, mode)

Dice Similarity:    0.9581 
Jaccard Similarity: 0.9206


(0.9580883533139442, 0.9205887956478481)