In [None]:
# !. ../venv/bin/activate

In [None]:
# !python3 -mpip install pip --upgrade
# !python3 -mpip install -r ./requirements.txt

In [None]:
from __future__ import print_function, division


import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import copy
import json
import importlib
import glob
import pandas as pd
from skimage import io, transform
import matplotlib.pyplot as plt
from matplotlib.image import imread
import numpy as np

import torch
# from sklearn.model_selection import train_test_split
from torch.optim import (
  Adam
)
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

from extra.utils import (
  load_config,
  _print,
)

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

## Global variables

In [None]:
CONFIG_NAME = "isic2018_unet.yaml"
CONFIG_FILE_PATH = os.path.join("./configs", CONFIG_NAME)

In [None]:
config = load_config(CONFIG_FILE_PATH)
_print("Config:", "info_underline")
print(json.dumps(config, indent=2))
print(20*"~-", "\n")

In [None]:
from datasets import (
  ISIC2018Dataset
)

Dataset = globals()[config['dataset']['class_name']]
training_dataset = Dataset(**config['dataset']['training']['params'])
validation_dataset = Dataset(**config['dataset']['validation']['params'])

print(f"Length of trainig_dataset:\t{len(training_dataset)}")
print(f"Length of validation_dataset:\t{len(validation_dataset)}")

train_dataloader = DataLoader(training_dataset, **config['data_loader']['train'])
validation_dataloader = DataLoader(validation_dataset, **config['data_loader']['validation'])

In [None]:
from models.unet import Unet
from losses import DiceLoss


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Net = globals()[f"{config['model']['name']}"](**config['model']['params'])

criterion = DiceLoss()
optimizer = globals()[config['training']['optimizer']['name']]
optimizer = optimizer(Net.parameters(), **config['training']['optimizer']['params'])
# optimizer = optim.RMSprop(Net.parameters(), lr= float(config['lr']), weight_decay=1e-8, momentum=0.9)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

Net.to(device)
0

In [None]:
epochs = 1
steps = 0
running_loss = 0
print_every = 1
train_losses, test_losses = [], []

for epoch in range(epochs):
    for batch in train_dataloader:
        imgs = batch['img']
        msks = batch['msk']
        steps += 1
        imgs, msks = imgs.to(device), msks.to(device)
        optimizer.zero_grad()
        preds = Net.forward(imgs)
        loss = criterion(preds, msks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
        if steps % print_every == 0:
            test_loss = 0
            accuracy = 0
            Net.eval()
            with torch.no_grad():
                for batch in validation_dataloader:
                    imgs = batch['img']
                    labels = batch['msk']
                    imgs, labels = imgs.to(device), labels.to(device)
                    preds = Net.forward(imgs)
                    batch_loss = criterion(preds, labels)
                    test_loss += batch_loss.item()
                    
                    ps = torch.exp(preds)
                    top_p, top_class = ps.topk(1, dim=1)
                    equals = top_class == labels.view(*top_class.shape)
                    accuracy += torch.mean(equals.type(torch.FloatTensor)).item()

            train_losses.append(running_loss/len(train_dataloader))
            test_losses.append(test_loss/len(validation_dataloader))                    
            print(f"Epoch {epoch+1}/{epochs}.. "
                  f"Train loss: {running_loss/print_every:.3f}.. "
                  f"Test loss: {test_loss/len(validation_dataloader):.3f}.. "
                  f"Test accuracy: {accuracy/len(validation_dataloader):.3f}")
            running_loss = 0
            Net.train()
torch.save(Net, 'basemodel.pth')


In [None]:
def show_img_msk(img, msk):
    _, axs = plt.subplots(1, 2, figsize=(8,4))
    x = torch.permute(img.squeeze(), [1, 2, 0]).numpy().astype(np.float)
    y = msk.squeeze().numpy().astype(np.float)
    
    print(f"x shape: {x.shape}, y shape: {y.shape}")
    
    axs[0].imshow(x)
    axs[1].imshow(y)
    plt.show()

In [None]:
show_img_msk(imgs[0], msks[0])