# Import Library

In [1]:
%load_ext autoreload
%autoreload 2


import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
import datetime

import torch
from torch.utils.data import DataLoader

from sklearn.model_selection import ShuffleSplit
from torch.utils.data import Subset

from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from torch import optim
import numpy as np
import pickle
from tqdm import tqdm

import albumentations as A
from albumentations.pytorch import ToTensorV2

from utils import *

from models import caranet
from unet import pretrained_unet

from metrics import get_DC, get_JS, DiceLoss



In [2]:
##### Hyperparameter Settings ####
device = 'cuda' if torch.cuda.is_available() else 'cpu'
leraning_rate = 0.001
weight_decay = 1e-6
batch_size = 16
num_epochs = 1000
early_stopping_patience = 10
random_seed = 42
date_time = datetime.datetime.now().strftime("%m-%d_%H-%M")
# model_name = 'caranet' # 'unet'
model_name = 'unet'
filename = f'models/{model_name}_{date_time}.pt'

mode = 'base' # base | caranet 
# mode = 'caranet' # base | caranet 
#################################

# Dataloader

In [3]:
transform = A.Compose([
    A.HorizontalFlip(),
    ToTensorV2(transpose_mask=True)
])



In [4]:
_2_4_loader, _2_loader, _4_loader = create_loader(transform, random_seed, batch_size, mode)

train_2_4_loader, val_2_4_loader, test_2_4_loader = _2_4_loader
train_2_loader, val_2_loader, test_2_loader = _2_loader
train_4_loader, val_4_loader, test_4_loader = _4_loader


train image shape: (1600, 400, 400, 3) 
train mask shape: (1600, 400, 400, 1)
test image shape: (200, 400, 400, 3) 
test mask shape: (200, 400, 400, 1)


# Training

In [5]:
# model = caranet().to(device)
model = pretrained_unet(True).to(device)
optimizer = Adam(model.parameters(), lr=leraning_rate, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=20, min_lr=leraning_rate/1000, verbose=True)
criterion = DiceLoss()

early_stopping = EarlyStopping(patience = 20, verbose = True, path = filename)
loss_dict = {'train': [], 'val': []}

Downloading: "https://github.com/mateuszbuda/brain-segmentation-pytorch/releases/download/v1.0/unet-e012d006.pt" to /home/kang/.cache/torch/hub/checkpoints/unet-e012d006.pt


In [6]:
train_loader = train_2_4_loader
val_loader = val_2_4_loader

In [9]:
for epoch in range(num_epochs):
    model.train()
    
    train_losses = []
    for it_1, (img, mask) in enumerate(tqdm(train_loader)):
        #print(train_img)
        img = img.to(device)
        mask = mask.to(device).float()
        #print(train_label)
        if mode == 'base':
            y_pred = model(img)
            loss = criterion(y_pred, mask)
            train_losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        elif mode == 'caranet':
            train_lrmap_5,train_lrmap_3,train_lrmap_2,train_lrmap_1 = model(img)
            train_loss5 = structure_loss(train_lrmap_5, mask)
            train_loss3 = structure_loss(train_lrmap_3, mask)
            train_loss2 = structure_loss(train_lrmap_2, mask)
            train_loss1 = structure_loss(train_lrmap_1, mask)
        
            loss = train_loss5 + train_loss3 + train_loss2 + train_loss1
        
            train_losses.append(loss.item())
            
            optimizer.zero_grad()
            loss.backward()
            clip_gradient(optimizer, 0.5)
            optimizer.step()
    
    train_loss = np.average(train_losses)

    loss_dict['train'].append(train_loss)

    model.eval()
    with torch.no_grad():
        valid_losses = []
        for it_2, (img, mask) in enumerate(val_loader):
            img = img.to(device)
            mask = mask.to(device).float()
            #print(train_label)
            if mode == 'base':
                y_pred = model(img)
                loss = criterion(y_pred, mask)
                valid_losses.append(loss.item())

            elif mode == 'caranet':
                lrmap_5, lrmap_3, lrmap_2, lrmap_1 = model(img)
                loss5 = structure_loss(lrmap_5, mask)
                loss3 = structure_loss(lrmap_3, mask)
                loss2 = structure_loss(lrmap_2, mask)
                loss1 = structure_loss(lrmap_1, mask)
            
                loss = loss5 + loss3 + loss2 + loss1
                valid_losses.append(loss.item())
        
        valid_loss = np.average(valid_losses)
        scheduler.step(valid_loss)
        
        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break
