# Imports 

In [None]:
import os
import numpy as np

import torch
import torch.utils.data as data

import tensorboard
from pytorch_lightning import loggers as pl_loggers
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

In [None]:
import patch_path
import torchcde

from mt_code.datasets import load_dataset, P300Dataset 
from mt_code.models import NeuralCde, OdeLstm, EegNet

from mt_code.runners import Learner as IrregularSequenceLearner

# Regular Experiments

## EEGNET

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    mode = 'min',
    dirpath='logs/models/demons/EEGNET_reg/',
    save_top_k = 3,
)

In [None]:
trainloader, testloader, in_features, num_classes, return_sequences, class_balance = load_dataset(
    'p300', 
    timestamps=False, 
    coeffs=False, 
    batch_size=512,  
    irregular = False, 
    data_dir = '../data/demons/nery_demons_dataset'
)

In [None]:
eegnet = EegNet(
    input_size =(40, 8), 
    rate = 50,
    F1 = 4,
    D = 6,
)
learn = IrregularSequenceLearner(eegnet, lr=0.1, timestamps=False, class_weights = 1/class_balance)


In [None]:
trainer = pl.Trainer(
    max_epochs=10,
    progress_bar_refresh_rate=1,
    callbacks=[checkpoint_callback],
    log_every_n_steps=1,
    val_check_interval=0.1,  
)
trainer.fit(learn, trainloader, val_dataloaders = testloader)


In [None]:
best_path = checkpoint_callback.best_model_path
checkpoint = torch.load(best_path)
states = {}
for k_new, k_old in zip(eegnet.state_dict().keys(), checkpoint['state_dict'].keys()):
    states[k_new] = checkpoint['state_dict'].get(k_old)
eegnet.load_state_dict(state_dict = states)
learn = IrregularSequenceLearner(eegnet, lr=0.05, timestamps=False, class_weights = 1/class_balance)

In [None]:
results = trainer.test(learn, testloader)


In [None]:
results = trainer.test(learn, trainloader)


## CDE

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    mode = 'min',
    dirpath='logs/models/demons/cde_reg/',
    save_top_k = 3,
)

In [None]:
trainloader, testloader, in_features, num_classes, return_sequences, class_balance = load_dataset(
    'p300', 
    timestamps=False, 
    coeffs=True, 
    irregular = False, 
    batch_size=1024,  
    data_dir = '../data/demons/nery_demons_dataset'
)

In [None]:
cde = NeuralCde(
    8,
    2,
    num_classes,
    return_sequences=False
)
learn = IrregularSequenceLearner(cde, lr=0.05, timestamps=False, class_weights = 1/class_balance)

In [None]:
trainer = pl.Trainer(
    max_epochs=1,
    progress_bar_refresh_rate=1,
    gradient_clip_val=1000,
    callbacks=[checkpoint_callback],
    log_every_n_steps=1,
    val_check_interval=0.1,    
)


In [None]:
trainer.fit(learn, trainloader, val_dataloaders = testloader)

In [None]:
best_path = checkpoint_callback.best_model_path
checkpoint = torch.load(best_path)
states = {}
for k_new, k_old in zip(cde.state_dict().keys(), checkpoint['state_dict'].keys()):
    states[k_new] = checkpoint['state_dict'].get(k_old)
cde.load_state_dict(state_dict = states)
    

In [None]:
learn = IrregularSequenceLearner(cde, lr=0.05, timestamps=False,  class_weights = 1/class_balance)

results = trainer.test(learn, testloader)

In [None]:
results = trainer.test(learn, trainloader)

### Vis

In [None]:
import matplotlib.pyplot as plt

In [None]:
# matplotlib
SMALL_SIZE = 15
MEDIUM_SIZE = 20
BIGGER_SIZE = 30

plt.rc('font', size=SMALL_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)  # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [None]:
ds = P300Dataset('../data/demons/nery_demons_dataset')
ds.get_data_for_experiments()

In [None]:
data = (ds.train_x, ds.train_y)

