In [1]:
from collections import Counter
import warnings
import random
import os

from ray.tune.schedulers import ASHAScheduler
from medmnist.dataset import OrganAMNIST
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
from ray.air import Checkpoint
from sklearn.metrics import *
import torch.optim as optim
from ray.air import session
from ray import air, tune
from tqdm import tqdm
import torch.nn as nn
import pandas as pd
import numpy as np
import torch

## 참고 : https://velog.io/@sdj4819/Focal-Loss
from misc.FocalLoss import FocalLoss

warnings.filterwarnings(action = 'ignore')

In [2]:
ROOT_PATH = '/'.join(os.getcwd().split('/')[:-2])
DATA_PATH = f'{ROOT_PATH}/Dataset/organ_MNIST'
DEVICE    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

os.makedirs(DATA_PATH, exist_ok = True)
DEVICE

device(type='cuda')

In [3]:
idx2label = {
                 0 : 'bladder',  1 :  'femur left', 2 :  'femur right',
                 3 :   'heart',  4 : 'kidney left', 5 : 'kidney right',
                 6 :   'liver',  7 :   'lung left', 8 :   'lung right',
                 9 :'pancreas', 10 :      'spleen'
            }

In [4]:
class OrganNet(nn.Module):
    
    def __init__(self, n_classes):
        
        super(OrganNet, self).__init__()
        
        self.conv1   = self.ConvBlock( 1, 16, 3)
        self.conv2   = self.ConvBlock(16, 16, 3)
        self.conv3   = self.ConvBlock(16, 64, 3)
        self.conv4   = self.ConvBlock(64, 64, 3)
        self.conv5   = self.ConvBlock(64, 64, 3, 1)
        self.pooling = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.linear  = self.LinearBlock(64, n_classes)
        
        
    
    def ConvBlock(self, in_feats, out_feats, kernel = 3, padding = None):
            
        Conv2d = nn.Conv2d(in_feats, out_feats, kernel_size = kernel) \
                 if padding == None else nn.Conv2d(in_feats, out_feats, kernel_size = kernel, padding = padding)
        
        layers = nn.Sequential(
                    Conv2d,
                    nn.BatchNorm2d(out_feats),
                    nn.ReLU()
                )
        
        return layers
    
    
    def LinearBlock(self, in_feats, n_classes):
        
        layers = nn.Sequential(
                    nn.Linear(in_feats * 4 * 4, 128),
                    nn.ReLU(),
                    nn.Linear(128, 128),
                    nn.ReLU(),
                    nn.Linear(128, n_classes)
                )
        
        return layers
    

    def forward(self, x):
        
        x = self.conv1(x)
        x = self.pooling(self.conv2(x))
        x = self.conv3(x)
        x = self.pooling(self.conv4(x))
        
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        
        return x

In [5]:
lb_counter = lambda dataset: sorted(Counter([lb[0] for lb in dataset.labels]).items())
to_list    = lambda tensor: tensor.detach().cpu().numpy().tolist()

def load_data():
    transform           = {}
    transform['train']  = transforms.Compose([
                                    transforms.RandomVerticalFlip(),
                                    transforms.Resize((28, 28)),
                                    transforms.ToTensor()
                            ]) 
    transform['valid']  = transforms.Compose([
                                transforms.Resize((28, 28)),
                                transforms.ToTensor()
                            ])

    train_dataset = OrganAMNIST(
                                    split  = 'train', download  = True,
                                    as_rgb = False  , transform = transform['train'],
                                    root   = DATA_PATH
                               )

    valid_dataset = OrganAMNIST(
                                    split  = 'val', download  = True,
                                    as_rgb = False, transform = transform['valid'],
                                    root   = DATA_PATH
                              )
    
    return train_dataset, valid_dataset

