# Install libraries 


In [1]:
import sys, os, json
import mne, sklearn, wandb
import numpy as np
import pandas as pd

from scipy.interpolate import interp1d
from nilearn import datasets, image, masking, plotting
from nilearn.input_data import NiftiLabelsMasker


# animation part
from IPython.display import HTML
import matplotlib
import matplotlib.pyplot as plt
# from celluloid import Camera   # it is convinient method to animate
from matplotlib import animation, rc
from matplotlib.animation import FuncAnimation


## torch libraries 
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset
from torch.utils.data import DataLoader, Subset

from pytorch_model_summary import summary




In [2]:
%load_ext autoreload
%autoreload 2
sys.path.insert(1, os.path.realpath(os.path.pardir))

from utils import get_datasets
from utils import preproc
from utils import torch_dataset
from utils import train_utils
from utils import inference
from utils.models_arch import autoencoder_new

# Set all hyperparameters
- Cuda and GPU.
- Parameters of dataset. 
- random seed( if necessary). 


In [3]:
# import random

# torch.manual_seed(0)
# random.seed(0)  # python operation seed
# np.random.seed(0)

# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True

print(torch.cuda.is_available(), torch.cuda.device_count())
torch.cuda.set_device(2)

True 4


In [4]:
config = dict(  
                dataset_name = 'CWL', # CWL
                new_fps=100, 
                freqs = np.logspace(np.log10(2), np.log10(99), 16), 
    
                n_channels = 30, # 30 
                n_roi = 6,
                
                bold_delay = 5,
                to_many = True,
                random_subsample = True,
                sample_per_epoch = 2048, 
                WINDOW_SIZE = 2048,
                    
                optimizer='adam',
                lr=3e-4,
                weight_decay=0, 
                batch_size=32, 
                
                preproc_type = 'dB_log',
                loss_function = 'mse_corr', 
                model_type = '1D_CNN_AE_wav2vec2'
                )


hp_autoencoder = dict(n_electrodes=config['n_channels'],
                      n_freqs = len(config['freqs']),
                      n_channels_out = config['n_roi'],

                     channels = [128, 128, 128, 128], 
                     kernel_sizes=[3, 3, 3],
                     strides=[8, 4, 4], 
                     dilation=[1, 1, 1], 
                     decoder_reduce=4
                     )


config = {**hp_autoencoder, **config}

params_train = {'batch_size': config['batch_size'],
                'shuffle': True,
                'num_workers': 0}

params_val = {'batch_size': config['batch_size'],
              'shuffle': False}

# Upload preprocessed dataset from np files. 
It should accelerate speed of experiments.

In [5]:
with open("../data/processed/labels_roi_6.json", 'r') as f:
    labels_roi = json.load(f)


if config['dataset_name']=='CWL':
    config['patients'] = ['trio1', 'trio2', 'trio3', 'trio4']
    dataset_paths = []
    for name in config['patients']:
        dataset_path = f'../data/processed/CWL/{name}_100_hz_6_roi_2_99_freqs.npz'
        dataset_paths.append(dataset_path)
    
elif config['dataset_name']=='NODDI':
    dataset_path = '../data/processed/NODDI/32_100_hz_6_roi_2_99_freqs.npz'
else:
    print('no such dataset')

    
    

# download data
list_train_dataset_prep = []
list_test_dataset_prep = []

for dataset_path in dataset_paths:
    
    data = np.load(dataset_path)

    train_dataset_prep = (data['x_train'], data['y_train'])
    test_dataset_prep = (data['x_test'], data['y_test'])


    # apply time dealy corrected
    train_dataset_prep = preproc.bold_time_delay_align(train_dataset_prep, 
                                                       config['new_fps'],
                                                       config['bold_delay'])
    test_dataset_prep = preproc.bold_time_delay_align(test_dataset_prep, 
                                                      config['new_fps'],
                                                      config['bold_delay'])

    list_train_dataset_prep.append(train_dataset_prep)
    list_test_dataset_prep.append(test_dataset_prep)
    print('Size of train dataset:', train_dataset_prep[0].shape, train_dataset_prep[1].shape)
    print('Size of test dataset:', test_dataset_prep[0].shape, test_dataset_prep[1].shape)


    
