In [None]:
!pip install mne pyriemann geoopt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting mne
  Downloading mne-1.1.1-py3-none-any.whl (7.5 MB)
[K     |████████████████████████████████| 7.5 MB 4.3 MB/s 
[?25hCollecting pyriemann
  Downloading pyriemann-0.3.tar.gz (365 kB)
[K     |████████████████████████████████| 365 kB 51.5 MB/s 
[?25hCollecting geoopt
  Downloading geoopt-0.5.0-py3-none-any.whl (90 kB)
[K     |████████████████████████████████| 90 kB 7.9 MB/s 
Building wheels for collected packages: pyriemann
  Building wheel for pyriemann (setup.py) ... [?25l[?25hdone
  Created wheel for pyriemann: filename=pyriemann-0.3-py2.py3-none-any.whl size=78033 sha256=6558be42809a5738280a05611c36abcf5da877bc758a5e19f61adfca83e66245
  Stored in directory: /root/.cache/pip/wheels/0b/1b/bf/a537f9e17e6c3490004ede419c72f863af1d0d765d25e532ef
Successfully built pyriemann
Installing collected packages: pyriemann, mne, geoopt
Successfully installed geoopt-0.5.0 mne-1.1.1 py

In [None]:
import time
import pandas as pd
import numpy as np



#import torch and sklearn
from torch.autograd import Variable
import torch.nn.functional as F
import torch as th
from torch.utils.data.sampler import SubsetRandomSampler
import torch.utils.data
from sklearn.model_selection import StratifiedShuffleSplit


#import util folder
from Geometric_Methods.utils.model import Tensor_CSPNet_Basic
from Geometric_Methods.utils.early_stopping import EarlyStopping
from Geometric_Methods.utils.load_data import load_KU, load_BCIC, dataloader_in_main
from Geometric_Methods.utils.args import args_parser
import Geometric_Methods.utils.geoopt as geoopt

In [None]:
import argparse

def args_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument('--alg_name', default = 'Tensor_CSPNet', help = 'name of model')
    parser.add_argument('--mlp',      default = False, help = 'whether the classifier is a multiple layer perception or not')

    parser.add_argument('--no-cuda',  action = 'store_true', default=False, help='disables CUDA training')
    parser.add_argument('--initial_lr', type = float, default = 1e-3, help="initial_lr for optimizer.")
    parser.add_argument('--decay',      type = float, default = 1, help="decay rate for adjust_learning")

    parser.add_argument('--start_No', type=int, default = 1,  help='testing starts on subject #')
    parser.add_argument('--end_No',   type=int, default = 9,  help='testing ends on subject #')
    parser.add_argument('--epochs',   type=int, default = 50, help='number of epochs to train')
    parser.add_argument('--patience', type=int, default = 10, help='patience for early stopping')

    parser.add_argument('--train_batch_size', type = int, default = 29, help = 'batch size in each epoch for trainning')
    parser.add_argument('--test_batch_size',  type = int, default = 29, help = 'batch size in each epoch for testing')
    parser.add_argument('--valid_batch_size', type = int, default = 29, help = 'batch size in each epoch for validation')

    parser.add_argument('--seed',         type = int, default = 1, metavar='S', help='random seed (default: 1)')
    parser.add_argument('--log_interval', type = int, default = 1, help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action = 'store_true', default=False, help='for Saving the current Model')

    parser.add_argument('--folder_name',         default = 'results')
    parser.add_argument('--weights_folder_path', default = 'model_paras/')

    args = parser.parse_args(args=[])

    return args

In [None]:
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    optimizer.lr = args.initial_lr * (args.decay ** (epoch // 100))


def main(args, train, test, train_y, test_y, sub, total_sub, kf_iter, validation):

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device   = torch.device("cuda" if use_cuda else "cpu")

    if validation:
        index_split = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
        for train_index, valid_index in index_split.split(train,train_y):
            train_sampler = SubsetRandomSampler(train_index)
            valid_sampler = SubsetRandomSampler(valid_index)
            
        train       = Variable(torch.from_numpy(train)).double()
        test        = Variable(torch.from_numpy(test)).double()
        train_y     = Variable(torch.LongTensor(train_y))
        test_y      = Variable(torch.LongTensor(test_y))
          
        train_dataset = dataloader_in_main(train, train_y)
        test_dataset  = dataloader_in_main(test, test_y)

        train_kwargs = {'batch_size': args.train_batch_size}
        if use_cuda:
              cuda_kwargs ={'num_workers': 100,
                            'sampler': train_sampler,
                              'pin_memory': True,
                              'shuffle': True     
              }
              train_kwargs.update(cuda_kwargs)
              
        valid_kwargs = {'batch_size': args.valid_batch_size}
        if use_cuda:
              cuda_kwargs ={'num_workers': 10,
                            'sampler':valid_sampler,
                              'pin_memory': True,
                              'shuffle': True     
              }
              valid_kwargs.update(cuda_kwargs)

        test_kwargs = {'batch_size': args.test_batch_size}
        if use_cuda:
              cuda_kwargs ={'num_workers': 1,
                              'pin_memory': True,
                              'shuffle': True      
              }
              test_kwargs.update(cuda_kwargs)

        train_loader  = torch.utils.data.DataLoader(dataset= train_dataset, **train_kwargs)
        valid_loader  = torch.utils.data.DataLoader(dataset= train_dataset, **valid_kwargs)
        test_loader   = torch.utils.data.DataLoader(dataset= test_dataset,  **test_kwargs)
    else:
        train       = Variable(torch.from_numpy(train)).double()
        test        = Variable(torch.from_numpy(test)).double()
        train_y     = Variable(torch.LongTensor(train_y))
        test_y      = Variable(torch.LongTensor(test_y))

        train_dataset = dataloader_in_main(train, train_y)
        test_dataset  = dataloader_in_main(test, test_y)

        train_kwargs  = {'batch_size': args.train_batch_size}

        if use_cuda:
            cuda_kwargs = {'num_workers': 100,
                            'pin_memory': True,
                            'shuffle': True,		
        }
            train_kwargs.update(cuda_kwargs)

        test_kwargs = {'batch_size': args.test_batch_size}
        if use_cuda:
            cuda_kwargs  = {'num_workers': 100,
                            'pin_memory': True,
                            'shuffle': True,      
            }
            test_kwargs.update(cuda_kwargs)

        train_loader  = torch.utils.data.DataLoader(dataset= train_dataset, **train_kwargs)
        test_loader   = torch.utils.data.DataLoader(dataset= test_dataset,  **test_kwargs)


    model = Tensor_CSPNet_Basic(channel_num = train.shape[1]*train.shape[2], 
        mlp = args.mlp,
        dataset = 'BCIC',
        ).to(device)

    optimizer = geoopt.optim.RiemannianAdam(model.parameters(), lr=args.initial_lr)
    """MixOptimizer is for StiefelParameter, SPDParameter, etc."""
    #optimizer = MixOptimizer(model.parameters(), lr= args.initial_lr)

    early_stopping = EarlyStopping(
        alg_name = args.alg_name, 
        path_w   = args.weights_folder_path + args.alg_name + '_checkpoint.pt', 
        patience = args.patience, 
        verbose  = True, 
        )

    print('#####Start Trainning######')

    for epoch in range(1, args.epochs+1):

        adjust_learning_rate(optimizer, epoch)

        model.train()

        train_correct = 0
    
        for batch_idx, (batch_train, batch_train_y) in enumerate(train_loader):

            optimizer.zero_grad()

            logits = model(batch_train.to(device))
            output = F.log_softmax(logits, dim = -1)
            loss   = F.nll_loss(output, batch_train_y.to(device))

            loss.backward()
            optimizer.step()

            if batch_idx % args.log_interval == 0:
                print('----#------#-----#-----#-----#-----#-----#-----')
                pred    = output.data.max(1, keepdim=True)[1]
                train_correct += pred.eq(batch_train_y.to(device).data.view_as(pred)).long().cpu().sum()
                torch.save(model.state_dict(), args.weights_folder_path + args.alg_name+'_model.pth')
                torch.save(optimizer.state_dict(), args.weights_folder_path+'optimizer.pth')


                print('['+args.alg_name+': Sub No.{}/{} Fold {}/10, Epoch {}/{}, Completed {:.0f}%]:\nTrainning loss {:.10f} Acc.: {:.4f}'.format(\
                        sub, total_sub, kf_iter+1, epoch, args.epochs, 100. * (1+batch_idx) / len(train_loader), loss.cpu().detach().numpy(),\
                        train_correct.item()/len(train_loader.dataset)))
                    
        if validation:
            #Validate the Model
            valid_losses  = []
            valid_loss    =  0
            valid_correct =  0

            model.eval()

            for batch_idx, (batch_valid, batch_valid_y) in enumerate(valid_loader):

                logits         = model(batch_valid.to(device))
                output         = F.log_softmax(logits, dim = -1)
                valid_loss    += F.nll_loss(output, batch_valid_y.to(device))
                valid_losses.append(valid_loss.item())
                
                
                pred           = output.data.max(1, keepdim=True)[1]
                valid_correct += pred.eq(batch_valid_y.to(device).data.view_as(pred)).long().cpu().sum()

            print('Validate loss: {:.10f} Acc: {:.4f}'.format(sum(valid_losses), valid_correct.item()/len(valid_loader.dataset)))
            
            early_stopping(np.average(valid_losses), model)
            
            if early_stopping.early_stop:
              print("Early Stopping!")
              break
        else:
            pass
        

    #Testing
    print('###############################################################')
    print('START TESTING')
    print('###############################################################')

    
    model.eval()

    test_loss    = 0
    test_correct = 0

    with torch.no_grad():
        for batch_idx, (batch_test, batch_test_y) in enumerate(test_loader):

            logits        = model(batch_test.to(device))
            output        = F.log_softmax(logits, dim = -1)
            test_loss    += F.nll_loss(output, batch_test_y.to(device))
            
            test_pred     = output.data.max(1, keepdim=True)[1]
            test_correct += test_pred.eq(batch_test_y.to(device).data.view_as(test_pred)).long().cpu().sum()

            print('-----------------------------------')
            print('Testing Batch {}:'.format(batch_idx))
            print('  Pred Label:', test_pred.view(1, test_pred.shape[0]).cpu().numpy()[0])
            print('Ground Truth:', batch_test_y.numpy())


    return test_correct.item()/len(test_loader.dataset), test_loss.item()/len(test_loader.dataset)

In [None]:
if __name__ == '__main__':

    args   = args_parser()

    alg_df = pd.DataFrame(columns=['R1', 'R2', 'R3','R4', 'R5', 'R6', 'R7', 'R8','R9', 'R10','Avg'])

    print('############Start Task#################')
    
    for sub in range(args.start_No, args.end_No + 1):

        BCIC_dataset = load_BCIC(sub,
            TorE     = True, 
            alg_name = args.alg_name,
            scenario = 'CV'
            )

        alg_record = []

        start      = time.time()

        for kf_iter in range(0, 10):

            x_train_stack, x_test_stack, y_train, y_test = BCIC_dataset.generate_training_test_set_CV(kf_iter)

            acc, loss = main(
                args       = args, 
                train      = x_train_stack, 
                test       = x_test_stack, 
                train_y    = y_train, 
                test_y     = y_test,
                kf_iter    = kf_iter, 
                sub        = sub, 
                total_sub  = args.end_No - args.start_No + 1, 
                validation = False,
                )
            
            print('##################################################################')

            print(args.alg_name + ' Testing Loss.: {:4f} Acc: {:4f}'.format(loss, acc))

            alg_record.append(acc)

        end = time.time()

        alg_record.append(np.mean(alg_record))
        
        alg_df.loc[sub] = alg_record
 
        alg_df.to_csv(args.folder_name + '/' \
        + time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()) \
        + args.alg_name \
        +'_Sub(' \
        + str(args.start_No) \
        +'-' \
        +str(args.end_No) \
        +')' \
        +'_' \
        + str(args.epochs)\
        + '.csv'\
        , index = False)

############Start Task#################
#####Start Trainning######


  return expm(x)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[Tensor_CSPNet: Sub No.9/9 Fold 7/10, Epoch 17/50, Completed 33%]:
Trainning loss 0.2366576055 Acc.: 0.3282
----#------#-----#-----#-----#-----#-----#-----
[Tensor_CSPNet: Sub No.9/9 Fold 7/10, Epoch 17/50, Completed 44%]:
Trainning loss 0.0958075175 Acc.: 0.4363
----#------#-----#-----#-----#-----#-----#-----
[Tensor_CSPNet: Sub No.9/9 Fold 7/10, Epoch 17/50, Completed 56%]:
Trainning loss 0.0740202722 Acc.: 0.5483
----#------#-----#-----#-----#-----#-----#-----
[Tensor_CSPNet: Sub No.9/9 Fold 7/10, Epoch 17/50, Completed 67%]:
Trainning loss 0.1237825530 Acc.: 0.6602
----#------#-----#-----#-----#-----#-----#-----
[Tensor_CSPNet: Sub No.9/9 Fold 7/10, Epoch 17/50, Completed 78%]:
Trainning loss 0.1884280073 Acc.: 0.7606
----#------#-----#-----#-----#-----#-----#-----
[Tensor_CSPNet: Sub No.9/9 Fold 7/10, Epoch 17/50, Completed 89%]:
Trainning loss 0.2660569246 Acc.: 0.8649
----#------#-----#-----#-----#-----#-----#-----