In [None]:
from __future__ import print_function, division

import os
import math

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.utils.class_weight import compute_class_weight

import matplotlib.pyplot as plt
from matplotlib.pyplot import bar
# from tqdm.notebook import tqdm as tqdm

### Custom Classes
from core.hierarchicalCrossEntropyLoss import hierarchicalCrossEntropyLoss as h_loss
from core.trainer import train_model
from utils import display_funcs as disp_funcs
from utils import data_convert_funcs as convert_funcs

# plt.ion()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device:', device)

import datetime
now = datetime.datetime.now()
save_name = 'hierarchicalSELoss'
writer = SummaryWriter('./runs/{}'.format(save_name))

In [None]:
resolution = 224
batch_size = 24
num_epochs = 50

data_dir = '../data/customDataSet'
save_model_path = './trained/{}'.format(save_name)
os.makedirs(save_model_path, exist_ok=True)

# model_type = 'ResNet50'
model_type = 'MobileNet_v2'
purpose = 'Regression'
# loss_type = 'MSE'
purpose = 'Classification'
# loss_type = 'CrossEntropy'
loss_type = 'HierarchicalCrossEntropy'
opt_type = 'Adam'
# use_sample_weights = False
use_sample_weights = True
# use_scheduler = True
use_scheduler = False
# use_finetuning = True
use_finetuning = True

hierarchy_dict = {'0': 'Normal Multiclass Classification',
                  '1': {'0': [1, 1, 1, 1, 0, 0, 0, 0, 0],
                        '1': [0, 0, 0, 0, 1, 1, 1, 1, 1]},
                  '2': {'0': [1, 1, 0, 0, 0, 0, 0, 0, 0],
                        '1': [0, 0, 1, 1, 0, 0, 0, 0, 0],
                        '2': [0, 0, 0, 0, 1, 1, 1, 0, 0],
                        '3': [0, 0, 0, 0, 0, 0, 0, 1, 1]}
                  }

hierarchy_label_spilitters = []
hierarchy_label_spilitters.append(convert_funcs.convertHierarchyDict2labelSplitters(hierarchy_dict['1']))
hierarchy_label_spilitters.append(convert_funcs.convertHierarchyDict2labelSplitters(hierarchy_dict['2']))

coefficient = [0.50, 0.10, 0.40] 
    
# reserve
class_weights = []

class_names = None
class_num = None
losses = {'train':[], 'val':[]}
accs = {'train':[], 'val':[]}

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(resolution),
        transforms.RandomHorizontalFlip(),
#         transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0, hue=0.05),
#         transforms.RandomAffine(degrees=5, translate=(0,0), scale=(0.8, 1.2)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # For Imagenet Pretrained model
#         transforms.RandomErasing(),
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(resolution),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # For Imagenet Pretrained model
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(resolution),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # For Imagenet Pretrained model
    ]),
}

In [None]:
labels_list = {}

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                     transform=data_transforms[x])
                  for x in ['train', 'val', 'test']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
                                              shuffle=True, num_workers=4)
               for x in ['train', 'val', 'test']}

dataset_size = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
class_names = image_datasets['train'].classes
class_num = len(class_names)

for x in ['train']:#, 'val']:
    labels_list[x] = disp_funcs.show_data_histogram('{}/{}'.format(data_dir, x), '{} data'.format(x), show=True) 
    print(f'training data num is {len(labels_list["train"])}')

In [None]:
num_classes = class_num if purpose == 'Classification' else 1
if model_type == 'ResNet50':
    model = models.resnet50(pretrained=True if use_finetuning == True else False)
    model.fc = nn.Linear(in_features=2048, out_features= num_classes)
elif model_type == 'ResNet152':
    model = models.resnet152(pretrained=True if use_finetuning == True else False)
    model.fc = nn.Linear(in_features=2048, out_features= num_classes)
elif model_type == 'MobileNet_v2':
    model = model = torch.hub.load('pytorch/vision:v0.5.0', 'mobilenet_v2', pretrained=True)
    model.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(model.last_channel, num_classes),
        )

model.train()
print('Chose {}'.format(model_type), 'Network setting is completed!')

In [None]:
if use_sample_weights == True:
    sample_weights = compute_class_weight(class_weight='balanced', 
                                          classes=np.unique(labels_list['train']),
                                          y=labels_list['train'])
    sample_weights = torch.FloatTensor(sample_weights).to(device)
    print(sample_weights)
    if loss_type == 'HierarchicalCrossEntropy':
        sample_weights = [sample_weights]
        for i in range(len(coefficient)-1):
            temp_classes = convert_funcs.convertClass2HierarchicalClass(labels_list['train'], hierarchy_label_spilitters[i])
            temp_weights = compute_class_weight(class_weight='balanced', 
                                                classes=np.unique(temp_classes),
                                                y=temp_classes)
            temp_weights = convert_funcs.convertClassWeights2HierarchicalClassWeights(temp_weights, hierarchy_label_spilitters[i])
            temp_weights = torch.FloatTensor(temp_weights).to(device)
            sample_weights.append(temp_weights)
            
if loss_type == 'CrossEntropy':
    criterion = torch.nn.CrossEntropyLoss(weight= None if use_sample_weights != True else sample_weights)
elif loss_type == 'HierarchicalCrossEntropy':
    criterion = h_loss(coefficient, hierarchy_dict, 
                       weight= None if use_sample_weights != True else sample_weights, device=device)
elif loss_type == 'MSE':
    criterion = torch.nn.MSELoss()

In [None]:
if opt_type == 'SGD':
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)        
elif opt_type == 'Adam':
    optimizer = optim.Adam(model.parameters())

In [None]:
if use_scheduler == True:
    scheduler = lr_scheduler.StepLR(optimizer, step_size=2.4, gamma=0.97)
else:
    scheduler = None

In [None]:
# model

In [None]:
model = train_model(model, dataloaders, 
                    purpose, criterion, optimizer, scheduler, 
                    num_epochs, losses, accs, 
                    save_model_path, device, writer)

# Results (Run belows again after training finished)

In [None]:
from utils import analyze_result_funcs as ar

analyzer = ar.show_results(class_num=class_num, device=device, model=model, 
                           trained_model_name=save_name+'best_model.pth')
analyzer.calc_confusion_matrix(dataset=image_datasets['test'], purpose=purpose)

In [None]:
analyzer.calc_classification_report(dataset=image_datasets['test'], purpose =purpose, batch_size=8)

In [None]:
from utils import analyze_result_funcs as ar
# import importlib
# importlib.reload(ar)

analyzer = ar.show_results(class_num=class_num, device=device, model=model, 
                           trained_model_name=save_name+'best_model.pth')
analyzer.calc_confusion_matrix(dataset=image_datasets['val'], purpose=purpose)

In [None]:
analyzer.calc_classification_report(dataset=image_datasets['val'], purpose =purpose, batch_size=8)