In [1]:
import os
import sys
import glob
import pathlib
import yaml
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchvision
import argparse
from sklearn.model_selection import KFold, train_test_split
from torchvision import transforms
from torchvision.models.resnet import resnet18
from dataloader import MedicalData
from trainer import Trainer
from datetime import datetime
from utils import *

In [2]:
cfg_path = './config/config.yml'
cfg = load_config(cfg_path)

In [3]:
df = load_data(cfg["datasets"]["root"])
df = df.sample(frac=1, random_state=42).reset_index(drop=True)

In [4]:
cfg['trainer']['epochs'] = 10


In [5]:
device = check_gpu(cfg['use_gpu'])

In [6]:
model_dir, visual_dir = create_folder()

In [7]:
def setup_dataflow(train_idx, val_idx):
    root = cfg['datasets']['root']
    width = cfg['datasets']['image_width']
    height = cfg['datasets']['image_height']
    batch_size = cfg['datasets']['batch_size']
    classes = pd.read_csv('./class.csv')
    
    
    transform_train = transforms.Compose([
        transforms.RandomRotation(30),
        transforms.GaussianBlur(3),
        transforms.RandomHorizontalFlip(0.3),
        transforms.Resize((height, width)),
        transforms.ToTensor(),
        transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
        transforms.RandomErasing()
    ])
    
    transform_valid = transforms.Compose([
        transforms.Resize((height, width)),
        transforms.ToTensor(),
        transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262])
    ])
    train_sampler = MedicalData(root, train_idx, classes, device, transform_valid)
    val_sampler = MedicalData(root, val_idx, classes, device, transform_valid)

    train_loader = torch.utils.data.DataLoader(train_sampler, batch_size=batch_size)
    val_loader = torch.utils.data.DataLoader(val_sampler, batch_size=batch_size)

    return train_loader, val_loader

In [8]:
n_splits = 5
kfold = KFold(n_splits, shuffle=True, random_state=42)

In [None]:
for i, (train_idx, valid_idx) in enumerate(kfold.split(df)):
    print('===' * 10 + "{}".format(i+1) + "===" * 10)
    df_train = df.iloc[train_idx]
    df_valid = df.iloc[valid_idx]
    train_loader, valid_loader = setup_dataflow(df_train, df_valid)
    test_loader = None
    model = resnet18()
    model.fc = nn.Linear(512, cfg['model']['num_classes'])
    trainer = Trainer(model=model, 
                        config=cfg,
                        train_loader=train_loader,
                        valid_loader=valid_loader,
                        test_loader=test_loader,
                        device=device)

    trainer.train(kfold=i+1, save_dir=model_dir)
    trainer.visualize(visual_dir, i+1)

In [19]:
cfg_copy = cfg
cfg_copy['trainer']['epochs'] = 1
cfg_copy['trainer']['print_every'] = 1
cfg_copy['trainer']['lr'] = 1e-4

In [20]:
model.load_state_dict(torch.load('./save/save_2023-06-14 10:56:24.952644/model_fold5_epoch4.pt'))
df_train = df.iloc[train_idx]
df_valid = df.iloc[valid_idx]
train_loader, valid_loader = setup_dataflow(df_train, df_valid)
test_loader = None
trainer = Trainer(model=model, 
                    config=cfg_copy,
                    train_loader=train_loader,
                    valid_loader=valid_loader,
                    test_loader=test_loader,
                    device=device)

trainer.train(kfold=i+1, save_dir=model_dir)

EPOCH: 1/1


loss: 1.1280405521392822: 100%|███████████████████████████████████| 469/469 [01:35<00:00,  4.92it/s]
100%|█████████████████████████████████████████████████████████████| 118/118 [00:16<00:00,  7.29it/s]


              precision    recall  f1-score   support

      ASC_US       0.50      0.55      0.53       682
        LSIL       0.54      0.51      0.53       835
         SCC       0.63      0.78      0.70       373
        HSIL       0.58      0.56      0.57       972
       ASC_H       0.49      0.50      0.49       760
         AIS       0.67      0.11      0.19       128

    accuracy                           0.54      3750
   macro avg       0.57      0.50      0.50      3750
weighted avg       0.55      0.54      0.54      3750

Avg loss: 1.0773418567328057
metrics: {'loss': 1.0428050847457624, 'accuracy': 0.543771186440678, 'f1_scores': 0.543771186440678, 'precision': 0.47533898305084743, 'mAP': 0.4856271186440677, 'recall': 0.4918389830508474}
Confusion Matrix:
 tensor([[377., 217.,   6.,  30.,  52.,   0.],
        [273., 429.,  17.,  32.,  84.,   0.],
        [ 20.,   8., 290.,  44.,  10.,   1.],
        [ 36.,  58.,  80., 548., 246.,   4.],
        [ 46.,  80.,  32., 222., 

In [21]:
df_train.to_csv('./data/train.csv')
df_valid.to_csv('./data/valid.csv')