In [None]:
import os, sys, random
import numpy as np
import torch
import torch.nn as nn
from time import time
from tqdm import tqdm, trange
from termcolor import cprint
# import wandb
import matplotlib.pyplot as plt

from omegaconf import DictConfig, open_dict
import hydra
from hydra.utils import get_original_cwd

from constants import device
from torch.utils.data import DataLoader, RandomSampler, BatchSampler

from meg_decoding.models import get_model, Classifier
from meg_decoding.utils.get_dataloaders import get_dataloaders, get_samplers
from meg_decoding.utils.loss import *
from meg_decoding.dataclass.god import GODDatasetBase, GODCollator
from meg_decoding.utils.loggers import Pickleogger
from meg_decoding.utils.vis_grad import get_grad
from meg_decoding.matlab_utils.load_meg import get_meg_data, roi, time_window, get_baseline

from hydra import initialize, compose
    with initialize(version_base=None, config_path="../configs/"):
        args = compose(config_name='20230417_sbj01_seq2stat')

In [None]:

DATAROOT = args.data_root
processed_meg_path_pattern = os.path.join(DATAROOT, '{sub}/mat/{name}')
label_path_pattern = os.path.join(DATAROOT, '{sub}/labels/{name}')
trigger_meg_path_pattern = os.path.join(DATAROOT, '{sub}/trigger/{name}')
processed_rest_meg_path_pattern = os.path.join(DATAROOT, '{sub}/mat/{name}')

meg_name, label_name, trigger_name, rest_name = args.subjects[sub][split]['mat'][0], args.subjects[sub][split]['labels'][0], args.subjects[sub][split]['trigger'][0], args.subjects[sub][split]['rest'][0]


processed_meg_path = processed_meg_path_pattern.format(sub=sub, name=meg_name)
label_path = label_path_pattern.format(sub=sub, name=label_name)
trigger_path = trigger_meg_path_pattern.format(sub=sub, name=trigger_name)
processed_rest_meg_path = processed_rest_meg_path_pattern.format(sub=sub, name=rest_name)

MEG_Data, image_features, labels, triggers = get_meg_data(meg_filepath, label_filepath, trigger_filepath,
                 rest_mean=None, rest_std=None, split='train')

roi_ids = roi(args)
data = MEG_Data[roi_ods,:]
# create raw
info = mne.create_info(
    ch_names=len(roi_ids),
    sfreq=1000,
    ch_types='meg',
)
raw = mne.io.RawArray(data, info)

mne.viz.plot_raw_psd(raw)
plt.show()

In [None]:
meg_raw = np.stack([df[key] for key in df.keys() if "MEG" in key])  # ( 224, ~396000 )
# NOTE: (kind of) confirmed that last 16 channels are REF
meg_raw = meg_raw[:num_channels]  # ( 208, ~396000 )

meg_filtered = mne.filter.filter_data(
    meg_raw, sfreq=brain_orig_rate, l_freq=brain_filter_low, h_freq=brain_filter_high,
)

# To 120 Hz
meg_resampled = mne.filter.resample(
    meg_filtered, down=brain_orig_rate / brain_resample_rate,
) 

In [None]:
from meg_decoding.utils.reproducibility import seed_worker
# NOTE: We do need it (IMHO).
if args.reproducible:
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.use_deterministic_algorithms(True)
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    g = torch.Generator()
    g.manual_seed(0)
    seed_worker = seed_worker
else:
    g = None
    seed_worker = None


train_dataset = GODDatasetBase(args, 'train')
val_dataset = GODDatasetBase(args, 'val')
with open_dict(args):
    args.num_subjects = train_dataset.num_subjects
    print('num subject is {}'.format(args.num_subjects))


if args.use_sampler:
    test_size = val_dataset.Y.shape[0]
    train_loader, test_loader = get_samplers(
        train_dataset,
        val_dataset,
        args,
        test_bsz=test_size,
        collate_fn=GODCollator(args),)

else:
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=g,
    )
    test_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        drop_last=True,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=g,
    )
