In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../src')

from data import GraphMatrixDataset, CNNMatrixDataset
from utils import P300Getter, train_model, plot_sample, show_progress
from interpretation import *
from models_cnn import *
from models_gnn import *
from graph import get_delaunay_graph, get_pos_init_graph, plot_graph, get_neighbors_graph
import run_exp
import regularization

In [3]:
import mne
import pandas as pd
import numpy as np
import scipy.io
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LogNorm, Normalize
from tqdm import tqdm

import torch
from torch import nn, optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F

from scipy.spatial import Delaunay
import networkx as nx
import scipy.sparse as sp
import time
from sklearn.neighbors import NearestNeighbors

from mne import Epochs, pick_types, events_from_annotations
from mne.channels import make_standard_montage, DigMontage
from mne.io import concatenate_raws, read_raw_edf
from mne.datasets import eegbci
from mne.decoding import Scaler

from torch_geometric.data import Data, InMemoryDataset

import wandb
import pickle

In [4]:
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

import pathlib

In [5]:
from IPython.display import clear_output

In [6]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda:0')

print(device)

cuda:0


In [7]:
DATA_PATH = pathlib.Path(r'C:\Users\Vladimir\PycharmProjects\EEGPatternRecognition\matrix_dataset')

train_A_raw = scipy.io.loadmat(DATA_PATH / 'Subject_A_Train.mat')
test_A_raw = scipy.io.loadmat(DATA_PATH / 'Subject_A_Test.mat')

eloc = mne.channels.read_custom_montage(DATA_PATH / 'eloc64.txt')
info = mne.create_info(ch_names=eloc.ch_names, ch_types=['eeg'] * 64, sfreq=240)

train_A_chars = list(train_A_raw['TargetChar'][0])
test_A_chars = list('WQXPLZCOMRKO97YFZDEZ1DPI9NNVGRQDJCUVRMEUOOOJD2UFYPOO6J7LDGYEGOA5VHNEHBTXOO1TDOILUEE5BFAEEXAW_K4R3MRU')

A_train_ds = P300Getter(train_A_raw, eloc, sample_size=72)
A_test_ds = P300Getter(test_A_raw, eloc, sample_size=72, target_chars=test_A_chars)

A_train_ds.get_cnn_p300_dataset(filter=True)
A_test_ds.get_cnn_p300_dataset(filter=True)

A_train_ds.upsample(4)
#A_test_ds.upsample(2)

X_train_A, y_train_A = A_train_ds.get_data()
X_test_A, y_test_A = A_test_ds.get_data()

clear_output()
print("Success")

Success


In [8]:
train_A_dataset = CNNMatrixDataset(tensors=(X_train_A, y_train_A), with_target=True, transform=None)
test_A_dataset = CNNMatrixDataset(tensors=(X_test_A, y_test_A), with_target=True, transform=None)

batch_size = 1024

train_A_CNN = DataLoader(train_A_dataset, batch_size=batch_size, shuffle=True)
test_A_CNN = DataLoader(test_A_dataset, batch_size=batch_size, shuffle=True)

data_loaders_CNN = {'train':train_A_CNN, 'val':test_A_CNN}

In [9]:
torch.manual_seed(44)
np.random.seed(44)

A_init = torch.empty(64, 64)
k = math.sqrt(1 / (64 * 64))
nn.init.uniform_(A_init, -k, k)
A_init = (A_init + A_init.T) / 2

A_init

tensor([[ 6.8622e-03,  1.0667e-02,  3.4481e-03,  ..., -2.1118e-03,
          5.2476e-03,  9.8840e-03],
        [ 1.0667e-02,  5.4418e-04,  3.9523e-03,  ...,  2.1553e-03,
          6.6262e-03,  3.1635e-04],
        [ 3.4481e-03,  3.9523e-03, -6.8051e-03,  ...,  1.0208e-04,
         -5.5562e-04,  3.8470e-03],
        ...,
        [-2.1118e-03,  2.1553e-03,  1.0208e-04,  ...,  1.4679e-02,
          2.0596e-03,  1.4644e-04],
        [ 5.2476e-03,  6.6262e-03, -5.5562e-04,  ...,  2.0596e-03,
         -1.1903e-02,  2.9926e-03],
        [ 9.8840e-03,  3.1635e-04,  3.8470e-03,  ...,  1.4644e-04,
          2.9926e-03,  9.0171e-06]])

