In [1]:
from collections import Counter
import warnings
import random
import os

from ray.tune.schedulers import ASHAScheduler
from medmnist.dataset import OrganAMNIST
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
from ray.air import Checkpoint
from sklearn.metrics import *
import torch.optim as optim
from ray.air import session
from ray import air, tune
from tqdm import tqdm
import torch.nn as nn
import pandas as pd
import numpy as np
import torch

## 참고 : https://velog.io/@sdj4819/Focal-Loss
from misc.FocalLoss import FocalLoss

warnings.filterwarnings(action = 'ignore')

In [2]:
ROOT_PATH = '/'.join(os.getcwd().split('/')[:-2])
DATA_PATH = f'{ROOT_PATH}/dataset/organ_MNIST'
DEVICE    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

os.makedirs(DATA_PATH, exist_ok = True)
DEVICE

device(type='cuda')

In [3]:
idx2label = {
                 0 : 'bladder',  1 :  'femur left', 2 :  'femur right',
                 3 :   'heart',  4 : 'kidney left', 5 : 'kidney right',
                 6 :   'liver',  7 :   'lung left', 8 :   'lung right',
                 9 :'pancreas', 10 :      'spleen'
            }

In [4]:
class OrganNet(nn.Module):
    
    def __init__(self, n_classes):
        
        super(OrganNet, self).__init__()
        
        self.conv1   = self.ConvBlock( 1, 16, 3)
        self.conv2   = self.ConvBlock(16, 16, 3)
        self.conv3   = self.ConvBlock(16, 64, 3)
        self.conv4   = self.ConvBlock(64, 64, 3)
        self.conv5   = self.ConvBlock(64, 64, 3, 1)
        self.pooling = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.linear  = self.LinearBlock(64, n_classes)
        
        
    
    def ConvBlock(self, in_feats, out_feats, kernel = 3, padding = None):
            
        Conv2d = nn.Conv2d(in_feats, out_feats, kernel_size = kernel) \
                 if padding == None else nn.Conv2d(in_feats, out_feats, kernel_size = kernel, padding = padding)
        
        layers = nn.Sequential(
                    Conv2d,
                    nn.BatchNorm2d(out_feats),
                    nn.ReLU()
                )
        
        return layers
    
    
    def LinearBlock(self, in_feats, n_classes):
        
        layers = nn.Sequential(
                    nn.Linear(in_feats * 4 * 4, 128),
                    nn.ReLU(),
                    nn.Linear(128, 128),
                    nn.ReLU(),
                    nn.Linear(128, n_classes)
                )
        
        return layers
    

    def forward(self, x):
        
        x = self.conv1(x)
        x = self.pooling(self.conv2(x))
        x = self.conv3(x)
        x = self.pooling(self.conv4(x))
        
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        
        return x

In [5]:
lb_counter = lambda dataset: sorted(Counter([lb[0] for lb in dataset.labels]).items())
to_list    = lambda tensor: tensor.detach().cpu().numpy().tolist()

def load_data():
    transform           = {}
    transform['train']  = transforms.Compose([
                                    transforms.RandomVerticalFlip(),
                                    transforms.Resize((28, 28)),
                                    transforms.ToTensor()
                            ]) 
    transform['valid']  = transforms.Compose([
                                transforms.Resize((28, 28)),
                                transforms.ToTensor()
                            ])

    train_dataset = OrganAMNIST(
                                    split  = 'train', download  = True,
                                    as_rgb = False  , transform = transform['train'],
                                    root   = DATA_PATH
                               )

    valid_dataset = OrganAMNIST(
                                    split  = 'val', download  = True,
                                    as_rgb = False, transform = transform['valid'],
                                    root   = DATA_PATH
                              )
    
    return train_dataset, valid_dataset


def train(model, loader, criterion, optimizer, batch_size):
    
    total_loss = 0
    correct    = 0
    
    gt, predicted  = [], []
    
    model.train()
    for image, label in loader:
        
        image    = image.float().to(DEVICE)
        label    = label.to(DEVICE).squeeze().long()
        outputs  = model(image)
        _, preds = torch.max(outputs, 1)
        loss     = criterion(outputs, label)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        total_loss += loss.item()
        correct    += torch.sum(preds == label.data)
        predicted  += to_list(preds)
        gt         += to_list(label)
        
    
    avg_loss = total_loss / len(loader)
    accuracy =    correct / (len(loader) * batch_size)
    accuracy = float(accuracy.detach().cpu().numpy())
    f1       = f1_score(gt, predicted, average = 'weighted')
    
