In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, utils

import cv2
import pickle
import pathlib
import random
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from glob import glob
from tqdm import tqdm
from sklearn.metrics import confusion_matrix

import segmentation_models_pytorch as smp
from torchsampler import ImbalancedDatasetSampler
from Hard_set import Chest_Single_Data_Generator
from multi_preprocessing import *

if not torch.cuda.is_available():
    raise Exception('torch cuda is not available')

SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

DEVICE = 'cuda'
EPOCHS = 100
IMG_SIZE = (1024, 1024)

LOAD_PATH = './checkpoints/11_support_device/best_model.pth'
SAVE_PATH = './result/11_support_device'
pathlib.Path(SAVE_PATH).mkdir(exist_ok=True, parents=True)

# Data Loader

In [None]:
with open("./Data/support_device.pickle","rb") as fw:
    load_paths = pickle.load(fw)

transform_train = transforms.Compose([
Gamma_2D(),
Shift_2D(),
Rotation_2D(),
RandomSharp(),
RandomBlur(),
RandomNoise(),
ToTensor(),
])

transform_test = transforms.Compose([
    ToTensor(),
])

trainset = Chest_Single_Data_Generator(IMG_SIZE, \
                                       load_paths['train']['image'], \
                                       load_paths['train']['mask'], \
                                       load_paths['train']['label'], \
                                       transform=transform_train)

testset = Chest_Single_Data_Generator(IMG_SIZE, \
                                      load_paths['test']['image'], \
                                      load_paths['train']['mask'], \
                                      load_paths['test']['label'], \
                                      transform=transform_test)

trainloader = DataLoader(trainset, batch_size=4, shuffle=False, num_workers=2, sampler=ImbalancedDatasetSampler(trainset))
testloader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=2)

# Training

In [None]:
if LOAD_PATH is not None:
    model = torch.load(LOAD_PATH, map_location=torch.device('cpu'))
    model.classification_head[4] = nn.Softmax()
    print("PATH ==> {}".format(LOAD_PATH))
    print("Model Load!!!")
    for i, child in enumerate(model.children()):
        if i == 1 or i == 2:
            for param in child.parameters():
                param.requires_grad = False

elif LOAD_PATH is not None:
    ENCODER = 'densenet121'
    ENCODER_WEIGHTS = None
    CLASSES = 1
    ACTIVATION = 'sigmoid'
    DEVICE = 'cuda'

    aux_params=dict(
        pooling='avg',
        dropout=0.5,
        activation='softmax',
        classes=3,
    )
    
    model = smp.Unet(ENCODER, classes=CLASSES, aux_params=aux_params, in_channels=1, activation=ACTIVATION, encoder_weights=ENCODER_WEIGHTS)
    for i, child in enumerate(model.children()):
        if i == 1 or i == 2:
            for param in child.parameters():
                param.requires_grad = False
    summary(model, torch.rand((1, 1, 1024, 1024)))

else:
    raise

model.cuda()

# classification, reconstruction 같이 학습

In [None]:
# # c_loss = smp.utils.losses.BCEWithLogitsLoss(weight=class_weight)
# # c_metrics = [
# #     smp.utils.metrics.ML_Accuracy(threshold=0.5),
# # ]

# c_loss = smp.utils.losses.CrossEntropyLoss()
# c_metrics = [
#     smp.utils.metrics.Accuracy(threshold=0.5),
# ]

# s_loss = smp.utils.losses.MSELoss()
# s_metrics = [
#     smp.utils.metrics.IoU(threshold=0.5),
# ]

# optimizer = torch.optim.Adam([ 
#     dict(params=model.parameters(), lr=0.0001,  weight_decay=5e-4),
# ])

# ## scheduler
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

# train_epoch = smp.utils.trainaux.TrainEpoch(
#     model,
#     c_loss=c_loss,
#     s_loss=s_loss,
#     c_metrics=c_metrics,
#     s_metrics=s_metrics,
#     optimizer=optimizer,
#     device=DEVICE,
#     verbose=True,
# )

# valid_epoch = smp.utils.trainaux.ValidEpoch(
#     model,
#     c_loss=c_loss,
#     s_loss=s_loss,
#     c_metrics=c_metrics,
#     s_metrics=s_metrics,
#     device=DEVICE,
#     verbose=True,
# )

# classification만 학습

In [None]:
loss = smp.utils.losses.CrossEntropyLoss()
metrics = [
    smp.utils.metrics.Accuracy(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001,  weight_decay=5e-4),
])

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

train_epoch = smp.utils.train_classification.TrainEpoch(
    model,
    loss=loss,
    metrics=metrics,
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train_classification.ValidEpoch(
    model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
    verbose=True,
)

In [None]:
max_score = 1000000.

train_logs = []
valid_logs = []

save_path = os.path.join(SAVE_PATH, 'weight')
pathlib.Path(save_path).mkdir(exist_ok=True, parents=True)

log_save_path = os.path.join(SAVE_PATH, 'logs')
pathlib.Path(log_save_path).mkdir(exist_ok=True, parents=True)

for i in range(EPOCHS):
    print('\nEpoch: {}, LR: {}'.format(i, optimizer.param_groups[0]['lr']))
    train_log = train_epoch.run(trainloader)
    valid_log = valid_epoch.run(testloader)
    
    train_logs.append(train_log)
    valid_logs.append(valid_log)
    if max_score > valid_log['cross_entropy_loss']:
        max_score = valid_log['cross_entropy_loss']
        torch.save(model, (save_path + '/best_model_{}.pth').format(str(i)))
        print('*******************   Best saved!   *******************')
        
    torch.save(model, save_path + '/last_model.pth')
    print('Model saved!')
    
    if (i+1) % 50 == 0:
        lr = optimizer.param_groups[0]['lr']
        optimizer.param_groups[0]['lr'] = lr*0.5
        print('Decrease decoder learning rate to {}'.format(optimizer.param_groups[0]['lr']))
    with open(log_save_path+'/train_log.pickle', 'wb') as fw:
        pickle.dump(train_logs, fw)
    with open(log_save_path+'/valid_log.pickle', 'wb') as fw:
        pickle.dump(valid_logs, fw)

# Test Last Version

In [None]:
best_path = glob(os.path.join(SAVE_PATH,'weight/*.pth' ))[-1]
best_model = torch.load(best_path)
best_model.eval()
best_model.cuda()

labels = []
pred_labels = []
label = ['normal', 'support device', 'others']

for sample in iter(testloader):
    pre = best_model(sample[0].cuda())
    labels.append(sample[2].item())
    pred_labels.append(torch.argmax(pre[1]).item())
    
confusion_matrix_path = os.path.join(SAVE_PATH, 'confusion_matrix')
pathlib.Path(confusion_matrix_path).mkdir(exist_ok=True, parents=True)

save_path = confusion_matrix_path + '/confusion_matrix.png'
plt.title('Confusion Matrix')
conf_matrix = confusion_matrix(labels, pred_labels)
heatmap = sns.heatmap(conf_matrix, cmap="Blues", annot=True, fmt='g', xticklabels=label, yticklabels=label)
plt.xlabel('predicted value')
plt.ylabel('true value')
figure = heatmap.get_figure()
figure.savefig(save_path,  bbox_inches='tight', pad_inches=0, dpi=400)
plt.close()