In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import copy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
from DL_lab2_EEGclassification_model import EEGNet, DeepConvNet
from dataloader import read_bci_data

def train(train_dataloader, test_dataloader, model_name):
    if model_name == 'EEGNet':
        models = {
            'EEGNet_ReLU': EEGNet(nn.ReLU()).to(device),
            'EEGNet_LeakyReLU': EEGNet(nn.LeakyReLU()).to(device),
            'EEGNet_ELU': EEGNet(nn.ELU()).to(device)
        }
        best_model_wts = {'EEGNet_ReLU':None,'EEGNet_LeakyReLU':None,'EEGNet_ELU':None}
    else:
        models = {
            'DeepConvNet_ReLU': DeepConvNet(nn.ReLU()).to(device),
            'DeepConvNet_LeakyReLU': DeepConvNet(nn.LeakyReLU()).to(device),
            'DeepConvNet_ELU': DeepConvNet(nn.ELU()).to(device)
        }
        best_model_wts = {'DeepConvNet_ReLU':None,'DeepConvNet_LeakyReLU':None,'DeepConvNet_ELU':None}
    
    dfAcc = pd.DataFrame()
    dfAcc['epoch'] = range(1, epochs+1)
    dfloss = pd.DataFrame()
    dfloss['epoch'] = range(1, epochs+1)
    
    for name, model in models.items():
        best_acc = 0.0
        train_acc, test_acc, loss_epoch = [], [], []
        for epoch in range(epochs):
            total_loss = 0.0
            correct = 0.0
            optimizer = optim.Adam(model.parameters(), lr = lr)
            Loss_func = nn.CrossEntropyLoss()
            model.train()
#             print('epoch: ' + str(epoch + 1) + ' / ' + str(epochs))
            
            """ training state"""
            for i, (data, label) in enumerate(train_dataloader):
                data = data.to(device, dtype = torch.float)
                label = label.to(device, dtype = torch.long)   # torch.LongTensor ==> 64位元
                optimizer.zero_grad()   # gradient清空
                
                predicts = model(data)
                loss = Loss_func(predicts, label)
                total_loss += loss.item()
                correct += predicts.max(dim=1)[1].eq(label).sum().item()
                
                
                loss.backward()      # calculate gradient
                optimizer.step()     # update loss

            total_loss /= len(train_dataloader.dataset)
            correct = 100.0 * correct / len(train_dataloader.dataset)
            train_acc.append(correct)
            loss_epoch.append(total_loss)
            
            """ testing state"""
            model.eval()
            tcorrect = 0.0
            
            for i, (data, label) in enumerate(test_dataloader):
                with torch.no_grad():  # don't need gradient
                    data = data.to(device, dtype = torch.float)
                    label = label.to(device, dtype = torch.long)
                    predicts = model(data)  
                    tcorrect += predicts.max(dim=1)[1].eq(label).sum().item()  
                    
            tcorrect = 100.0 * tcorrect / len(test_dataloader.dataset)
            test_acc.append(tcorrect)
            
            """ output train and test results """
            if epoch%10 == 0:
                print('Training epoch: %d / loss: %.3f | acc: %.3f' %(epoch, total_loss, correct))
                print('Test acc: %.3f' % tcorrect)
            
            """ update best model and accuracy """
            if tcorrect > best_acc:
                best_model_wts[name] = copy.deepcopy(model.state_dict())
                best_acc = tcorrect
        
        """ save return parameter """
        dfAcc[name+'_train'] = train_acc
        dfAcc[name+'_test'] = test_acc
        dfloss[name+'_loss'] = loss_epoch

    return dfAcc, dfloss, best_model_wts 

# Batch size= 64 Learning rate = 1e-2 Epochs = 300
# Optimizer: AdamLoss function: torch.nn.CrossEntropyLoss()
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # print(device)
    batch_size = 64
    lr = 0.001
    epochs = 300
    
    """ load dataset """
    X_train, y_train, X_test, y_test = read_bci_data()
    dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train))
    train_dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True)
    
    dataset = TensorDataset(torch.Tensor(X_test),torch.Tensor(y_test))
    test_dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = False)

    """ show dataset """
#     randi=int(np.random.randint(0,X_train.shape[0],1))
#     print(f'sample_id:{randi}')
#     plt.figure(figsize=(10,2))
#     plt.plot(X_train[randi,0,0])
#     plt.figure(figsize=(10,2))
#     plt.plot(X_train[randi,0,1])
    
    """ model create 1 """
    EEG_acc, EEG_train_loss, best_model_wts = train(train_dataloader, test_dataloader, 'EEGNet')

    for name, model_wts in best_model_wts.items():
        torch.save(model_wts, './'+ name +'.pt')
    
    """ show results """
    plt.figure(figsize = (10,6))
    for name in EEG_acc.columns[1:]:
        plt.plot('epoch', name, data = EEG_acc)
    plt.title('Activation function comparison (EEGNet)')
    plt.ylabel('Accuracy (%)'), plt.xlabel('Epoch')
    plt.legend()
    plt.savefig('eeg result.png')
    
    plt.figure(figsize = (10,6))
    for name in EEG_train_loss.columns[1:]:
        plt.plot('epoch',name, data = EEG_train_loss)
    plt.title('Activation function comparison (EEGNet)')
    plt.ylabel('loss'), plt.xlabel('Epoch')
    plt.legend()
    plt.savefig('EEG_train_loss.png')
    
    print()
    for column in EEG_acc.columns[1:]:
        print(f'{column} best acc: {EEG_acc[column].max()}')
    
    """ model create 2 """
    print("\n Doing DeepConvNet model \n")
    DEEP_acc, DEEP_train_loss, DEEP_best_model_wts = train(train_dataloader, test_dataloader, 'DeepConvNet')
    
    for name, model_wts in DEEP_best_model_wts.items():
        torch.save(model_wts, './'+ name +'.pt')

    plt.figure(figsize = (10,6))
    for name in DEEP_acc.columns[1:]:
        plt.plot('epoch', name, data = DEEP_acc)
    plt.title('Activation function comparison (DeepConvNet)')
    plt.ylabel('Accuracy (%)'), plt.xlabel('Epoch')
    plt.legend()
    plt.savefig('DEEP result.png')
    
    plt.figure(figsize = (10,6))
    for name in DEEP_train_loss.columns[1:]:
        plt.plot('epoch',name, data = DEEP_train_loss)
    plt.title('Activation function comparison (DeepConvNet)')
    plt.ylabel('loss'), plt.xlabel('Epoch')
    plt.legend()
    plt.savefig('DEEP_train_loss.png')
    
    print()
    for column in DEEP_acc.columns[1:]:
        print(f'{column} best acc: {DEEP_acc[column].max()}')