def valid(model, loader, criterion, batch_size):
    
    total_loss, correct =  0,  0
    gt, predicted       = [], []
    
    model.eval()
    with torch.no_grad():
        
        for image, label in loader:
            
            image    = image.float().to(DEVICE)
            label    = label.to(DEVICE).squeeze().long()
            outputs  = model(image)
            _, preds = torch.max(outputs, 1)
            loss     = criterion(outputs, label)
            
            total_loss += loss.item()
            correct    += torch.sum(preds == label.data)
            
            predicted  += to_list(preds)
            gt         += to_list(label)
            
        
    avg_loss = total_loss / len(loader)
    accuracy = correct / (len(loader) * batch_size)
    accuracy = float(accuracy.detach().cpu().numpy())
    f1       = f1_score(gt, predicted, average = 'weighted')

    return avg_loss, accuracy, f1


def tune_function(config):
    model     = OrganNet(n_classes = len(idx2label.keys())).to(DEVICE)
        
    criterion = config['criterion']
    optimizer = config['optimizer'](model.parameters(), lr = config['lr'])
        
    to_list = lambda tensor: tensor.detach().cpu().numpy().tolist()
    
    train_dataset, valid_dataset = load_data()
    train_loader = DataLoader(train_dataset, shuffle =  True, batch_size = config['batch_size'])
    valid_loader = DataLoader(valid_dataset, shuffle = False, batch_size = config['batch_size'])
    
    train_count = [cnt[1] for cnt in lb_counter(train_dataset)]
    alpha       = [1 - cnt / sum(train_count) for cnt in train_count]
    
    if 'focal' in criterion.__class__.__name__.lower():
        criterion.alpha = alpha
        criterion.gamma = config['gamma']

    early_stopping_cnt =  0
    
    for epoch in range(config['epoch']):
        train(model, train_loader, criterion, optimizer, config['batch_size'])
        loss, accuracy, f1 = valid(model, valid_loader, criterion, config['batch_size'])
        
        os.makedirs('tuned_model', exist_ok = True)
        checkpoint = Checkpoint.from_directory('tuned_model')
        session.report({'loss' : loss, 'accuracy' : accuracy, 'f1' : f1}, checkpoint = checkpoint)

In [10]:
config = {
            'criterion'  : tune.grid_search([FocalLoss(), nn.CrossEntropyLoss()]),
            'lr'         : tune.loguniform(1e-7, 1e-2),
            'batch_size' : tune.grid_search([8, 16, 32, 64, 128]),
            'optimizer'  : tune.grid_search([optim.Adam, optim.SGD]),
            'epoch'      : tune.grid_search([10, 15, 20]),
            'gamma'      : tune.grid_search([0, 0.125, 0.25, 0.5, 1, 2, 5])
        }

In [None]:
scheduler = ASHAScheduler(
                max_t = 20, grace_period = 1, reduction_factor = 2
            )

tuner = tune.Tuner(
                tune.with_resources(
                    tune.with_parameters(tune_function),
                    resources = {'cpu' : 2, 'gpu' : 1}
                ),
                tune_config = tune.TuneConfig(
                                metric = 'accuracy',
                                mode   = 'max',
                                scheduler = scheduler
                            ),
                param_space = config,
            )

results = tuner.fit()

0,1
Current time:,2023-05-24 08:46:16
Running for:,00:02:34.52
Memory:,46.3/1007.7 GiB