# Make big train dataaset but the same test.

# # torch dataset creation 
list_torch_dataset_train = []
for train_dataset_prep in list_train_dataset_prep:
    torch_dataset_train = torch_dataset.CreateDataset_eeg_fmri(train_dataset_prep, 
                                                                random_sample=config['random_subsample'], 
                                                                sample_per_epoch=config['sample_per_epoch'], 
                                                                to_many=config['to_many'], 
                                                                window_size = config['WINDOW_SIZE'])
    list_torch_dataset_train.append(torch_dataset_train)
torch_dataset_train = torch.utils.data.ConcatDataset(list_torch_dataset_train)


test_dataset_prep = list_test_dataset_prep[0]
torch_dataset_test = torch_dataset.CreateDataset_eeg_fmri(test_dataset_prep, 
                                                            random_sample=False, 
                                                            sample_per_epoch=None, 
                                                            to_many=config['to_many'], 
                                                            window_size = config['WINDOW_SIZE'])
# # because you do not have strid for val data. 
torch_dataset_test = Subset(torch_dataset_test, np.arange(len(torch_dataset_test))[::100])

# init dataloaders for training
train_loader = torch.utils.data.DataLoader(torch_dataset_train, **params_train)
val_loader = torch.utils.data.DataLoader(torch_dataset_test, **params_val)




Size of train dataset: (30, 16, 20690) (6, 20690)
Size of test dataset: (30, 16, 5500) (6, 5500)
Size of train dataset: (30, 16, 21080) (6, 21080)
Size of test dataset: (30, 16, 5500) (6, 5500)
Size of train dataset: (30, 16, 20885) (6, 20885)
Size of test dataset: (30, 16, 5500) (6, 5500)
Size of train dataset: (30, 16, 21860) (6, 21860)
Size of test dataset: (30, 16, 5500) (6, 5500)


# Init Model, Loss, optimizers

In [6]:
model = autoencoder_new.AutoEncoder1D(**hp_autoencoder)

print(summary(model, torch.zeros(4, config['n_channels'], 
                                 len(config['freqs']),
                                 config['WINDOW_SIZE']), show_input=False))


-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
       ConvBlock-1      [4, 128, 2048]         184,576         184,576
       ConvBlock-2       [4, 128, 256]          49,408          49,408
       ConvBlock-3        [4, 128, 64]          49,408          49,408
       ConvBlock-4        [4, 128, 16]          49,408          49,408
     UpConvBlock-5         [4, 32, 64]          12,352          12,352
     UpConvBlock-6        [4, 32, 256]           3,136           3,136
     UpConvBlock-7       [4, 32, 2048]           3,136           3,136
          Conv1d-8        [4, 6, 2048]             198             198
Total params: 351,622
Trainable params: 351,622
Non-trainable params: 0
-----------------------------------------------------------------------


# Model training

In [None]:
n_runs = 3

