In [None]:
import os
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import seaborn as sn
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import ResNetModel as models
from dataloader import LeukemiaLoader
from sklearn.metrics import confusion_matrix

def train(train_dataloader, test_dataloader, model, model_name):
    
    best_model_wts = {model_name:None}
    
    dfAcc = pd.DataFrame()
    dfAcc['epoch'] = range(1, epochs+1)
    dfloss = pd.DataFrame()
    dfloss['epoch'] = range(1, epochs+1)
    
    best_acc = 0.0
    train_acc, test_acc, loss_epoch = [], [], []
    
    """ 先記錄ground_truth """
    labels, prediction = [], []
    for _, label in test_dataloader:
        labels.extend(label.detach().cpu().numpy().tolist())
        
    for epoch in range(epochs):
        total_loss = 0.0
        correct = 0.0 
#         optimizer = optim.SGD(model.parameters(), lr = lr, momentum = 0.9, weight_decay = 5e-4)
        optimizer = optim.Adam(model.parameters(), lr = lr)
        Loss_func = nn.CrossEntropyLoss()
        model.train()
        
        """ training state"""
        for i, (data, label) in enumerate(tqdm(train_dataloader)):
            data = data.to(device, dtype = torch.float)
            label = label.to(device, dtype = torch.long)
            torch.cuda.empty_cache()
            
            predicts = model(data)
            loss = Loss_func(predicts, label)
            total_loss += loss.item()
            correct += predicts.max(dim=1)[1].eq(label).sum().item()
            
            optimizer.zero_grad()   # gradient清空
            loss.backward()      # calculate gradient
            optimizer.step()     # update loss
            
        total_loss /= len(train_dataloader.dataset)
        correct = 100.0 * correct / len(train_dataloader.dataset)
        train_acc.append(correct)
        loss_epoch.append(total_loss)
        
        """ testing state"""
        model.eval()
        tcorrect = 0.0
        pred = []
        
        for i, (data, label) in enumerate(tqdm(test_dataloader)):
            with torch.no_grad():  # don't need gradient
                data = data.to(device, dtype = torch.float)
                label = label.to(device, dtype = torch.long)
                predicts = model(data)
                _, pred_num = torch.max(predicts, 1)
                pred.extend(pred_num.detach().cpu().numpy().tolist())
                tcorrect += predicts.max(dim=1)[1].eq(label).sum().item()
                
        tcorrect = 100.0 * tcorrect / len(test_dataloader.dataset)
        test_acc.append(tcorrect)
          
        """ output train and test results """
        if epoch % 10 == 0:
            print('Training epoch: %d / loss: %.3f | acc: %.3f' %(epoch, total_loss, correct))
            print('Test acc: %.3f' % tcorrect)
            
        """ update best model and accuracy """
        if tcorrect > best_acc:
            best_model_wts[model_name] = copy.deepcopy(model.state_dict())
            best_acc = tcorrect
            prediction = pred
    
    torch.cuda.empty_cache()
    
    """ save return parameter """
    dfAcc[model_name + '_train'] = train_acc
    dfAcc[model_name + '_test'] = test_acc
    dfloss[model_name + '_loss'] = loss_epoch
    
    plot_confusion_matrix(labels, prediction, title = f'Confusion Matrix ({model_name})', filename = f'{model_name}.png')
        
    return dfAcc, dfloss, best_model_wts 

