In [2]:
from __future__ import print_function, division


import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import copy
import json
import importlib
import glob
import pandas as pd
from skimage import io, transform
import matplotlib.pyplot as plt
from matplotlib.image import imread
import numpy as np
from tqdm import tqdm_notebook as tqdm

import torch

from utils import (
    show_sbs,
    load_config,
    _print,
)

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

<matplotlib.pyplot._IonContext at 0x7f72a5158c40>

In [3]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f71f8a63030>

In [4]:
CONFIG_NAME = "isic2018_unet.yaml"
CONFIG_FILE_PATH = os.path.join("./configs", CONFIG_NAME)

In [5]:
config = load_config(CONFIG_FILE_PATH)
# _print("Config:", "info_underline")
# print(json.dumps(config, indent=2))
# print(20*"~-", "\n")

## Functions

In [6]:
def show_img_pred_gt(img, pred, gt, figsize=(8,4)):
    _, axs = plt.subplots(1, 3, figsize=figsize)
    x = img.squeeze().permute([1, 2, 0]).to('cpu').numpy().astype(np.float)
    y = gt.squeeze().detach().to('cpu').numpy().astype(np.float)
    p = pred.squeeze().detach().to('cpu').numpy().astype(np.float)
    axs[0].imshow(x); axs[0].set_title('image')
    axs[1].imshow(p); axs[1].set_title('pred')
    axs[2].imshow(y); axs[2].set_title('gt')
    plt.show()
    

def plot_loss_curves(tr, vl):
#     plt.figure(figsize=(12, 4))
    plt.title("Loss Curve")
    plt.xlabel("Epoch")
    plt.ylabel("DiceLoss")
    plt.plot(range(len(tr)), tr, 'r')
    plt.plot(range(len(vl)), vl, 'b')
    plt.legend(['Traning', 'Validation'])
    plt.show()

### dataset and dataloader

In [7]:
!pip install --upgrade torchvision

Collecting torchvision
  Using cached torchvision-0.13.1-cp38-cp38-manylinux1_x86_64.whl (19.1 MB)