for i in range(n_runs):
    
    model = autoencoder_new.AutoEncoder1D(**hp_autoencoder)

    loss_func = train_utils.make_complex_loss_function(mse_weight = 0.1, 
                                                       corr_weight = 1,
                                                       manifold_weight = 0,
                                                       bound=1)
    train_step = train_utils.train_step

    optimizer = optim.Adam(model.parameters(), 
                       lr=config['lr'], 
                       weight_decay=config['weight_decay'])
    
    
    parameters = {
        'EPOCHS': 500,
        'model': model, 
        'train_loader': train_loader, 
        'val_loader': val_loader, 
        'loss_function': loss_func,
        'train_step': train_step,
        'optimizer': optimizer, 
        'device': 'cuda', 
        'raw_test_data': test_dataset_prep,
        'show_info': 20, 
        'num_losses': 5,
        'labels': labels_roi,
        'inference_function': inference.model_inference_function, 
        'to_many': config['to_many']
    }



    path_to_save_wandb = 'common/koval_alvi/Checkpoints/wandb_brain'
    
    
    with wandb.init(project="eeg_fmri", config=config, save_code=True):
        
        wandb.define_metric("val/corr_mean", summary="max")

        if i == 0: 
            exp_name = wandb.run.name
        
        wandb.run.name = exp_name +'_run_' + str(i)
        
        print(config)
        print(parameters['model'])
        print(summary(model, torch.zeros(4, config['n_channels'],
                                         len(config['freqs']), config['WINDOW_SIZE']), show_input=False))
        
        model = train_utils.wanb_train_regression(**parameters)
        

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkoval_alvi[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.15 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


{'n_electrodes': 30, 'n_freqs': 16, 'n_channels_out': 6, 'channels': [128, 128, 128, 128], 'kernel_sizes': [3, 3, 3], 'strides': [8, 4, 4], 'dilation': [1, 1, 1], 'decoder_reduce': 4, 'dataset_name': 'CWL', 'new_fps': 100, 'freqs': array([ 2.        ,  2.59420132,  3.36494024,  4.3646662 ,  5.6614114 ,
        7.34342046,  9.52515552, 12.3550855 , 16.02578954, 20.78706217,
       26.96291204, 34.97361097, 45.36429384, 58.84205542, 76.32406886,
       99.        ]), 'n_channels': 30, 'n_roi': 6, 'bold_delay': 5, 'to_many': True, 'random_subsample': True, 'sample_per_epoch': 2048, 'WINDOW_SIZE': 2048, 'optimizer': 'adam', 'lr': 0.0003, 'weight_decay': 0, 'batch_size': 32, 'preproc_type': 'dB_log', 'loss_function': 'mse_corr', 'model_type': '1D_CNN_AE_wav2vec2', 'patients': ['trio1', 'trio2', 'trio3', 'trio4']}
AutoEncoder1D(
  (spatial_reduce): ConvBlock(
    (conv1d): Conv1d(480, 128, kernel_size=(3,), stride=(1,), padding=same, bias=False)
    (norm): LayerNorm((128,), eps=1e-05, eleme

wandb: Network error (ReadTimeout), entering retry loop.


........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

VBox(children=(Label(value=' 0.59MB of 0.59MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train/loss_0,█▅▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss_1,▁▄▆▆▇▇▇▇████████████████████████████████
train/loss_2,█▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss_3,▇█▆▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/corr_mean,▁▆▇█
val/loss_0,▅▁▂▃▄█▅▄▅▆▆▇▆▇▆██▇▆█▇▆▆▇▆▆▆▇▆▆▆▇▅▅▅▅▆▅▅▅
val/loss_1,▅█▇▇▅▁▄▅▄▃▃▂▃▂▃▁▁▂▃▁▂▂▂▂▂▂▂▂▃▃▃▂▄▄▄▃▂▄▄▄
val/loss_2,▅▁▄▅▃▅▄▅▄▅▅▆▄▆▅▆█▄▃▅▅▄▄▃▄▅▄▄▃▅▂▄▄▃▃▃▄▄▃▃
val/loss_3,▇█▅▂▂▂▂▂▂▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁

0,1
train/loss_0,-0.92823
train/loss_1,0.94014
train/loss_2,0.11907
train/loss_3,1.9383
val/loss_0,-0.19257
val/loss_1,0.33282
val/loss_2,1.40254
val/loss_3,4.9815


[34m[1mwandb[0m: wandb version 0.12.15 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


{'n_electrodes': 30, 'n_freqs': 16, 'n_channels_out': 6, 'channels': [128, 128, 128, 128], 'kernel_sizes': [3, 3, 3], 'strides': [8, 4, 4], 'dilation': [1, 1, 1], 'decoder_reduce': 4, 'dataset_name': 'CWL', 'new_fps': 100, 'freqs': array([ 2.        ,  2.59420132,  3.36494024,  4.3646662 ,  5.6614114 ,
        7.34342046,  9.52515552, 12.3550855 , 16.02578954, 20.78706217,
       26.96291204, 34.97361097, 45.36429384, 58.84205542, 76.32406886,
       99.        ]), 'n_channels': 30, 'n_roi': 6, 'bold_delay': 5, 'to_many': True, 'random_subsample': True, 'sample_per_epoch': 2048, 'WINDOW_SIZE': 2048, 'optimizer': 'adam', 'lr': 0.0003, 'weight_decay': 0, 'batch_size': 32, 'preproc_type': 'dB_log', 'loss_function': 'mse_corr', 'model_type': '1D_CNN_AE_wav2vec2', 'patients': ['trio1', 'trio2', 'trio3', 'trio4']}
AutoEncoder1D(
  (spatial_reduce): ConvBlock(
    (conv1d): Conv1d(480, 128, kernel_size=(3,), stride=(1,), padding=same, bias=False)
    (norm): LayerNorm((128,), eps=1e-05, eleme

wandb: Network error (ReadTimeout), entering retry loop.


........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

wandb: Network error (ReadTimeout), entering retry loop.


........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

wandb: Network error (ReadTimeout), entering retry loop.


........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

VBox(children=(Label(value=' 0.34MB of 0.34MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train/loss_0,█▅▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss_1,▁▄▅▆▇▇▇▇████████████████████████████████
train/loss_2,█▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss_3,▆█▆▅▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/corr_mean,▁█
val/loss_0,▃▄▂▃▂▁▂▂▂▅▅▃▄▃▂▅▄▄▄▃▃▅▄▃▅▃▅▅▆▇▆▆▆▅▇█▇█▇▆
val/loss_1,▆▅▇▆▇█▇▇▆▄▄▆▅▆▇▄▄▅▄▆▅▄▄▆▄▅▄▄▃▂▃▃▃▃▂▁▁▁▁▂
val/loss_2,▃▄▂▄▂▁▂▁▂▅▅▃▄▂▁▄▄▃▄▂▃▄▃▃▅▃▅▆▄▆▅▆▅▅▆█▆▇▇▅
val/loss_3,██▇▆▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train/loss_0,-0.92893
train/loss_1,0.94072
train/loss_2,0.11795
train/loss_3,1.92469
val/loss_0,-0.00535
val/loss_1,0.17515
val/loss_2,1.69797
val/loss_3,4.83262


[34m[1mwandb[0m: wandb version 0.12.15 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


{'n_electrodes': 30, 'n_freqs': 16, 'n_channels_out': 6, 'channels': [128, 128, 128, 128], 'kernel_sizes': [3, 3, 3], 'strides': [8, 4, 4], 'dilation': [1, 1, 1], 'decoder_reduce': 4, 'dataset_name': 'CWL', 'new_fps': 100, 'freqs': array([ 2.        ,  2.59420132,  3.36494024,  4.3646662 ,  5.6614114 ,
        7.34342046,  9.52515552, 12.3550855 , 16.02578954, 20.78706217,
       26.96291204, 34.97361097, 45.36429384, 58.84205542, 76.32406886,
       99.        ]), 'n_channels': 30, 'n_roi': 6, 'bold_delay': 5, 'to_many': True, 'random_subsample': True, 'sample_per_epoch': 2048, 'WINDOW_SIZE': 2048, 'optimizer': 'adam', 'lr': 0.0003, 'weight_decay': 0, 'batch_size': 32, 'preproc_type': 'dB_log', 'loss_function': 'mse_corr', 'model_type': '1D_CNN_AE_wav2vec2', 'patients': ['trio1', 'trio2', 'trio3', 'trio4']}
AutoEncoder1D(
  (spatial_reduce): ConvBlock(
    (conv1d): Conv1d(480, 128, kernel_size=(3,), stride=(1,), padding=same, bias=False)
    (norm): LayerNorm((128,), eps=1e-05, eleme

wandb: Network error (ReadTimeout), entering retry loop.


........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

wandb: Network error (ReadTimeout), entering retry loop.


........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

# 