In [None]:
idc_p300 = np.where(data[1]==1)[0]

In [None]:
n = 100000
x = torch.stack((data[0][idc_p300[301]], data[0][2]))

In [None]:
c = torchcde.natural_cubic_coeffs(x)
X = torchcde.CubicSpline(c)
XX = X.evaluate(np.arange(0,x.size(1), x.size(1)/n))

In [None]:
z0=X.evaluate(0)
times = torch.arange(0, x.size(1), x.size(1)/n)
z_t = torchcde.cdeint(X=X, z0=z0, func=cde.func, t=times.to(torch.float32))

In [None]:
time = ds.train_t[0,:]
new_time = np.linspace(start=ds.train_t[0,0], stop=ds.train_t[0, -1], num=len(XX[1, :,0]))
plt.figure(figsize = (12,8))
plt.plot(time, x[0, :, 0], alpha=0.7, label = 'наблюдения сигнала')
plt.plot(new_time, XX[0, :,0], alpha=0.7, label = 'интерполяция сингала')
plt.plot(new_time, z_t[0, :, 0].detach().numpy(), alpha=0.7, label = 'z(t)')
plt.ylabel('амплитуда сигнала, мкВ')
plt.xlabel('время, мс')
plt.legend()
plt.show()

In [None]:
time = ds.train_t[0,:]
new_time = np.linspace(start=ds.train_t[0,0], stop=ds.train_t[0, -1], num=len(XX[1, :,0]))
plt.figure(figsize = (12,8))
plt.plot(time, x[1, :, 0], alpha=0.7, label = 'наблюдения сигнала')
plt.plot(new_time, XX[1, :,0], alpha=0.7, label = 'интерполяция сингала')
plt.plot(new_time, z_t[1, :, 0].detach().numpy(), alpha=0.7, label = 'z(t)')
plt.ylabel('амплитуда сигнала, мкВ')
plt.xlabel('время, мс')
plt.legend()
plt.show()

## ODELSTM

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_f1',
    mode = 'max',
    dirpath='logs/models/demons/odelstm_reg/',
    save_top_k = 3,
)

In [None]:
trainloader, testloader, in_features, num_classes, return_sequences, class_balance = load_dataset(
    'p300', 
    timestamps=True, 
    coeffs=False, 
    irregular = False, 
    batch_size=512,  
    data_dir = '../data/demons/nery_demons_dataset'
)

In [None]:
olstm = OdeLstm(
    8,
    8,
    num_classes,
    return_sequences=False
)

learn = IrregularSequenceLearner(olstm, lr=0.05, timestamps=True, class_weights = 1/class_balance)

In [None]:
trainer = pl.Trainer(
    max_epochs=5,
    progress_bar_refresh_rate=1,
    gradient_clip_val=1000,
    callbacks=[checkpoint_callback],
    log_every_n_steps=1,
    val_check_interval=0.05, 
)

trainer.fit(learn, trainloader, val_dataloaders = testloader)

In [None]:
best_path = checkpoint_callback.best_model_path

checkpoint = torch.load(best_path)
states = {}
for k_new, k_old in zip(olstm.state_dict().keys(), checkpoint['state_dict'].keys()):
    states[k_new] = checkpoint['state_dict'].get(k_old)
olstm.load_state_dict(state_dict = states)

learn = IrregularSequenceLearner(olstm, lr=0.05, timestamps=True, class_weights = 1/class_balance)
results = trainer.test(learn, testloader)

In [None]:
results = trainer.test(learn, trainloader)


# Irregular experiments 

## EEGNET

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    mode = 'max',
    dirpath='logs/models/demons/EEGNET_irreg/',
    save_top_k = 3,
)

In [None]:
trainloader, testloader, in_features, num_classes, return_sequences, class_balance = load_dataset(
    'p300', 
    timestamps=False, 
    coeffs=False, 
    batch_size=1024,  
    irregular = True, 
    data_dir = '../data/demons/nery_demons_dataset'
)



