In [1]:
import os
import sys
import glob
import pathlib
import yaml
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchvision
import argparse
from sklearn.model_selection import KFold, train_test_split
from torchvision import transforms
from torchvision.models.resnet import resnet18, resnet50, ResNet50_Weights
from dataloader import MedicalData
from trainer import Trainer
from datetime import datetime
from utils import *
import albumentations as A

In [2]:
cfg_path = './config/config.yml'
cfg = load_config(cfg_path)

In [3]:
df = load_data(cfg["datasets"]["root"])
df = df.sample(frac=1, random_state=42).reset_index(drop=True)

In [4]:
cfg['trainer']['epochs'] = 30

In [5]:
device = check_gpu(cfg['use_gpu'])

In [6]:
model_dir, visual_dir = create_folder()

In [7]:
def setup_dataflow(train_idx, val_idx):
    root = cfg['datasets']['root']
    width = cfg['datasets']['image_width']
    height = cfg['datasets']['image_height']
    batch_size = cfg['datasets']['batch_size']
    classes = pd.read_csv('./class.csv')
    
    

    transform_train = A.Compose([
        A.Resize(224, 224),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(p=0.5),
        A.RandomCrop(p=0.2, height=224, width=224),
        A.Normalize(),
    ])
    
    transform_valid = A.Compose([
        A.Resize(224, 224),
        A.Normalize(),
    ])
    
    train_sampler = MedicalData(root, train_idx, classes, device, transform_train)
    val_sampler = MedicalData(root, val_idx, classes, device, transform_valid)

    train_loader = torch.utils.data.DataLoader(train_sampler, batch_size=batch_size)
    val_loader = torch.utils.data.DataLoader(val_sampler, batch_size=batch_size)

    return train_loader, val_loader

In [8]:
n_splits = 5
kfold = KFold(n_splits, shuffle=True, random_state=42)

In [9]:
def load_model():
    model = resnet50(ResNet50_Weights)
    model.fc = nn.Linear(2048, cfg['model']['num_classes'])
    for idx, child in enumerate(model.children()):
        if idx < 6:
            for param in child.parameters():
                param.requires_grad = False
    return model

In [10]:
torch.cuda.empty_cache()

In [11]:
for i, (train_idx, valid_idx) in enumerate(kfold.split(df)):
    print('===' * 10 + "{}".format(i+1) + "===" * 10)
    df_train = df.iloc[train_idx]
    df_valid = df.iloc[valid_idx]
    train_loader, valid_loader = setup_dataflow(df_train, df_valid)
    test_loader = None
    model = load_model()
    if i >= 4:
        trainer = Trainer(model=model, 
                            config=cfg,
                            train_loader=train_loader,
                            valid_loader=valid_loader,
                            test_loader=test_loader,
                            device=device)

        trainer.train(kfold=i+1, save_dir=model_dir)
        trainer.visualize(visual_dir, i+1)





















EPOCH: 1/30


  0%|          | 0/469 [00:00<?, ?it/s]

Saved!
EPOCH: 2/30


  0%|          | 0/469 [00:00<?, ?it/s]

Saved!
EPOCH: 3/30


  0%|          | 0/469 [00:00<?, ?it/s]

Saved!
EPOCH: 4/30


  0%|          | 0/469 [00:00<?, ?it/s]

Saved!
EPOCH: 5/30


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

              precision    recall  f1-score   support

      ASC_US       0.53      0.60      0.56       682
        LSIL       0.68      0.60      0.63       835
         SCC       0.71      0.82      0.76       373
        HSIL       0.57      0.59      0.58       972
       ASC_H       0.50      0.48      0.49       760
         AIS       0.78      0.38      0.51       128

    accuracy                           0.59      3750
   macro avg       0.63      0.58      0.59      3750
weighted avg       0.59      0.59      0.59      3750