In [10]:
learning_params = {
    'num_epochs' : 500,
    # 'num_epochs' : 150,
    'lr' : 1e-4,
    'weight_decay' : 1e-2,
    'step_size' : 5,
    'gamma' : 1,
    'num_classes' : 2,
    'model_type' : 'CNN'
  }

In [11]:
A_res_acyclic = run_exp.run(
    run_name_fmt='sym-random-acyclic-g{0}-sd',
    reg_cls=regularization.AcyclicReg,
    gamma_grid=np.logspace(-3, 2, num=6, base=10),
    A_init=torch.zeros_like(A_init),
    data_loaders=data_loaders_CNN,
    learning_params=learning_params,
    device=device
)

wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: bogachevv. Use `wandb login --relogin` to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

  0%|          | 0/500 [00:00<?, ?it/s]

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/A_grad,▁▂▂▁▂▂▂▂▂▂▃▂▂▁▂▃▃▃▂▂▃▂▂▃▃█▂▂▂▃▂▅▃▂▄▃▂▂▂▂
train/epoch_acc,▁▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████
train/epoch_bc,▁▄▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████
train/epoch_f1,▁▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████
train/epoch_ones,▁▂▁▄▅▅▅▅▆▅▅▅▅▅▅▅▄▅▅▄▆▅▇▆▅█▆▆▆▆▆▆▆▇▇▆▇▇▇▆
train/epoch_precision,▁▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇█▇▇████▇█████████
train/epoch_recall,▁▄▆▆▆▆▇▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇█▇▇█████
train/loss,██▅▅▄▄▄▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
train/max_acc,▁▂▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇███████
train/min_acc,▁▁▂▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████

0,1
train/A_grad,0.12687
train/epoch_acc,0.71659
train/epoch_bc,0.71659
train/epoch_f1,0.71788
train/epoch_ones,536.125
train/epoch_precision,0.71462
train/epoch_recall,0.72118
train/loss,0.19242
train/max_acc,0.72212
train/min_acc,0.71106


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

  0%|          | 0/500 [00:00<?, ?it/s]

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/A_grad,▁▃▅▆▅▄▃▃▄▃▆▄▂▄▃▄▂▅▄▆▅▆▆▆▆▄▃▃▅▃▄▄▃▃▄▄█▅▅▄
train/epoch_acc,▁▃▄▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████
train/epoch_bc,▁▂▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇██████████
train/epoch_f1,▁▁▅▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████
train/epoch_ones,▁▂▃▅▅▅▅▆▅▆▅▅▆▅▅▅▅▅▅▆▅▆▇▆▆▆▆▆▆▆▅▆▆▆▆▆▇▇█▇
train/epoch_precision,▁▂▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇██▇▇█▇███████
train/epoch_recall,▁▁▂▁▃▄▄▃▃▄▄▄▅▄▄▅▅▆▆▅▆▅▅▆▅▆▆▇▆▇▇▇▇▆▇▇▇▇▇█
train/loss,█▇▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
train/max_acc,▁▃▄▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████████
train/min_acc,▁▁▁▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇███████

0,1
train/A_grad,0.16996
train/epoch_acc,0.71506
train/epoch_bc,0.71506
train/epoch_f1,0.71595
train/epoch_ones,534.58333
train/epoch_precision,0.71372
train/epoch_recall,0.7182
train/loss,0.19289
train/max_acc,0.7206
train/min_acc,0.70952


  0%|          | 0/500 [00:00<?, ?it/s]

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/A_grad,▁▁▂▃▄▃▃▃▄▃▂▃▅▄▄▇▄▄▃▇▃▃▂▄▃▆▃█▃▃▂▇▄▄█▄▆▅▅▇
train/epoch_acc,▁▄▄▄▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇██████
train/epoch_bc,▁▃▄▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████
train/epoch_f1,▁▄▅▆▆▆▆▆▆▆▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████
train/epoch_ones,▁▂▄▄▅▄▅▄▅▅▅▄▄▄▅▅▅▄▄▆▄▅▄▆▅▆▅▆▆▅▆▆▆▇▆▅▇█▇▆
train/epoch_precision,▁▃▄▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇▇███
train/epoch_recall,▁▅▅▆▆▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇█████████
train/loss,██▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
train/max_acc,▁▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇███████
train/min_acc,▁▅▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████