#            scheduler.step(float(val_loss))
        loss_dict['val'].append(valid_loss)
            
    print(f'train_loss: {train_loss:0.5f}   val_loss: {valid_loss:0.5f}')

100%|██████████| 88/88 [00:19<00:00,  4.44it/s]


Validation loss decreased (inf --> 0.156516).  Saving model ...
train_loss: 0.10209789059378883 val_loss: 0.15651596967990583


100%|██████████| 88/88 [00:19<00:00,  4.41it/s]


Validation loss decreased (0.156516 --> 0.156196).  Saving model ...
train_loss: 0.09034808047793129 val_loss: 0.15619559929921076


100%|██████████| 88/88 [00:20<00:00,  4.38it/s]


Validation loss decreased (0.156196 --> 0.095465).  Saving model ...
train_loss: 0.0871887511827729 val_loss: 0.09546504570887639


100%|██████████| 88/88 [00:20<00:00,  4.36it/s]


EarlyStopping counter: 1 out of 20
train_loss: 0.08117315444079312 val_loss: 0.10694830692731418


100%|██████████| 88/88 [00:20<00:00,  4.35it/s]


EarlyStopping counter: 2 out of 20
train_loss: 0.07712138850580562 val_loss: 0.1141365308028001


100%|██████████| 88/88 [00:20<00:00,  4.34it/s]


Validation loss decreased (0.095465 --> 0.089737).  Saving model ...
train_loss: 0.07607273012399673 val_loss: 0.08973680092738225


100%|██████████| 88/88 [00:20<00:00,  4.34it/s]


Validation loss decreased (0.089737 --> 0.083466).  Saving model ...
train_loss: 0.07237258959900249 val_loss: 0.0834656862112192


100%|██████████| 88/88 [00:20<00:00,  4.33it/s]


EarlyStopping counter: 1 out of 20
train_loss: 0.07064845345237038 val_loss: 0.08862636181024405


100%|██████████| 88/88 [00:20<00:00,  4.33it/s]


EarlyStopping counter: 2 out of 20
train_loss: 0.06594702127304944 val_loss: 0.11146635275620681


100%|██████████| 88/88 [00:20<00:00,  4.32it/s]


EarlyStopping counter: 3 out of 20
train_loss: 0.07163622027093713 val_loss: 0.08799645075431237


100%|██████████| 88/88 [00:20<00:00,  4.31it/s]


EarlyStopping counter: 4 out of 20
train_loss: 0.06668705967339603 val_loss: 0.09421787353662345


100%|██████████| 88/88 [00:20<00:00,  4.32it/s]


Validation loss decreased (0.083466 --> 0.074835).  Saving model ...
train_loss: 0.06430418992584402 val_loss: 0.07483511704664964


100%|██████████| 88/88 [00:20<00:00,  4.32it/s]


EarlyStopping counter: 1 out of 20
train_loss: 0.06243199448693882 val_loss: 0.07642429150067843


 23%|██▎       | 20/88 [00:04<00:16,  4.14it/s]


KeyboardInterrupt: 

Traceback (most recent call last):
  File "/home/kang/anaconda3/envs/torch/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/home/kang/anaconda3/envs/torch/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/kang/anaconda3/envs/torch/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/home/kang/anaconda3/envs/torch/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/home/kang/anaconda3/envs/torch/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/home/kang/anaconda3/envs/torch/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/kang/anaconda3/envs/torch/lib/python3.7/mult

# Finetuning

# Evaluate

In [None]:
# load best model
model.load_state_dict(torch.load(f'{filename}.pt'))