Trial name,# failures,error file
tune_function_15db5_00000,1,"/home/jovyan/ray_results/tune_function_2023-05-24_08-43-41/tune_function_15db5_00000_0_batch_size=8,criterion=ref_ph_f6251774,epoch=10,gamma=0,lr=0.0000,optimizer=ref_ph_7bdb82fb_2023-05-24_08-43-42/error.txt"
tune_function_15db5_00001,1,"/home/jovyan/ray_results/tune_function_2023-05-24_08-43-41/tune_function_15db5_00001_1_batch_size=16,criterion=ref_ph_f6251774,epoch=10,gamma=0,lr=0.0000,optimizer=ref_ph_7bdb82fb_2023-05-24_08-43-45/error.txt"
tune_function_15db5_00002,1,"/home/jovyan/ray_results/tune_function_2023-05-24_08-43-41/tune_function_15db5_00002_2_batch_size=32,criterion=ref_ph_f6251774,epoch=10,gamma=0,lr=0.0000,optimizer=ref_ph_7bdb82fb_2023-05-24_08-43-45/error.txt"
tune_function_15db5_00003,1,"/home/jovyan/ray_results/tune_function_2023-05-24_08-43-41/tune_function_15db5_00003_3_batch_size=64,criterion=ref_ph_f6251774,epoch=10,gamma=0,lr=0.0004,optimizer=ref_ph_7bdb82fb_2023-05-24_08-43-45/error.txt"
tune_function_15db5_00004,1,"/home/jovyan/ray_results/tune_function_2023-05-24_08-43-41/tune_function_15db5_00004_4_batch_size=128,criterion=ref_ph_f6251774,epoch=10,gamma=0,lr=0.0009,optimizer=ref_ph_7bdb82fb_2023-05-24_08-43-45/error.txt"
tune_function_15db5_00005,1,"/home/jovyan/ray_results/tune_function_2023-05-24_08-43-41/tune_function_15db5_00005_5_batch_size=8,criterion=ref_ph_449a4ea2,epoch=10,gamma=0,lr=0.0002,optimizer=ref_ph_7bdb82fb_2023-05-24_08-43-45/error.txt"
tune_function_15db5_00006,1,"/home/jovyan/ray_results/tune_function_2023-05-24_08-43-41/tune_function_15db5_00006_6_batch_size=16,criterion=ref_ph_449a4ea2,epoch=10,gamma=0,lr=0.0001,optimizer=ref_ph_7bdb82fb_2023-05-24_08-43-45/error.txt"
tune_function_15db5_00007,1,"/home/jovyan/ray_results/tune_function_2023-05-24_08-43-41/tune_function_15db5_00007_7_batch_size=32,criterion=ref_ph_449a4ea2,epoch=10,gamma=0,lr=0.0004,optimizer=ref_ph_7bdb82fb_2023-05-24_08-43-45/error.txt"
tune_function_15db5_00008,1,"/home/jovyan/ray_results/tune_function_2023-05-24_08-43-41/tune_function_15db5_00008_8_batch_size=64,criterion=ref_ph_449a4ea2,epoch=10,gamma=0,lr=0.0001,optimizer=ref_ph_7bdb82fb_2023-05-24_08-43-46/error.txt"
tune_function_15db5_00009,1,"/home/jovyan/ray_results/tune_function_2023-05-24_08-43-41/tune_function_15db5_00009_9_batch_size=128,criterion=ref_ph_449a4ea2,epoch=10,gamma=0,lr=0.0001,optimizer=ref_ph_7bdb82fb_2023-05-24_08-43-49/error.txt"

Trial name,status,loc,batch_size,criterion,epoch,gamma,lr,optimizer
tune_function_15db5_00219,RUNNING,121.160.102.68:1514671,128,CrossEntropyLoss(),10,0,2.81236e-06,<class 'torch.o_d810
tune_function_15db5_00220,RUNNING,121.160.102.68:1514674,8,FocalLoss(),15,0,1.24989e-06,<class 'torch.o_d810
tune_function_15db5_00221,RUNNING,121.160.102.68:1514676,16,FocalLoss(),15,0,0.00127023,<class 'torch.o_d810
tune_function_15db5_00223,RUNNING,121.160.102.68:1514680,64,FocalLoss(),15,0,0.000308983,<class 'torch.o_d810
tune_function_15db5_00224,RUNNING,121.160.102.68:1514681,128,FocalLoss(),15,0,4.74462e-06,<class 'torch.o_d810
tune_function_15db5_00225,RUNNING,121.160.102.68:1515075,8,CrossEntropyLoss(),15,0,1.79367e-05,<class 'torch.o_d810
tune_function_15db5_00226,RUNNING,121.160.102.68:1515077,16,CrossEntropyLoss(),15,0,8.65403e-05,<class 'torch.o_d810
tune_function_15db5_00227,PENDING,,32,CrossEntropyLoss(),15,0,5.86557e-05,<class 'torch.o_d810
tune_function_15db5_00228,PENDING,,64,CrossEntropyLoss(),15,0,4.68248e-07,<class 'torch.o_d810
tune_function_15db5_00229,PENDING,,128,CrossEntropyLoss(),15,0,0.00268388,<class 'torch.o_d810