def plot_confusion_matrix(y_true, y_pred, classes=[0, 1], title = None, 
                          cmap = plt.cm.Blues, filename = None):
    
    figure = confusion_matrix(y_true, y_pred, labels = classes, normalize = 'true')
    fig, ax = plt.subplots()
    sn.heatmap(figure, annot = True, ax = ax, cmap = cmap, fmt = '.2f')
    ax.set_xlabel('Predicted label')
    ax.set_ylabel('True label')
    plt.title(title)
    plt.savefig(filename, dpi = 300)
    plt.close()

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
    
    batch_size = 32
    lr = 1e-3
    epochs = 80
    
    """ load dataset """
    train_dataloader = DataLoader(LeukemiaLoader('', 'train'), batch_size = batch_size, shuffle = True)
    test_dataloader = DataLoader(LeukemiaLoader('', 'valid'), batch_size = batch_size, shuffle = False)
    
    """ model create 1 """
    model = models.ResNet18()
    model = model.to(device)
    all_acc, train_loss, best_model_wts = train(train_dataloader, test_dataloader, model, 'ResNet18')
    
    for name, model_wts in best_model_wts.items():
        torch.save(model_wts, './'+ name +'.pt')
    
    """ show results """
    plt.figure(figsize = (10,6))
    for name in all_acc.columns[1:]:
        plt.plot('epoch', name, data = all_acc)
    plt.title('Accuracy Curve')
    plt.ylabel('Accuracy (%)'), plt.xlabel('Epoch')
    plt.legend()
    plt.savefig('ResNet18 result.png')
    
    plt.figure(figsize = (10,6))
    for name in train_loss.columns[1:]:
        plt.plot('epoch',name, data = train_loss)
    plt.title('Loss Curve')
    plt.ylabel('loss'), plt.xlabel('Epoch')
    plt.legend()
    plt.savefig('ResNet18_train_loss.png')
    
    print()
    for column in all_acc.columns[1:]:
        print(f'{column} best acc: {all_acc[column].max()}')
        
    """ model create 2 """
#     model = models.ResNet50()
#     model = model.to(device)
#     all_acc2, train_loss2, best_model_wts2 = train(train_dataloader, test_dataloader, model, 'ResNet50')
    
#     for name, model_wts in best_model_wts2.items():
#         torch.save(model_wts, './'+ name +'.pt')
    
#     """ show results """
#     plt.figure(figsize = (10,6))
#     for name in all_acc2.columns[1:]:
#         plt.plot('epoch', name, data = all_acc2)
#     plt.title('Accuracy Curve')
#     plt.ylabel('Accuracy (%)'), plt.xlabel('Epoch')
#     plt.legend()
#     plt.savefig('ResNet50 result.png')
    
#     plt.figure(figsize = (10,6))
#     for name in train_loss2.columns[1:]:
#         plt.plot('epoch',name, data = train_loss2)
#     plt.title('Loss Curve')
#     plt.ylabel('loss'), plt.xlabel('Epoch')
#     plt.legend()
#     plt.savefig('ResNet50_train_loss.png')
    
#     print()
#     for column in all_acc2.columns[1:]:
#         print(f'{column} best acc: {all_acc2[column].max()}')
    
#     torch.cuda.empty_cache() 
    
    """ model create 3 """
#     model = models.ResNet152()
#     model = model.to(device)
#     all_acc3, train_loss3, best_model_wts3 = train(train_dataloader, test_dataloader, model, 'ResNet152')
    
#     for name, model_wts in best_model_wts3.items():
#         torch.save(model_wts, './'+ name +'.pt')
    
#     """ show results """
#     plt.figure(figsize = (10,6))
#     for name in all_acc3.columns[1:]:
#         plt.plot('epoch', name, data = all_acc3)
#     plt.title('Accuracy Curve')
#     plt.ylabel('Accuracy (%)'), plt.xlabel('Epoch')
#     plt.legend()
#     plt.savefig('ResNet152 result.png')
    
#     plt.figure(figsize = (10,6))
#     for name in train_loss3.columns[1:]:
#         plt.plot('epoch',name, data = train_loss3)
#     plt.title('Loss Curve')
#     plt.ylabel('loss'), plt.xlabel('Epoch')
#     plt.legend()
#     plt.savefig('ResNet152_train_loss.png')
    
#     print()
#     for column in all_acc3.columns[1:]:
#         print(f'{column} best acc: {all_acc3[column].max()}')
        
#     torch.cuda.empty_cache() 