0,1
train/A_grad,0.19838
train/epoch_acc,0.71549
train/epoch_bc,0.71549
train/epoch_f1,0.71697
train/epoch_ones,536.79167
train/epoch_precision,0.71327
train/epoch_recall,0.72071
train/loss,0.19325
train/max_acc,0.72103
train/min_acc,0.70995


  0%|          | 0/500 [00:00<?, ?it/s]

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/A_grad,▁▁▁▂▂▃▅▆▃▃▇█▄▄▅▄▄▃▃▄▅▄▃▅▇▂▄▄▄▃█▆█▃▃▄▃▃▆▄
train/epoch_acc,▁▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇██████████
train/epoch_bc,▁▂▂▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████
train/epoch_f1,▁▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇██████████
train/epoch_ones,▁▃▅▆▅▅▅▆▆▆▅▆▆▅▆▆▆▆▆▆▆▆▆▇▅▆▆▇▇▇█▇▇█▆▇▇▇█▇
train/epoch_precision,▁▁▂▃▄▄▅▅▅▅▅▅▅▆▅▆▆▆▆▆▆▆▆▆▆▇▆▆▇▇▇▇▇▇▇▇▇█▇▇
train/epoch_recall,▁▂▃▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇██▇████████████
train/loss,█▆▄▄▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
train/max_acc,▁▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████████
train/min_acc,▁▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████████

0,1
train/A_grad,0.14062
train/epoch_acc,0.71298
train/epoch_bc,0.71298
train/epoch_f1,0.71471
train/epoch_ones,537.70833
train/epoch_precision,0.71042
train/epoch_recall,0.71906
train/loss,0.19421
train/max_acc,0.71853
train/min_acc,0.70743


  0%|          | 0/500 [00:00<?, ?it/s]

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/A_grad,▁▁▃▄▃▃▄▄▄▆▂▃▆▄▆▄█▃▄▄▄▅▇▄▅▃▆▅▃▄▅▅▅▃▅▆▅█▃▄
train/epoch_acc,▁▃▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇███████████
train/epoch_bc,▁▂▂▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████
train/epoch_f1,▁▂▂▃▃▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇███
train/epoch_ones,▁▁▅▅▆▅▅▅▆▅▅▄▅▇▅▆▅▇▆▇▅█▆▆▆▇▅▇▇▇▇▆█▇▇▇█▇▆█
train/epoch_precision,▁▂▄▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇██▇█████
train/epoch_recall,▁▅▆▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇█████████
train/loss,█▆▄▄▄▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
train/max_acc,▁▂▃▄▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████
train/min_acc,▁▃▃▄▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇██████████

0,1
train/A_grad,0.12568
train/epoch_acc,0.71482
train/epoch_bc,0.71482
train/epoch_f1,0.71706
train/epoch_ones,539.66667
train/epoch_precision,0.71147
train/epoch_recall,0.72275
train/loss,0.19504
train/max_acc,0.72037
train/min_acc,0.70928


  0%|          | 0/500 [00:00<?, ?it/s]

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/A_grad,▁▁▃▄▅▇▄▂▅▄▅▄▃▅▅▄▃▅▃▅▅▅▅▄▄▃▄▄▇▆█▄▅▄▅▃▇▅▄▅
train/epoch_acc,▁▂▃▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████
train/epoch_bc,▁▁▃▄▄▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████
train/epoch_f1,▁▂▂▃▄▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇██▇▇▇████████
train/epoch_ones,▃▃▁▁▄▇▆▆▅▆▅▅▆▅▅▆▅▅▅▅▅▅▅▆▄▆▅▅▅▆▆▆▆▆▆▆█▇▆▆
train/epoch_precision,▁▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████
train/epoch_recall,▁▂▂▄▄▆▆▆▆▆▇▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██▇█████
train/loss,███▅▅▄▄▄▄▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
train/max_acc,▁▂▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████
train/min_acc,▁▂▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████

0,1
train/A_grad,0.12443
train/epoch_acc,0.71227
train/epoch_bc,0.71227
train/epoch_f1,0.71186
train/epoch_ones,529.70833
train/epoch_precision,0.71289
train/epoch_recall,0.71082
train/loss,0.19361
train/max_acc,0.71783
train/min_acc,0.70672


In [12]:
with open('../A_zero_acyclic_sd_dump.bin', 'wb') as f:
    pickle.dump(
        obj=A_res_acyclic,
        file=f
    )