def train(model, loader, criterion, optimizer, batch_size):
    
    total_loss = 0
    correct    = 0
    
    gt, predicted  = [], []
    
    model.train()
    for image, label in loader:
        
        image    = image.float().to(DEVICE)
        label    = label.to(DEVICE).squeeze().long()
        outputs  = model(image)
        _, preds = torch.max(outputs, 1)
        loss     = criterion(outputs, label)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        total_loss += loss.item()
        correct    += torch.sum(preds == label.data)
        predicted  += to_list(preds)
        gt         += to_list(label)
        
    
    avg_loss = total_loss / len(loader)
    accuracy =    correct / (len(loader) * batch_size)
    accuracy = float(accuracy.detach().cpu().numpy())
    f1       = f1_score(gt, predicted, average = 'weighted')
    
def valid(model, loader, criterion, batch_size):
    
    total_loss, correct =  0,  0
    gt, predicted       = [], []
    
    model.eval()
    with torch.no_grad():
        
        for image, label in loader:
            
            image    = image.float().to(DEVICE)
            label    = label.to(DEVICE).squeeze().long()
            outputs  = model(image)
            _, preds = torch.max(outputs, 1)
            loss     = criterion(outputs, label)
            
            total_loss += loss.item()
            correct    += torch.sum(preds == label.data)
            
            predicted  += to_list(preds)
            gt         += to_list(label)
            
        
    avg_loss = total_loss / len(loader)
    accuracy = correct / (len(loader) * batch_size)
    accuracy = float(accuracy.detach().cpu().numpy())
    f1       = f1_score(gt, predicted, average = 'weighted')

    return avg_loss, accuracy, f1


def tune_function(config):
    model     = OrganNet(n_classes = len(idx2label.keys())).to(DEVICE)
        
    criterion = config['criterion']
    optimizer = config['optimizer'](model.parameters(), lr = config['lr'])
        
    to_list = lambda tensor: tensor.detach().cpu().numpy().tolist()
    

    train_loader = DataLoader(train_dataset, shuffle =  True, batch_size = config['batch_size'])
    valid_loader = DataLoader(valid_dataset, shuffle = False, batch_size = config['batch_size'])
    
    train_count = [cnt[1] for cnt in lb_counter(train_dataset)]
    alpha       = [1 - cnt / sum(train_count) for cnt in train_count]
    
    if 'focal' in criterion.__class__.__name__.lower():
        criterion.alpha = alpha
        criterion.gamma = config['gamma']

    early_stopping_cnt =  0
    
    for epoch in range(config['epoch']):
        train(model, train_loader, criterion, optimizer, config['batch_size'])
        loss, accuracy, f1 = valid(model, valid_loader, criterion, config['batch_size'])
        
        os.makedirs('tuned_model', exist_ok = True)
        checkpoint = Checkpoint.from_directory('tuned_model')
        session.report({'loss' : loss, 'accuracy' : accuracy, 'f1' : f1}, checkpoint = checkpoint)

In [6]:
config = {
            'criterion'  : tune.grid_search([FocalLoss(), nn.CrossEntropyLoss()]),
            'lr'         : tune.loguniform(1e-7, 1e-2),
            'batch_size' : tune.grid_search([8, 16, 32, 64, 128]),
            'optimizer'  : tune.grid_search([optim.Adam, optim.SGD]),
            'epoch'      : tune.grid_search([10, 15, 20]),
            'gamma'      : tune.grid_search([0, 0.125, 0.25, 0.5, 1, 2, 5])
        }

train_dataset, valid_dataset = load_data()

Using downloaded and verified file: /home/jovyan/NVIDIA_CUDA-11.1_Samples/TIL/AI_study/Dataset/organ_MNIST/organamnist.npz
Using downloaded and verified file: /home/jovyan/NVIDIA_CUDA-11.1_Samples/TIL/AI_study/Dataset/organ_MNIST/organamnist.npz


In [None]:
scheduler = ASHAScheduler(
                max_t = 20, grace_period = 1, reduction_factor = 2
            )

tuner = tune.Tuner(
                tune.with_resources(
                    tune.with_parameters(tune_function),
                    resources = {'cpu' : 2, 'gpu' : 1}
                ),
                tune_config = tune.TuneConfig(
                                metric = 'accuracy',
                                mode   = 'max',
                                scheduler = scheduler
                            ),
                param_space = config,
            )

