<a href="https://colab.research.google.com/github/RiccardoMPesce/eeg-fmri-rest-state-classification/blob/main/CNNForfMRIRestState.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import nibabel
import mne

import importlib
import json
import wandb

from glob import glob
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from dataset_utils import *
from dataset import *
from train_utils import *
from basic_models import *

In [2]:
CWL_BASE_PATH = Path("CWLData")
MRI_BASE_PATH = CWL_BASE_PATH / "mri" / "epi_normalized"
EEG_BASE_PATH = CWL_BASE_PATH / "eeg" / "in-scan"
DATASET_BASE_PATH = CWL_BASE_PATH / "dataset"
CHECKPOINT_PATH = CWL_BASE_PATH / "checkpoints"
METRICS_PATH = CWL_BASE_PATH / "metrics"

# Hyperparameters
BATCH_SIZE = 16
LEARNING_RATE = 10 ** (-3)
EPOCHS = 100

DATASET_BASE_PATH.mkdir(exist_ok=True)

In [3]:
# Backend options
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    torch.backends.cudnn.benchmark = True
    print("CUDA available")
    mne.set_config("MNE_USE_CUDA", "True")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("MPS (Metal) available")
else:
    DEVICE = torch.device("cpu")
    print("CPU available")

MPS (Metal) available


In [4]:
dataset = EEGMRIDataset(DATASET_BASE_PATH / "by_interval", use_cwl=True)
dataset_no_cwl = EEGMRIDataset(DATASET_BASE_PATH / "by_interval", use_cwl=False)

In [5]:
splits = make_splits(dataset)

In [6]:
loaders = {split: DataLoader(Splitter(dataset, split_dict = splits, split_name = split), batch_size = BATCH_SIZE, drop_last = True, shuffle = True) for split in ["train", "val", "test"]}
loaders_no_cwl = {split: DataLoader(Splitter(dataset_no_cwl, split_dict = splits, split_name = split), batch_size = BATCH_SIZE, drop_last = True, shuffle = True) for split in ["train", "val", "test"]}

In [7]:
conv1d_base_net_model = Conv1DBaseNet({"in_channels": 38, "num_classes": 2, "verbose": True})
optimizer = torch.optim.Adam(conv1d_base_net_model.parameters(), lr=LEARNING_RATE)

In [8]:
train_loop_eeg(
    conv1d_base_net_model, 
    loaders, 
    optimizer, 
    F.cross_entropy,
    METRICS_PATH / "conv1d_base_net_model_accuracy.json",
    METRICS_PATH / "conv1d_base_net_model_loss.json",
    CHECKPOINT_PATH / "conv1d_base_net_model",
    EPOCHS,
    DEVICE,
    LEARNING_RATE,
    False
)

[34m[1mwandb[0m: Currently logged in as: [33mriccardompesce[0m ([33meeg-fmri-rest-state[0m). Use [1m`wandb login --relogin`[0m to force relogin


Training starting at epoch 1
torch.Size([16, 38, 2000])


: 

: 