In [None]:
eegnet = EegNet(
    input_size =(32, 8), 
    rate = 50,
    F1 = 2,
    D = 4,
)
learn = IrregularSequenceLearner(eegnet, lr=0.05, timestamps=False, class_weights = 1/class_balance)


In [None]:
trainer = pl.Trainer(
    max_epochs=10,
    progress_bar_refresh_rate=1,
    callbacks=[checkpoint_callback],
    gradient_clip_val=1000,
    log_every_n_steps=1,
    val_check_interval=0.1,  
)
trainer.fit(learn, trainloader, val_dataloaders = testloader)

In [None]:
best_path = checkpoint_callback.best_model_path
checkpoint = torch.load(best_path)
states = {}
for k_new, k_old in zip(eegnet.state_dict().keys(), checkpoint['state_dict'].keys()):
    states[k_new] = checkpoint['state_dict'].get(k_old)
eegnet.load_state_dict(state_dict = states)
learn = IrregularSequenceLearner(eegnet, lr=0.05, timestamps=False, class_weights = 1/class_balance)
results = trainer.test(learn, testloader)


In [None]:
results = trainer.test(learn, trainloader)


## NCDE

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    mode = 'min',
    dirpath='logs/models/demons/cde_irr/',
    save_top_k = 3,
)

In [None]:
trainloader, testloader, in_features, num_classes, return_sequences, class_balance = load_dataset(
    'p300', 
    timestamps=False, 
    coeffs=True, 
    irregular = True, 
    batch_size=1024,  
    data_dir = '../data/demons/nery_demons_dataset'
)

In [None]:
cde = NeuralCde(
    8,
    8,
    num_classes,
    return_sequences=False
)

learn = IrregularSequenceLearner(cde, lr=0.05, timestamps=False, class_weights = 1/class_balance)

In [None]:
trainer = pl.Trainer(
    max_epochs=5,
    progress_bar_refresh_rate=1,
    gradient_clip_val=1000,
    callbacks=[checkpoint_callback],
    log_every_n_steps=1,
    val_check_interval=0.02,   
)

trainer.fit(learn, trainloader, val_dataloaders = testloader)

In [None]:
best_path = checkpoint_callback.best_model_path
checkpoint = torch.load(best_path)
states = {}
for k_new, k_old in zip(cde.state_dict().keys(), checkpoint['state_dict'].keys()):
    states[k_new] = checkpoint['state_dict'].get(k_old)
cde.load_state_dict(state_dict = states)
    

learn = IrregularSequenceLearner(cde, lr=0.01, timestamps=False,  class_weights = 1/class_balance)

results = trainer.test(learn, testloader)

In [None]:
results = trainer.test(learn, trainloader)


## ODELSTM

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_f1',
    mode = 'max',
    dirpath='logs/models/demons/odelstm_irr/',
    save_top_k = 3,
)



In [None]:
trainloader, testloader, in_features, num_classes, return_sequences, class_balance = load_dataset(
    'p300', timestamps=True, coeffs=False, irregular = False, batch_size=512,  data_dir = '../data/demons/nery_demons_dataset')



In [None]:
olstm = OdeLstm(
    8,
    8,
    num_classes,
    return_sequences=False
)
learn = IrregularSequenceLearner(olstm, lr=0.05, timestamps=True, class_weights = 1/class_balance)






In [None]:
trainer = pl.Trainer(
    max_epochs=5,
    progress_bar_refresh_rate=1,
    gradient_clip_val=1,
    callbacks=[checkpoint_callback]
    
)
trainer.fit(learn, trainloader, val_dataloaders = testloader)




In [None]:

checkpoint = torch.load(best_path)
states = {}
for k_new, k_old in zip(olstm.state_dict().keys(), checkpoint['state_dict'].keys()):
    states[k_new] = checkpoint['state_dict'].get(k_old)
olstm.load_state_dict(state_dict = states)

learn = IrregularSequenceLearner(olstm, lr=0.05, timestamps=True, class_weights = 1/class_balance)

results = trainer.test(learn, testloader)



In [None]:
results = trainer.test(learn, trainloader)