metrics: {'loss': 0.47249152542372874, 'accuracy': 0.5874406779661016, 'f1_scores': 0.5874406779661016, 'precision': 0.5498305084745764, 'mAP': 0.5259491525423727, 'recall': 0.5528728813559322}
Confusion Matrix:
 tensor([[410., 182.,   9.,  35.,  45.,   1.],
        [225., 500.,   7.,  52.,  50.,   1.],
        [  3.,   1., 306.,  48.,  14.,   1.],
        [ 46.,  27.,  67., 570., 254.,   8.],
        [ 89.,  30.,  35., 237., 366.,   3.],
        [  1.,

  0%|          | 0/469 [00:00<?, ?it/s]

Saved!
EPOCH: 7/30


  0%|          | 0/469 [00:00<?, ?it/s]

Saved!
EPOCH: 8/30


  0%|          | 0/469 [00:00<?, ?it/s]

Saved!
EPOCH: 9/30


  0%|          | 0/469 [00:00<?, ?it/s]

Saved!
EPOCH: 10/30


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

              precision    recall  f1-score   support

      ASC_US       0.37      0.83      0.51       682
        LSIL       0.53      0.27      0.36       835
         SCC       0.84      0.64      0.73       373
        HSIL       0.54      0.30      0.39       972
       ASC_H       0.41      0.51      0.45       760
         AIS       0.91      0.16      0.27       128

    accuracy                           0.46      3750
   macro avg       0.60      0.45      0.45      3750
weighted avg       0.52      0.46      0.45      3750

metrics: {'loss': 0.7028813559322035, 'accuracy': 0.46233898305084753, 'f1_scores': 0.46233898305084753, 'precision': 0.4589915254237288, 'mAP': 0.42584745762711845, 'recall': 0.43758474576271195}
Confusion Matrix:
 tensor([[563.,  62.,   2.,  14.,  41.,   0.],
        [530., 228.,   3.,  26.,  48.,   0.],
        [  5.,   6., 239.,  75.,  48.,   0.],
        [162.,  73.,  23., 296., 416.,   2.],
        [239.,  55.,  11.,  64., 391.,   0.],
        [  

  0%|          | 0/469 [00:00<?, ?it/s]

Saved!
EPOCH: 12/30


  0%|          | 0/469 [00:00<?, ?it/s]

Saved!
EPOCH: 13/30


  0%|          | 0/469 [00:00<?, ?it/s]

Saved!
EPOCH: 14/30


  0%|          | 0/469 [00:00<?, ?it/s]

Saved!
EPOCH: 15/30


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

              precision    recall  f1-score   support

      ASC_US       0.51      0.59      0.55       682
        LSIL       0.66      0.49      0.56       835
         SCC       0.73      0.83      0.78       373
        HSIL       0.60      0.42      0.49       972
       ASC_H       0.44      0.68      0.53       760
         AIS       0.76      0.45      0.57       128

    accuracy                           0.56      3750
   macro avg       0.62      0.58      0.58      3750
weighted avg       0.58      0.56      0.56      3750

metrics: {'loss': 0.48384745762711884, 'accuracy': 0.562101694915254, 'f1_scores': 0.562101694915254, 'precision': 0.5460084745762708, 'mAP': 0.5207372881355932, 'recall': 0.5463813559322036}
Confusion Matrix:
 tensor([[403., 173.,   9.,  23.,  74.,   0.],
        [271., 410.,  13.,  40.,  97.,   4.],
        [  0.,   1., 310.,  40.,  22.,   0.],
        [ 40.,  21.,  52., 407., 442.,  10.],
        [ 78.,  19.,  26., 120., 513.,   4.],
        [  0.,  

  0%|          | 0/469 [00:00<?, ?it/s]

In [None]:
cfg_copy = cfg
cfg_copy['trainer']['epochs'] = 1
cfg_copy['trainer']['print_every'] = 1
cfg_copy['trainer']['lr'] = 1e-4

In [None]:
model.load_state_dict(torch.load('./save/save_2023-06-14 10:56:24.952644/model_fold5_epoch4.pt'))
df_train = df.iloc[train_idx]
df_valid = df.iloc[valid_idx]
train_loader, valid_loader = setup_dataflow(df_train, df_valid)
test_loader = None
trainer = Trainer(model=model, 
                    config=cfg_copy,
                    train_loader=train_loader,
                    valid_loader=valid_loader,
                    test_loader=test_loader,
                    device=device)

trainer.train(kfold=i+1, save_dir=model_dir)

RuntimeError: Error(s) in loading state_dict for ResNet:
	Missing key(s) in state_dict: "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer2.0.conv3.weight", "layer2.0.bn3.weight", "layer2.0.bn3.bias", "layer2.0.bn3.running_mean", "layer2.0.bn3.running_var", "layer2.1.conv3.weight", "layer2.1.bn3.weight", "layer2.1.bn3.bias", "layer2.1.bn3.running_mean", "layer2.1.bn3.running_var", "layer2.2.conv1.weight", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.conv2.weight", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.2.conv3.weight", "layer2.2.bn3.weight", "layer2.2.bn3.bias", "layer2.2.bn3.running_mean", "layer2.2.bn3.running_var", "layer2.3.conv1.weight", "layer2.3.bn1.weight", "layer2.3.bn1.bias", "layer2.3.bn1.running_mean", "layer2.3.bn1.running_var", "layer2.3.conv2.weight", "layer2.3.bn2.weight", "layer2.3.bn2.bias", "layer2.3.bn2.running_mean", "layer2.3.bn2.running_var", "layer2.3.conv3.weight", "layer2.3.bn3.weight", "layer2.3.bn3.bias", "layer2.3.bn3.running_mean", "layer2.3.bn3.running_var", "layer3.0.conv3.weight", "layer3.0.bn3.weight", "layer3.0.bn3.bias", "layer3.0.bn3.running_mean", "layer3.0.bn3.running_var", "layer3.1.conv3.weight", "layer3.1.bn3.weight", "layer3.1.bn3.bias", "layer3.1.bn3.running_mean", "layer3.1.bn3.running_var", "layer3.2.conv1.weight", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.conv2.weight", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "layer3.2.conv3.weight", "layer3.2.bn3.weight", "layer3.2.bn3.bias", "layer3.2.bn3.running_mean", "layer3.2.bn3.running_var", "layer3.3.conv1.weight", "layer3.3.bn1.weight", "layer3.3.bn1.bias", "layer3.3.bn1.running_mean", "layer3.3.bn1.running_var", "layer3.3.conv2.weight", "layer3.3.bn2.weight", "layer3.3.bn2.bias", "layer3.3.bn2.running_mean", "layer3.3.bn2.running_var", "layer3.3.conv3.weight", "layer3.3.bn3.weight", "layer3.3.bn3.bias", "layer3.3.bn3.running_mean", "layer3.3.bn3.running_var", "layer3.4.conv1.weight", "layer3.4.bn1.weight", "layer3.4.bn1.bias", "layer3.4.bn1.running_mean", "layer3.4.bn1.running_var", "layer3.4.conv2.weight", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "layer3.4.conv3.weight", "layer3.4.bn3.weight", "layer3.4.bn3.bias", "layer3.4.bn3.running_mean", "layer3.4.bn3.running_var", "layer3.5.conv1.weight", "layer3.5.bn1.weight", "layer3.5.bn1.bias", "layer3.5.bn1.running_mean", "layer3.5.bn1.running_var", "layer3.5.conv2.weight", "layer3.5.bn2.weight", "layer3.5.bn2.bias", "layer3.5.bn2.running_mean", "layer3.5.bn2.running_var", "layer3.5.conv3.weight", "layer3.5.bn3.weight", "layer3.5.bn3.bias", "layer3.5.bn3.running_mean", "layer3.5.bn3.running_var", "layer4.0.conv3.weight", "layer4.0.bn3.weight", "layer4.0.bn3.bias", "layer4.0.bn3.running_mean", "layer4.0.bn3.running_var", "layer4.1.conv3.weight", "layer4.1.bn3.weight", "layer4.1.bn3.bias", "layer4.1.bn3.running_mean", "layer4.1.bn3.running_var", "layer4.2.conv1.weight", "layer4.2.bn1.weight", "layer4.2.bn1.bias", "layer4.2.bn1.running_mean", "layer4.2.bn1.running_var", "layer4.2.conv2.weight", "layer4.2.bn2.weight", "layer4.2.bn2.bias", "layer4.2.bn2.running_mean", "layer4.2.bn2.running_var", "layer4.2.conv3.weight", "layer4.2.bn3.weight", "layer4.2.bn3.bias", "layer4.2.bn3.running_mean", "layer4.2.bn3.running_var". 
	size mismatch for layer1.0.conv1.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 1, 1]).
	size mismatch for layer1.1.conv1.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 256, 1, 1]).
	size mismatch for layer2.0.conv1.weight: copying a param with shape torch.Size([128, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 256, 1, 1]).
	size mismatch for layer2.0.downsample.0.weight: copying a param with shape torch.Size([128, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 256, 1, 1]).
	size mismatch for layer2.0.downsample.1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for layer2.0.downsample.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for layer2.0.downsample.1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for layer2.0.downsample.1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for layer2.1.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 512, 1, 1]).
	size mismatch for layer3.0.conv1.weight: copying a param with shape torch.Size([256, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 1, 1]).
	size mismatch for layer3.0.downsample.0.weight: copying a param with shape torch.Size([256, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 512, 1, 1]).
	size mismatch for layer3.0.downsample.1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for layer3.0.downsample.1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for layer3.0.downsample.1.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for layer3.0.downsample.1.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for layer3.1.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]).
	size mismatch for layer4.0.conv1.weight: copying a param with shape torch.Size([512, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 1024, 1, 1]).
	size mismatch for layer4.0.downsample.0.weight: copying a param with shape torch.Size([512, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([2048, 1024, 1, 1]).
	size mismatch for layer4.0.downsample.1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).
	size mismatch for layer4.0.downsample.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).
	size mismatch for layer4.0.downsample.1.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).
	size mismatch for layer4.0.downsample.1.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).
	size mismatch for layer4.1.conv1.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 2048, 1, 1]).
	size mismatch for fc.weight: copying a param with shape torch.Size([6, 512]) from checkpoint, the shape in current model is torch.Size([6, 2048]).

In [None]:
df_train.to_csv('./data/train.csv')
df_valid.to_csv('./data/valid.csv')