2023-05-24 08:43:46,205	ERROR trial_runner.py:1450 -- Trial tune_function_15db5_00000: Error happened when processing _ExecutorEventType.TRAINING_RESULT.
ray.exceptions.RayTaskError(RuntimeError): [36mray::ImplicitFunc.train()[39m (pid=1484118, ip=121.160.102.68, repr=tune_function)
  File "/opt/conda/envs/tensor/lib/python3.8/site-packages/ray/tune/trainable/trainable.py", line 384, in train
    raise skipped from exception_cause(skipped)
  File "/opt/conda/envs/tensor/lib/python3.8/site-packages/ray/tune/trainable/function_trainable.py", line 336, in entrypoint
    return self._trainable_func(
  File "/opt/conda/envs/tensor/lib/python3.8/site-packages/ray/tune/trainable/function_trainable.py", line 653, in _trainable_func
    output = fn()
  File "/opt/conda/envs/tensor/lib/python3.8/site-packages/ray/tune/trainable/util.py", line 421, in _inner
    return inner(config, checkpoint_dir=None)
  File "/opt/conda/envs/tensor/lib/python3.8/site-packages/ray/tune/trainable/util.py", line

Trial name,date,hostname,node_ip,pid,timestamp,trial_id
tune_function_15db5_00000,2023-05-24_08-43-44,ubuntu,121.160.102.68,1484118,1684917824,15db5_00000
tune_function_15db5_00001,2023-05-24_08-43-47,ubuntu,121.160.102.68,1484248,1684917827,15db5_00001
tune_function_15db5_00002,2023-05-24_08-43-47,ubuntu,121.160.102.68,1484250,1684917827,15db5_00002
tune_function_15db5_00003,2023-05-24_08-43-47,ubuntu,121.160.102.68,1484252,1684917827,15db5_00003
tune_function_15db5_00004,2023-05-24_08-43-47,ubuntu,121.160.102.68,1484254,1684917827,15db5_00004
tune_function_15db5_00005,2023-05-24_08-43-47,ubuntu,121.160.102.68,1484256,1684917827,15db5_00005
tune_function_15db5_00006,2023-05-24_08-43-47,ubuntu,121.160.102.68,1484258,1684917827,15db5_00006
tune_function_15db5_00007,2023-05-24_08-43-47,ubuntu,121.160.102.68,1484261,1684917827,15db5_00007
tune_function_15db5_00008,2023-05-24_08-43-48,ubuntu,121.160.102.68,1485605,1684917828,15db5_00008
tune_function_15db5_00009,2023-05-24_08-43-52,ubuntu,121.160.102.68,1486211,1684917832,15db5_00009


2023-05-24 08:43:49,457	ERROR trial_runner.py:1450 -- Trial tune_function_15db5_00004: Error happened when processing _ExecutorEventType.TRAINING_RESULT.
ray.exceptions.RayTaskError(RuntimeError): [36mray::ImplicitFunc.train()[39m (pid=1484254, ip=121.160.102.68, repr=tune_function)
  File "/opt/conda/envs/tensor/lib/python3.8/site-packages/ray/tune/trainable/trainable.py", line 384, in train
    raise skipped from exception_cause(skipped)
  File "/opt/conda/envs/tensor/lib/python3.8/site-packages/ray/tune/trainable/function_trainable.py", line 336, in entrypoint
    return self._trainable_func(
  File "/opt/conda/envs/tensor/lib/python3.8/site-packages/ray/tune/trainable/function_trainable.py", line 653, in _trainable_func
    output = fn()
  File "/opt/conda/envs/tensor/lib/python3.8/site-packages/ray/tune/trainable/util.py", line 421, in _inner
    return inner(config, checkpoint_dir=None)
  File "/opt/conda/envs/tensor/lib/python3.8/site-packages/ray/tune/trainable/util.py", line

In [8]:
results

ResultGrid<[
  Result(
    metrics={'loss': 2.0876925903208114, 'accuracy': 0.09375, 'f1': 0.05604647312191422, 'should_checkpoint': True, 'done': True, 'trial_id': '53c40_00000', 'experiment_tag': '0_batch_size=64,criterion=ref_ph_f6251774,epoch=10,gamma=0.2500,lr=0.0000,optimizer=ref_ph_46a9caad'},
    path='/home/jovyan/ray_results/tune_function_2023-05-24_08-38-09/tune_function_53c40_00000_0_batch_size=64,criterion=ref_ph_f6251774,epoch=10,gamma=0.2500,lr=0.0000,optimizer=ref_ph_46a9caad_2023-05-24_08-38-16',
    checkpoint=Checkpoint(local_path=/home/jovyan/ray_results/tune_function_2023-05-24_08-38-09/tune_function_53c40_00000_0_batch_size=64,criterion=ref_ph_f6251774,epoch=10,gamma=0.2500,lr=0.0000,optimizer=ref_ph_46a9caad_2023-05-24_08-38-16/checkpoint_000009)
  )
]>