results = tuner.fit()

2023-05-24 08:53:15,709	INFO worker.py:1625 -- Started a local Ray instance.
2023-05-24 08:53:20,733	INFO tune.py:218 -- Initializing Ray automatically. For cluster usage or custom Ray initialization, call `ray.init(...)` before `Tuner(...)`.


0,1
Current time:,2023-05-24 08:54:24
Running for:,00:01:03.34
Memory:,62.8/1007.7 GiB

Trial name,status,loc,batch_size,criterion,epoch,gamma,lr,optimizer,iter,total time (s),loss,accuracy,f1
tune_function_6f3de_00000,RUNNING,121.160.102.68:1592113,8,FocalLoss(),10,0,5.9822e-05,<class 'torch.o_06e0,1.0,27.1615,0.445597,0.793411,0.778323
tune_function_6f3de_00001,RUNNING,121.160.102.68:1592373,16,FocalLoss(),10,0,3.80643e-06,<class 'torch.o_06e0,1.0,17.9292,1.47093,0.37931,0.291246
tune_function_6f3de_00004,RUNNING,121.160.102.68:1592384,128,FocalLoss(),10,0,8.04893e-06,<class 'torch.o_06e0,4.0,24.9931,1.06209,0.586857,0.561325
tune_function_6f3de_00005,RUNNING,121.160.102.68:1592446,8,CrossEntropyLoss(),10,0,7.40525e-05,<class 'torch.o_06e0,,,,,
tune_function_6f3de_00006,RUNNING,121.160.102.68:1592581,16,CrossEntropyLoss(),10,0,4.29617e-06,<class 'torch.o_06e0,1.0,17.175,1.61422,0.380234,0.28758
tune_function_6f3de_00007,RUNNING,121.160.102.68:1592774,32,CrossEntropyLoss(),10,0,0.000105145,<class 'torch.o_06e0,2.0,20.4066,0.464368,0.80819,0.805074
tune_function_6f3de_00010,RUNNING,121.160.102.68:1592377,8,FocalLoss(),15,0,0.00197399,<class 'torch.o_06e0,,,,,
tune_function_6f3de_00011,RUNNING,121.160.102.68:1592381,16,FocalLoss(),15,0,2.81982e-06,<class 'torch.o_06e0,,,,,
tune_function_6f3de_00012,PENDING,,32,FocalLoss(),15,0,0.000287036,<class 'torch.o_06e0,,,,,
tune_function_6f3de_00013,PENDING,,64,FocalLoss(),15,0,0.000231454,<class 'torch.o_06e0,,,,,