[0mInstalling collected packages: torchvision
  Attempting uninstall: torchvision
[0m    Found existing installation: torchvision 0.10.0
    Uninstalling torchvision-0.10.0:
      Successfully uninstalled torchvision-0.10.0
Successfully installed torchvision-0.13.1
[0m

In [None]:
!conda update torchvision
# !pip install torchvision

In [None]:
from datasets.isic import ISIC2018TrainingDataset
from torch.utils.data import DataLoader, Subset
from torchvision import transforms

In [None]:
# ------------------- params --------------------
INPUT_SIZE = config['dataset']['input_size']
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<


# ----------------- transforms ------------------
# transform for image
img_transform = transforms.Compose([
    transforms.Resize(
        size=[INPUT_SIZE, INPUT_SIZE], 
        interpolation=transforms.functional.InterpolationMode.BILINEAR
    ),
])
# transform for mask
msk_transform = transforms.Compose([
    transforms.Resize(
        size=[INPUT_SIZE, INPUT_SIZE], 
        interpolation=transforms.functional.InterpolationMode.NEAREST
    ),
])
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<


# ----------------- dataset --------------------
# preparing training dataset
train_dataset = ISIC2018TrainingDataset(
    img_transform=img_transform,
    msk_transform=msk_transform
)

# We consider 1815 samples for training, 259 samples for validation and 520 samples for testing
# !cat ~/deeplearning/skin/Prepare_ISIC2018.py

indices = list(range(len(train_dataset)))

# split indices to: -> train, validation, and test
tr_indices = indices[0:1815]
vl_indices = indices[1815:1815+259]
te_indices = indices[1815+259:2594]

# create new datasets from train dataset as training, validation, and test
tr_dataset = Subset(train_dataset, tr_indices)
vl_dataset = Subset(train_dataset, vl_indices)
te_dataset = Subset(train_dataset, te_indices)
print(f"Length of trainig_dataset:\t{len(tr_dataset)}")
print(f"Length of validation_dataset:\t{len(vl_dataset)}")
print(f"Length of test_dataset:\t\t{len(te_dataset)}")


# prepare train dataloader
tr_dataloader = DataLoader(tr_dataset, **config['data_loader']['train'])

# prepare validation dataloader
vl_dataloader = DataLoader(vl_dataset, **config['data_loader']['validation'])

# prepare test dataloader
te_dataloader = DataLoader(te_dataset, **config['data_loader']['test'])

# -------------- test -----------------
# test and visualize the input data
for img, msk in tr_dataloader:
    print("Training")
    show_sbs(img[0], msk[0])
    break
    
for img, msk in vl_dataloader:
    print("Validation")
    show_sbs(img[1], msk[1])
    break

### Device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Torch device: {device}")

### model and config

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchmetrics
from models.unet import Unet

from torch.optim import (
  Adam
)

from losses import (
    DiceLoss
)

In [None]:
metrics = MetricCollection(
    [
        MetricCollection([
            Accuracy(num_classes=3, average='macro'),
            Precision(num_classes=3, average='macro')
        ], postfix='_macro'),
        MetricCollection([
            Accuracy(num_classes=3, average='micro'),
            Precision(num_classes=3, average='micro')
        ], postfix='_micro'),
    ], 
    prefix='valmetrics/'
)

In [None]:
metrics = torchmetrics.MetricCollection(
    [
        torchmetrics.MetricCollection(
            [
                torchmetrics.F1Score(num_classes=2, threshold=0.5, average='micro'),
                torchmetrics.Accuracy(num_classes=2, threshold=0.5, average='micro'),
                torchmetrics.Dice(num_classes=2, threshold=0.5, average='micro'),
            ], postfix='_micro'),
        torchmetrics.MetricCollection(
            [
                torchmetrics.F1Score(num_classes=2, threshold=0.5, average='macro'),
                torchmetrics.Accuracy(num_classes=2, threshold=0.5, average='macro'),
                torchmetrics.Dice(num_classes=2, threshold=0.5, average='macro'),
            ], postfix='_macro'),
    ],
    prefix='train_metrics/'
)

# train_metrics
train_metrics = metrics.clone(prefix='train_metrics/')

# valid_metrics
valid_metrics = metrics.clone(prefix='valid_metrics/')

# test_metrics
test_metrics = metrics.clone(prefix='test_metrics/')

## validate

In [None]:
def validate(model, criterion, vl_dataloader):
    model.eval()
    with torch.no_grad():
#         print('validating...')
        
        # calculate metrics per batch
        evaluator = valid_metrics
        
        losses = []
        cnt = 0.
        iterator = tqdm(enumerate(vl_dataloader), leave=None)
        for batch, (imgs, msks) in iterator:
            cnt += msks.shape[0]
            
            imgs = imgs.to(device)
            msks = msks.to(device)
            
            preds = model(imgs)
            loss = criterion(preds, msks)
            losses.append(loss.item())
            
            metrics = evaluator(preds, msks)
            _cml = f"curr_mean-loss:{np.sum(losses)/cnt:0.5f}"
            _bl = f"batch-loss:{losses[-1]/msks.shape[0]:0.5f}"
            iterator.set_description(f"Validation) batch:{batch+1:04d} -> {_cml}, {_bl}")
        
        # print the final results
        loss = np.sum(losses)/cnt
        t_m = evaluator.compute()
        _ams = ', '.join([f'{k}: {v:0.5f}' for k,v in t_m.items()])
        iterator.set_description(f"Validation_result (on all val_data): {_ams}")
        
        evaluator.reset()
    
    return t_m, loss

In [None]:
metrics = torchmetrics.MetricCollection(
    [
        torchmetrics.Accuracy().to(device),
        torchmetrics.Precision().to(device),
        torchmetrics.Specificity().to(device),
        torchmetrics.Recall().to(device),
        torchmetrics.Dice().to(device),
        torchmetrics.F1Score(
#             reduce='macro', 
#             num_classes=2, 
#             multiclass=True,
#             threshold=0.5,
#             reduction='elementwise_mean',
#             average='macro'
        ).to(device),
        torchmetrics.JaccardIndex(
            num_classes=2, 
            ignore_index=None, 
            absent_score=0.0, 
#             threshold=0.5, 
            multilabel=False, 
            reduction='elementwise_mean', 
            compute_on_step=None).to(device),
#         torchmetrics.Accuracy(threshold=0.5, num_classes=1).to(device),
#         torchmetrics.Dice(num_classes=2, threshold=0.5, average='macro').to(device),
    ],
    prefix='train/')

# train_metrics
train_metrics = metrics.clone()


evaluator = train_metrics
msks_int = msks>0.5

for p, m in zip(preds, msks_int):
    evaluator(p.reshape(1, -1).squeeze(), m.reshape(1, -1).squeeze())

res = evaluator.compute()
for k,v in res.items():
    res[k] = v.item()
print(json.dumps(res, indent=4))
evaluator.reset()

In [None]:
idx = 1
show_img_pred_gt(imgs[idx], preds[idx]>0, msks[idx])

In [None]:
tr_prms = config['training']
EPOCHS = tr_prms['epochs']

criterion = globals()[tr_prms['criterion']['name']](**tr_prms['criterion']['params'])
optimizer = globals()[tr_prms['optimizer']['name']](model.parameters(), **tr_prms['optimizer']['params'])
# optimizer = optim.RMSprop(Net.parameters(), lr= float(config['lr']), weight_decay=1e-8, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

# calculate metrics per batch

evaluator = train_metrics


epochs_info = []
best_vl_loss = np.Inf
epoch_tqdm = tqdm(range(EPOCHS), nrows=2)
for epoch in epoch_tqdm:
    model.train()

    evaluator.reset()
    tr_iterator = tqdm(enumerate(tr_dataloader), leave=None)
    tr_losses = []
    cnt = 0
    for batch, (imgs, msks) in tr_iterator:
        imgs = imgs.to(device)
        msks = msks.to(device)

        optimizer.zero_grad()
        preds = model(imgs)
        loss = criterion(preds, msks)
        loss.backward()
        optimizer.step()

        msks_ = torch.argmax(msks.squeeze(), dim=1)
        tr_metrics = evaluator(preds, msks_)
        cnt += imgs.shape[0]
        tr_losses.append(loss.item())

        # write details for each training batch
        _cml = f"curr_mean-loss:{np.sum(tr_losses)/cnt:0.5f}"
        _bl = f"mean_batch-loss:{tr_losses[-1]/imgs.shape[0]:0.5f}"
        tr_iterator.set_description(f"Training) batch:{batch+1:04d} -> {_cml}, {_bl}")


#             if cnt>150: break

    # validate model
#         tr_iterator.set_description(f"Validation... (tr-loss:{np.sum(tr_losses)/cnt:0.5f})")
    vl_metrics, vl_loss = validate(model, criterion, vl_dataloader)
    if vl_loss < best_vl_loss:
        # find a better model
        best_model = model
        best_vl_loss = vl_loss

    # print the final results
    epoch_info = {
        'tr_loss': np.sum(tr_losses)/cnt,
        'vl_loss': vl_loss,
        'tr_metrics': evaluator.compute(),
        'vl_metrics': vl_metrics
    }
    epochs_info.append(epoch_info)

    # write details for this epoch
    _bvl = f'best_vl-loss:{best_vl_loss:0.5f}'
    _ltl = f"last_tr-loss:{epoch_info['tr_loss']:0.5f}"
    _tr_ams = ', '.join([f'tr_{k}: {v:0.4f}' for k,v in epoch_info['tr_metrics'].items()])
    _vl_ams = ', '.join([f'vl_{k}: {v:0.4f}' for k,v in epoch_info['vl_metrics'].items()])

    epoch_tqdm.set_description(f"Epoch:{epoch+1}/{EPOCHS} -> {_bvl}, {_ltl}, {_vl_ams}, {_tr_ams}")

    evaluator.reset()

#         if cnt>5: break

# save final results
res = {
    'id': save_file_id,
    'config': config,
    'epochs_info': epochs_info
}
# fn = f"{save_file_id+'_' if save_file_id else ''}result.json"
# fp = os.path.join(config['model']['save_dir'],fn)
# with open(fp, "w") as write_file:
#     json.dump(res, write_file, indent=4)

# # save model's state_dict
# fn = f"{save_file_id if save_file_id else 'model'}_state_dict.pt"
# fp = os.path.join(config['model']['save_dir'],fn)
# torch.save(model.state_dict(), fp)


## train

In [None]:
def train(
    model, 
    device, 
    tr_dataloader,
    vl_dataloader,
    config,
    save_dir='./',
    save_file_id=None,
):
    tr_prms = config['training']
    EPOCHS = tr_prms['epochs']
  
    criterion = globals()[tr_prms['criterion']['name']](**tr_prms['criterion']['params'])
    optimizer = globals()[tr_prms['optimizer']['name']](model.parameters(), **tr_prms['optimizer']['params'])
    # optimizer = optim.RMSprop(Net.parameters(), lr= float(config['lr']), weight_decay=1e-8, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
    
    # calculate metrics per batch
    
    evaluator = train_metrics
    
    
    epochs_info = []
    best_vl_loss = np.Inf
    epoch_tqdm = tqdm(range(EPOCHS), nrows=2)
    for epoch in epoch_tqdm:
        model.train()
        
        evaluator.reset()
        tr_iterator = tqdm(enumerate(tr_dataloader), leave=None)
        tr_losses = []
        cnt = 0
        for batch, (imgs, msks) in tr_iterator:
            imgs = imgs.to(device)
            msks = msks.to(device)
            
            optimizer.zero_grad()
            preds = model(imgs)
            loss = criterion(preds, msks)
            loss.backward()
            optimizer.step()
            
            msks_ = torch.argmax(msks.squeeze(), dim=1)
            tr_metrics = evaluator(preds, msks_)
            cnt += imgs.shape[0]
            tr_losses.append(loss.item())
            
            # write details for each training batch
            _cml = f"curr_mean-loss:{np.sum(tr_losses)/cnt:0.5f}"
            _bl = f"mean_batch-loss:{tr_losses[-1]/imgs.shape[0]:0.5f}"
            tr_iterator.set_description(f"Training) batch:{batch+1:04d} -> {_cml}, {_bl}")

            
#             if cnt>150: break
            
        # validate model
#         tr_iterator.set_description(f"Validation... (tr-loss:{np.sum(tr_losses)/cnt:0.5f})")
        vl_metrics, vl_loss = validate(model, criterion, vl_dataloader)
        if vl_loss < best_vl_loss:
            # find a better model
            best_model = model
            best_vl_loss = vl_loss
        
        # print the final results
        epoch_info = {
            'tr_loss': np.sum(tr_losses)/cnt,
            'vl_loss': vl_loss,
            'tr_metrics': evaluator.compute(),
            'vl_metrics': vl_metrics
        }
        epochs_info.append(epoch_info)
        
        # write details for this epoch
        _bvl = f'best_vl-loss:{best_vl_loss:0.5f}'
        _ltl = f"last_tr-loss:{epoch_info['tr_loss']:0.5f}"
        _tr_ams = ', '.join([f'tr_{k}: {v:0.4f}' for k,v in epoch_info['tr_metrics'].items()])
        _vl_ams = ', '.join([f'vl_{k}: {v:0.4f}' for k,v in epoch_info['vl_metrics'].items()])
        
        epoch_tqdm.set_description(f"Epoch:{epoch+1}/{EPOCHS} -> {_bvl}, {_ltl}, {_vl_ams}, {_tr_ams}")
        
        evaluator.reset()
        
#         if cnt>5: break
        
    # save final results
    res = {
        'id': save_file_id,
        'config': config,
        'epochs_info': epochs_info
    }
    fn = f"{save_file_id+'_' if save_file_id else ''}result.json"
    fp = os.path.join(config['model']['save_dir'],fn)
    with open(fp, "w") as write_file:
        json.dump(res, write_file, indent=4)

    # save model's state_dict
    fn = f"{save_file_id if save_file_id else 'model'}_state_dict.pt"
    fp = os.path.join(config['model']['save_dir'],fn)
    torch.save(model.state_dict(), fp)
    
    
    return model, res

In [None]:
# show_img_pred_gt(imgs[idx], tp[idx], tn[idx], figsize=(12,4))

In [None]:
# def test(model, modelPath):
#     if not os.path.exists("results"):
#         os.makedirs("results")
#     model.load_state_dict(torch.load(modelPath))
#     model.eval()
#     model = model.to(device)
#     df = pd.DataFrame(columns=classes)
#     with torch.no_grad():
#         for idx, sample in enumerate(dataloader_test):
#             outputs = model(sample['image'].to(device))
#             outputs = softmax(outputs)
#             outputs = outputs.cpu().numpy()
#             df = df.append(
#                 pd.DataFrame(data=outputs, columns=classes))
#         df = df.reset_index()
#         del df['index']
#         df.insert(0, 'image', images['image'])
#     df.to_csv(f'results/test_results.csv', index=False)

In [None]:
# if __name__ == "__main__":
#     model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=9)
#     model, fileName = train_model(model, 16, 0.01, 1, 0, 0, '3')
#     validate(model)
#     test_model(model, fileName)

In [None]:
model = Unet(**config['model']['params'])
torch.cuda.empty_cache()
model = model.to(device)
print("Number of parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

os.makedirs(config['model']['save_dir'], exist_ok=True)
model_path = f"{config['model']['save_dir']}/model_state_dict.pt"

if config['model']['load_weights']:
    model.load_state_dict(torch.load(model_path))
    print("Loaded pre-trained weights...")

In [None]:
_ = train(
    model, 
    device, 
    tr_dataloader,
    vl_dataloader,
    config,
    save_dir = config['model']['save_dir'],
    save_file_id = None,
)