In [None]:
import torch
from torch import optim
from torch.utils.data import DataLoader
import torch.nn as nn
import os
import shutil
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

import matplotlib
matplotlib.use('Agg')
import numpy as np
import matplotlib.pyplot as plt


from torchvision import transforms
import torch
from torch.utils.data import DataLoader
from torchvision import models

transform =   transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
])

from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd


from helper import train_utils

results_path = 'dataset/results'

if os.path.exists(f'{results_path}'):
    shutil.rmtree(f'{results_path}')

os.mkdir(f'{results_path}')

model = train_utils.get_pretrained_model('simple',10)

N_FFTs = [1024*(i+1) for i in range(8)]
batch_sizes = [16, 32, 64, 128, 256]
n_epochs=50
n_classes = 10
histories = []
for N_FFT in N_FFTs:
    for batch_size in batch_sizes:
        experience_name = f'Melspectogram_Experiment_N_FFT_{N_FFT}_batch_size_{batch_size}'
        graphs_path = f'dataset/graph_classes_{N_FFT}_batch_size_{batch_size}'
        
        print(f"running on {experience_name}")

        # Split the src trainset into train and test
        

        X_train, X_test = train_utils.split_train(graphs_path, 0.8, transform)
        train_dataloader = DataLoader(X_train, batch_size=batch_size, shuffle=True)
        test_dataloader = DataLoader(X_test, batch_size=batch_size, shuffle=True)

        ################################


        model = train_utils.get_pretrained_model('vgg16', n_classes)

        criterion = nn.NLLLoss()
        optimizer = optim.Adam(model.parameters())

        print('Training ...')
        model, history = train_utils.train(
            model = model,
            criterion = criterion,
            optimizer = optimizer,
            train_loader = train_dataloader,
            test_loader = None,# Since I tested it after training to make the confusion matrix
            n_epochs=n_epochs,
            print_log= True)

        # torch.save(model, f'{path}/classes/{p}_{f}/model_{experience_name}') 
        if os.path.exists(f'{results_path}/{N_FFT}_batch_size_{batch_size}'):
            shutil.rmtree(f'{results_path}/{N_FFT}_batch_size_{batch_size}')

        os.mkdir(f'{results_path}/{N_FFT}_batch_size_{batch_size}')
        torch.save(history,f'{results_path}/{N_FFT}_batch_size_{batch_size}/history_{experience_name}') 
        # print(history)
        print('Testing on 20% of trainSet...')
        acc_test_train , y_true_, y_pred_ = train_utils.test(model,test_dataloader, print_log=True)

        # constant for classes
        classes = [i for i in range(10)]#('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
                # 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

        # Build confusion matrix
        cf_matrix = confusion_matrix(y_true_, y_pred_)
        df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index = [i for i in classes],
                            columns = [i for i in classes])
        matplotlib.rcParams.update({'font.size': 10})
        plt.figure(figsize = (12,12))
        sn.heatmap(df_cm, annot=True)
        plt.savefig(f'{results_path}/{N_FFT}_batch_size_{batch_size}/confusion_{experience_name}.svg')
        plt.close()
        # print(f'Loss: {loss}, Accuracy: {acc}')
        #######################################################

        # X_test_for_csv = ImageFolder(f'{path_X_test}/{p}_{f}',transform=transform)
        # X_test_for_csv_dataloader = DataLoader(X_test, batch_size=batch_size, shuffle=True)
        # print('Testing on test...')
        # acc_test_test  = train_utils.test(model,criterion,X_test_for_csv_dataloader)

        print('##############################################\n\n\n')
        with open(f'{results_path}/{N_FFT}_batch_size_{batch_size}/history_{experience_name}.txt', 'w') as fi:
            fi.write(str(history))
            fi.write('##############################################\n\n\n')
            fi.write(f'acc_test_train = {str(acc_test_train)}')
            # fi.write(f'acc_test_test = {str(acc_test_test)}')
            fi.close()

        # train_utils.get_csv_test_labels(f'{path_X_test}/{p}_{f}/{experience_name}.csv', model, X_test_for_csv, X_test_for_csv_dataloader)
        print("===============================")
    print("==============================================================")
    print("==============================================================")
    print("==============================================================")       