[2m[36m(tune_function pid=1592113)[0m   log_pt = F.log_softmax(input)
[2m[36m(tune_function pid=1592384)[0m   log_pt = F.log_softmax(input)[32m [repeated 5x across cluster][0m


Trial name,accuracy,date,done,f1,hostname,iterations_since_restore,loss,node_ip,pid,should_checkpoint,time_since_restore,time_this_iter_s,time_total_s,timestamp,training_iteration,trial_id
tune_function_6f3de_00000,0.793411,2023-05-24_08-54-23,False,0.778323,ubuntu,1,0.445597,121.160.102.68,1592113,True,27.1615,27.1615,27.1615,1684918463,1,6f3de_00000
tune_function_6f3de_00001,0.37931,2023-05-24_08-54-16,False,0.291246,ubuntu,1,1.47093,121.160.102.68,1592373,True,17.9292,17.9292,17.9292,1684918456,1,6f3de_00001
tune_function_6f3de_00002,0.491071,2023-05-24_08-54-20,True,0.427025,ubuntu,2,1.29432,121.160.102.68,1592377,True,21.3717,9.25627,21.3717,1684918460,2,6f3de_00002
tune_function_6f3de_00003,0.135263,2023-05-24_08-54-09,True,0.127305,ubuntu,1,2.12342,121.160.102.68,1592381,True,9.8358,9.8358,9.8358,1684918449,1,6f3de_00003
tune_function_6f3de_00004,0.586857,2023-05-24_08-54-24,False,0.561325,ubuntu,4,1.06209,121.160.102.68,1592384,True,24.9931,5.24376,24.9931,1684918464,4,6f3de_00004
tune_function_6f3de_00005,0.809729,2023-05-24_08-54-26,False,0.794907,ubuntu,1,0.472503,121.160.102.68,1592446,True,27.1213,27.1213,27.1213,1684918466,1,6f3de_00005
tune_function_6f3de_00006,0.380234,2023-05-24_08-54-17,False,0.28758,ubuntu,1,1.61422,121.160.102.68,1592581,True,17.175,17.175,17.175,1684918457,1,6f3de_00006
tune_function_6f3de_00007,0.821121,2023-05-24_08-54-29,False,0.808481,ubuntu,3,0.427611,121.160.102.68,1592774,True,28.9541,8.54751,28.9541,1684918469,3,6f3de_00007
tune_function_6f3de_00008,0.13894,2023-05-24_08-54-15,True,0.113302,ubuntu,1,2.36386,121.160.102.68,1592381,True,6.65298,6.65298,6.65298,1684918455,1,6f3de_00008
tune_function_6f3de_00009,0.303615,2023-05-24_08-54-21,True,0.170501,ubuntu,1,1.92799,121.160.102.68,1592381,True,5.21597,5.21597,5.21597,1684918461,1,6f3de_00009


2023-05-24 08:54:09,272	INFO tensorboardx.py:269 -- Removed the following hyperparameter values when logging to tensorboard: {'criterion': ('__ref_ph', 'f6251774'), 'optimizer': ('__ref_ph', '7bdb82fb')}
[2m[36m(tune_function pid=1592384)[0m   log_pt = F.log_softmax(input)[32m [repeated 5x across cluster][0m
2023-05-24 08:54:15,950	INFO tensorboardx.py:269 -- Removed the following hyperparameter values when logging to tensorboard: {'criterion': ('__ref_ph', '449a4ea2'), 'optimizer': ('__ref_ph', '7bdb82fb')}
[2m[36m(tune_function pid=1592384)[0m   log_pt = F.log_softmax(input)[32m [repeated 4x across cluster][0m
2023-05-24 08:54:20,529	INFO tensorboardx.py:269 -- Removed the following hyperparameter values when logging to tensorboard: {'criterion': ('__ref_ph', 'f6251774'), 'optimizer': ('__ref_ph', '7bdb82fb')}
2023-05-24 08:54:21,187	INFO tensorboardx.py:269 -- Removed the following hyperparameter values when logging to tensorboard: {'criterion': ('__ref_ph', '449a4ea2'), '

In [8]:
results

ResultGrid<[
  Result(
    metrics={'loss': 2.0876925903208114, 'accuracy': 0.09375, 'f1': 0.05604647312191422, 'should_checkpoint': True, 'done': True, 'trial_id': '53c40_00000', 'experiment_tag': '0_batch_size=64,criterion=ref_ph_f6251774,epoch=10,gamma=0.2500,lr=0.0000,optimizer=ref_ph_46a9caad'},
    path='/home/jovyan/ray_results/tune_function_2023-05-24_08-38-09/tune_function_53c40_00000_0_batch_size=64,criterion=ref_ph_f6251774,epoch=10,gamma=0.2500,lr=0.0000,optimizer=ref_ph_46a9caad_2023-05-24_08-38-16',
    checkpoint=Checkpoint(local_path=/home/jovyan/ray_results/tune_function_2023-05-24_08-38-09/tune_function_53c40_00000_0_batch_size=64,criterion=ref_ph_f6251774,epoch=10,gamma=0.2500,lr=0.0000,optimizer=ref_ph_46a9caad_2023-05-24_08-38-16/checkpoint_000009)
  )
]>