<a href="https://colab.research.google.com/github/RiccardoMPesce/eeg-fmri-rest-state-classification/blob/main/CNNForfMRIRestState.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import nibabel
import mne

import importlib

from glob import glob
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch

from torch.utils.data import Dataset, DataLoader

import torch.nn.functional as F

In [4]:
def extract_relevant_markers_from_eeg(eeg_file, kind):
    recording_metadata = list(zip(eeg_file.annotations.description.tolist(), eeg_file.annotations.onset.tolist()))
    clean_recs_markers = [(rec[0], int(rec[1] * 1000)) for rec in recording_metadata if rec[0] in ("eeo", "eec", "beo", "bec", "mri")]
    
    clean_markers = clean_recs_markers[:]

    t_r = 1.95 if kind == "trio" else 2.00

    mri_intervals = []

    last_annotation = None

    for (a, t) in clean_markers:
        if a in ("beo", "eeo", "bec", "eec"):
            last_annotation = a
        else:
            if last_annotation == "beo" and a == "mri":
                mri_intervals += [((t, t + int(t_r * 1000)), "eo")]
            elif last_annotation == "bec" and a == "mri":
                mri_intervals += [((t, t + int(t_r * 1000)), "ec")]
            else:
                pass
    
    return mri_intervals
    


In [5]:
def retrive_times_fmri_cwl(raw):
    """
    Extract information from annottation raw file about fmri frames.
    This information importatn for further interpo
    
    -----
    Input
    Raw is file from EEG set.
    Retrive fMRI time annotation. When occurs recordings in seconds
    It is useful for aligning EEG and fMRI data 
    
    Output: 
    times_fmri - np 
        array of times in ms .
    """
    
    times_fmri = []
    for annot in raw.annotations:
        if annot['description'] == "mri":
            times_fmri.append(annot['onset'])

    times_fmri = np.array(times_fmri)
    times_fmri = times_fmri * 1000  # seconds to milliseconds
    
    return times_fmri

Actual

In [6]:
from pathlib import Path
from glob import glob

import json

import torch

import numpy as np

CWL_BASE_PATH = Path("/content/drive/MyDrive/CWLData/")
MRI_BASE_PATH = CWL_BASE_PATH / "mri" / "epi_normalized"
EEG_BASE_PATH = CWL_BASE_PATH / "eeg" / "in-scan"
DATASET_BASE_PATH = CWL_BASE_PATH / "dataset"
CHECKPOINT_PATH = CWL_BASE_PATH / "checkpoints"
METRICS_PATH = CWL_BASE_PATH / "metrics"

# Hyperparameters
BATCH_SIZE = 16
LEARNING_RATE = 10 ** (-4)

# Backend options
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    torch.backends.cudnn.benchmark = True
    print("CUDA available")
    mne.set_config("MNE_USE_CUDA", "True")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    # BATCH_SIZE = 440
    print("MPS (Metal) available")
else:
    DEVICE = torch.device("cpu")
    print("CPU available")

CPU available


In [7]:
DATASET_BASE_PATH.mkdir(exist_ok=True)

In [8]:
mri_images = sorted([Path(image) for image in glob(str(MRI_BASE_PATH) + "/*.nii")])
eeg_files = sorted([Path(eeg) for eeg in glob(str(EEG_BASE_PATH) + "/*.set") if "mrcorrected" not in eeg])
eeg_files_mrcorrected = sorted([Path(eeg) for eeg in glob(str(EEG_BASE_PATH) + "/*.set") if "mrcorrected" in eeg])

In [9]:
def dump_dataset_to_json(mri_images, eeg_files):
    couples = zip(mri_images, eeg_files)

    for mri_file, eeg_file in couples:
        eeg = mne.io.read_raw_eeglab(eeg_file)
        mri = nibabel.load(mri_file)

        filename = eeg_file.name.replace(".set", "")
        kind = "trio" if "trio" in filename else "verio"

        intervals = extract_relevant_markers_from_eeg(eeg, kind=kind)

        chunk_size = min(len(intervals), mri.get_fdata().shape[-1])

        ds = {
            "eeg": [],
            "fmri": [],
            "label": []
        }

        ds_no_cwl = {
            "eeg": [],
            "fmri": [],
            "label": []
        }

        print(f"Shape of data: {eeg.get_data().shape}")

        eeg_data = eeg.get_data()
        eeg_data_no_cwl = eeg.get_data(picks=[ch for ch in eeg.ch_names if "cw" not in ch.lower()])
        mri_data = mri.get_fdata()
        fmri_data = list(a.tolist() for a in np.array_split(mri_data, mri_data.shape[-1], axis=3))

        for (start, end), label in (intervals if len(intervals) <= chunk_size else intervals[:chunk_size]):
            eeg_chunk = eeg_data[:, start:end].tolist()
            eeg_chunk_no_cwl = eeg_data_no_cwl[:, start:end].tolist()

            ds["label"] += [label]
            ds_no_cwl["label"] += [label]

            ds["eeg"] += [eeg_chunk]
            ds_no_cwl["eeg"] += [eeg_chunk_no_cwl]

        ds["fmri"] = fmri_data if len(fmri_data) < chunk_size else fmri_data[:chunk_size]
        ds_no_cwl["fmri"] = fmri_data if len(fmri_data) < chunk_size else fmri_data[:chunk_size]

        ds_file = DATASET_BASE_PATH / (filename + "_dataset.json")
        ds_file_no_cwl = DATASET_BASE_PATH / (filename + "_no_cwl_dataset.json")

        with open(ds_file, "w") as ds_file:
            json.dump(ds, ds_file)

        with open(ds_file_no_cwl, "w") as ds_file_no_cwl:
            json.dump(ds_no_cwl, ds_file_no_cwl)

        print("Dumped ", filename)
        
        eeg_data = None
        
        eeg = None
        mri = None
        ds = {}


In [10]:
def melt_json(folder_path, out_file_name, use_cwl=True, dump_every=2):
    melted = []

    files = [f for f in Path(folder_path).glob("*dataset*") if "no_cwl" not in f.name] if use_cwl else [f for f in Path(folder_path).glob("*dataset*") if "no_cwl" in f.name]
    
    Path(out_file_name).touch(exist_ok=True)

    for count_processed, f in enumerate(files):
        print(f"Processing {f.name}")

        with open(f, "r") as in_f:
            ds = json.load(in_f)

        with open(out_file_name, "w") as out_f:
            keys = list(ds.keys())

            size = len(keys)

            for i in range(size):
                for key in keys:
                    melted = [{
                        key: ds[key][i]
                    }]

            if count_processed % dump_every == 0:
                json.dump(melted, out_f)

    return len(melted)
        



In [11]:
# melt_json(DATASET_BASE_PATH, DATASET_BASE_PATH / "melted_ds.json")

In [12]:
def dump_json_by_step(mri_files, eeg_files):
    ds = []
    ds_no_cwl = []
    couples = zip(mri_images, eeg_files)
    
    for mri_file, eeg_file in couples:
        eeg = mne.io.read_raw_eeglab(eeg_file)
        mri = nibabel.load(mri_file)

        filename = eeg_file.name.replace(".set", "")
        kind = "trio" if "trio" in filename else "verio"

        intervals = extract_relevant_markers_from_eeg(eeg, kind=kind)

        chunk_size = min(len(intervals), mri.get_fdata().shape[-1])

        print(f"Shape of data: {eeg.get_data().shape}")

        eeg_data = eeg.get_data()
        eeg_data_no_cwl = eeg.get_data(picks=[ch for ch in eeg.ch_names if "cw" not in ch.lower()])
        mri_data = mri.get_fdata()
        fmri_data = list(a.tolist() for a in np.array_split(mri_data, mri_data.shape[-1], axis=3))

        entry = {}
        entry_no_cwl = {}

        for (start, end), label in (intervals if len(intervals) <= chunk_size else intervals[:chunk_size]):
            eeg_chunk = eeg_data[:, start:end].tolist()
            eeg_chunk_no_cwl = eeg_data_no_cwl[:, start:end].tolist()

            entry["label"] = label
            entry_no_cwl["label"] = label

            entry["eeg"] = eeg_chunk
            entry_no_cwl["eeg"] = eeg_chunk_no_cwl

        entry["fmri"] = fmri_data if len(fmri_data) < chunk_size else fmri_data[:chunk_size]
        entry_no_cwl["fmri"] = fmri_data if len(fmri_data) < chunk_size else fmri_data[:chunk_size]

        ds_file = DATASET_BASE_PATH / "dataset_melted.json"
        ds_file_no_cwl = DATASET_BASE_PATH / "no_cwl_dataset_melted.json"

        ds += [entry]
        ds_no_cwl += [entry_no_cwl]

        with open(ds_file, "w") as ds_file:
            json.dump(ds, ds_file)

        with open(ds_file_no_cwl, "w") as ds_file_no_cwl:
            json.dump(ds_no_cwl, ds_file_no_cwl)

        print("Dumped ", filename)

In [13]:
def dump_dataset_by_interval(eeg_files, mri_files):
    (DATASET_BASE_PATH / "by_interval").mkdir(exist_ok=True)

    couples = zip(eeg_files, mri_files)

    for mri_file, eeg_file in couples:
        eeg = mne.io.read_raw_eeglab(eeg_file)
        mri = nibabel.load(mri_file)

        filename = eeg_file.name.replace(".set", "")
        kind = "trio" if "trio" in filename else "verio"

        intervals = extract_relevant_markers_from_eeg(eeg, kind=kind)

        chunk_size = min(len(intervals), mri.get_fdata().shape[-1])

        eeg_data = eeg.get_data()
        eeg_data_no_cwl = eeg.get_data(picks=[ch for ch in eeg.ch_names if "cw" not in ch.lower()])
        mri_data = mri.get_fdata()
        fmri_data = list(a.tolist() for a in np.array_split(mri_data, mri_data.shape[-1], axis=3))

        for i, ((start, end), label) in enumerate((intervals if len(intervals) <= chunk_size else intervals[:chunk_size])):
            entry = {}
            entry_no_cwl = {}

            eeg_chunk = eeg_data[:, start:end].tolist()
            eeg_chunk_no_cwl = eeg_data_no_cwl[:, start:end].tolist()

            entry["label"] = label
            entry_no_cwl["label"] = label

            entry["eeg"] = eeg_chunk
            entry_no_cwl["eeg"] = eeg_chunk_no_cwl

            entry["fmri"] = fmri_data[i] 
            entry_no_cwl["fmri"] = fmri_data[i] 

            with open(DATASET_BASE_PATH / "by_interval" / f"{filename}_s{start}_e{end}", "w") as ds_file:
                json.dump(entry, ds_file)

            with open(DATASET_BASE_PATH / "by_interval" / f"{filename}_no_cwl_s{start}_e{end}", "w") as ds_file_no_cwl:
                json.dump(entry_no_cwl, ds_file_no_cwl)

In [14]:
print(True)

True


In [15]:
from pathlib import Path

class EEGMRIDataset(Dataset):
    def __init__(self, directory_path, use_cwl=None, he_pump=None):
        self.dataset_files = list(Path(directory_path).glob("*"))

        self.use_cwl = use_cwl
        
        if use_cwl is not None:
            if use_cwl:
                self.dataset_files = [f for f in self.dataset_files if "no_cwl" not in f.name]
            else:
                self.dataset_files = [f for f in self.dataset_files if "no_cwl" in f.name]

        if he_pump is not None:
            if he_pump:
                self.dataset_files = [f for f in self.dataset_files if "hpump-on" in f.name]
            else:
                self.dataset_files = [f for f in self.dataset_files if "hpump-off" in f.name]

    def __len__(self):
        return len(self.dataset_files) 

    def __getitem__(self, idx):
        f_path = self.dataset_files[idx]
        
        with open(f_path, "r") as f_p:
            observation = json.load(f_p)
        
        label = 1 if observation["label"] == "eo" else 0
        fmri = torch.Tensor(observation["fmri"])
        eeg = torch.Tensor(observation["eeg"])

        if eeg.shape[1] < 2000:
            eeg = torch.cat((eeg, torch.zeros((eeg.shape[0], 50))), axis=1)

        if eeg.shape[0] not in (32, 38):
            if self.use_cwl:
                eeg = torch.cat((eeg, torch.zeros((38 - eeg.shape[0], eeg.shape[1]))), axis=0)
            else:
                eeg = torch.cat((eeg, torch.zeros((32 - eeg.shape[0], eeg.shape[1]))), axis=0)

        fmri = fmri.reshape(1, fmri.shape[0], fmri.shape[1], fmri.shape[2])

        return eeg, fmri, label


In [16]:
dataset = EEGMRIDataset(DATASET_BASE_PATH / "by_interval", use_cwl=True)
dataset_no_cwl = EEGMRIDataset(DATASET_BASE_PATH / "by_interval", use_cwl=False)


In [17]:
import random


class Splitter:
    def __init__(self, dataset, split_dict, split_name):
        # Set EEG dataset
        self.dataset = dataset
        # Load split
        self.split_idx = split_dict[split_name]
        # Compute size
        self.size = len(self.split_idx)

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        # Get sample from dataset
        eeg, fmri, label = self.dataset[self.split_idx[idx]]
        # Return
        return eeg, fmri, label


def make_splits(dataset, train_frac=0.9, val_frac=0.05, test_frac=0.05):
    splits = {}
    
    if train_frac + val_frac + test_frac != 1:
        train_frac, val_frac, test_frac = 0.9, 0.05, 0.05

    indices = list(range(len(dataset)))

    random.shuffle(indices)

    for split in ["train", "val", "test"]:
        if split == "train":
            frac = train_frac 
        elif split == "val":
            frac = val_frac
        else:
            frac = test_frac 

        splits[split] = [indices.pop() for _ in range(int(round(len(dataset) * frac)))]
    
    return splits

In [18]:
splits = make_splits(dataset)

In [19]:
loaders = {split: DataLoader(Splitter(dataset, split_dict = splits, split_name = split), batch_size = BATCH_SIZE, drop_last = True, shuffle = True) for split in ["train", "val", "test"]}
loaders_no_cwl = {split: DataLoader(Splitter(dataset_no_cwl, split_dict = splits, split_name = split), batch_size = BATCH_SIZE, drop_last = True, shuffle = True) for split in ["train", "val", "test"]}

In [20]:
from torch import nn, optim, functional
from torch.nn import functional as F

class ConvNet1D(nn.Module):
    def __init__(self, input_channels):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv1d(input_channels, 32, kernel_size=3),
            nn.ReLU(),
            # nn.Dropout(0.5),
            nn.MaxPool1d(5))
        self.layer2 = nn.Sequential(
            nn.Conv1d(32, 64, kernel_size=3, dilation=2),
            nn.ReLU(),
            # nn.Dropout(0.5),
            nn.MaxPool1d(10))
        self.layer3 = nn.Flatten()
        self.layer4 = nn.Sequential(
            nn.Linear(2496, 512),
            nn.Linear(512, 2),
            nn.Softmax())

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        return out


In [21]:
model_conv1d = ConvNet1D(38)

In [22]:
import re

def find_last_version(checkpoint_path):
    logs_path = checkpoint_path / "lightning_logs"
    versions = [p for p in logs_path.glob("*") if (p / "checkpoints").exists()]
    return logs_path / sorted(versions)[-1] if versions != [] else logs_path / ""

def find_last_checkpoint(checkpoint_path):
    checkpoints_path = find_last_version(checkpoint_path) / "checkpoints"
    print(checkpoints_path)
    if list(checkpoints_path.glob("*")) == []:
        return None
    else:
        last_checkpoint = list(checkpoints_path.glob("*"))[0]
    return checkpoints_path / last_checkpoint

In [24]:
def train_loop(model, loaders, optimizer, accuracy_json, loss_json, checkpoint_prefix, n_epochs, debug=False, save_every=2):
    # Initialize training, validation, test losses and accuracy list

    accuracy_json_file_path = Path(accuracy_json)
    loss_json_file_path = Path(loss_json)

    if accuracy_json_file_path.exists() and accuracy_json_file_path.is_file():
        with open(accuracy_json_file_path, "r+") as accuracy_json_f:
            accuracies_per_epoch = json.load(accuracy_json_f)
            print(f"Loaded accuracy dictionary at {accuracy_json_file_path}")
    else:
        accuracies_per_epoch = {"train": [], "val": [], "test": []}

    if loss_json_file_path.exists() and loss_json_file_path.is_file():
        with open(loss_json_file_path, "r+") as loss_json_f:
            losses_per_epoch = json.load(loss_json_f)
            print(f"Loaded loss dictionary at {loss_json_file_path}")
    else:
        losses_per_epoch = {"train": [], "val": [], "test": []}

    starting_epoch = 0

    checkpoint_path = Path(checkpoint_prefix) 

    checkpoint_path.mkdir(exist_ok=True)

    # Check for the latest weights
    checkpoints = [checkpoint.name for checkpoint in checkpoint_path.glob("*.pth") if checkpoint.is_file()]

    if len(checkpoints) > 0:
        epochs = [int(s.replace(".pth", "")) for s in checkpoints]

        latest_state_dict_path = checkpoint_path / f"{max(epochs)}.pth"

        # Loading state dict
        model.load_state_dict(torch.load(latest_state_dict_path, map_location=DEVICE))
        print(f"Loaded weights generated at epoch {max(epochs)}")

        starting_epoch = max(epochs)

    best_accuracy = 0
    best_accuracy_val = 0
    best_epoch = 0
    
    predicted_labels = [] 
    correct_labels = []

    print(f"Training starting at epoch {starting_epoch + 1}")
    for epoch in range(starting_epoch + 1, n_epochs + starting_epoch + 1):
        # Initialize loss/accuracy variables
        losses = {"train": 0, "val": 0, "test": 0}
        accuracies = {"train": 0, "val": 0, "test": 0}
        counts = {"train": 0, "val": 0, "test": 0}
        
        # Adjust learning rate for SGD
        """
        if OPTIMIZER == "SGD":
            lr = LEARNING_RATE * (LR_DECAY ** (epoch // LEARNING_RATE_DECAY_EVERY))
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr
        """
        
        # Process each split
        for split in ("train", "val", "test"):
            # Set network mode
            if split == "train":
                model.train()
                torch.set_grad_enabled(True)
            else:
                model.eval()
                torch.set_grad_enabled(False)
            
            # Process all split batches
            for i, (input, _, target) in enumerate(loaders[split]):

                # Move model to device
                model = model.to(DEVICE)
                
                # Move tensors to device
                input = input.to(DEVICE)
                target = target.to(DEVICE)
                
                if debug:
                    print(input.device)

                # Forward
                output = model(input.squeeze())

                # Compute loss
                loss = F.cross_entropy(output, target)
                losses[split] += loss.item()
                
                # Compute accuracy
                _, pred = output.data.max(1)
                correct = pred.eq(target.data).sum().item()
                accuracy = correct / input.data.size(0)   
                accuracies[split] += accuracy
                counts[split] += 1
                
                # Backward and optimize
                if split == "train":
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
        
        # Print info at the end of the epoch
        if accuracies["val"] / counts["val"] >= best_accuracy_val:
            best_accuracy_val = accuracies["val"] / counts["val"]
            best_accuracy = accuracies["test"] / counts["test"]
            best_epoch = epoch

        train_loss = losses["train"] / counts["train"]
        train_accuracy = accuracies["train"] / counts["train"]
        validation_loss = losses["val"] / counts["val"]
        validation_accuracy = accuracies["val"] / counts["val"]
        test_loss = losses["test"] / counts["test"]
        test_accuracy = accuracies["test"] / counts["test"]

        print("INFO")
        print(f"- Model: {model.__class__.__name__} - epoch {epoch}")
        print("STATS")
        print(f"- Training: Loss {train_loss:.4f}, Accuracy {train_accuracy:.4f} " 
            f"- Validation: Loss {validation_loss:.4f}, Accuracy {validation_accuracy:.4f} "
            f"- Test: Loss {test_loss:.4f}, Accuracy {test_accuracy:.4f}")
        print(f"Best Test Accuracy at maximum Validation Accuracy (validation_accuracy = {best_accuracy_val:.4f}) is {best_accuracy:.4f} at epoch {best_epoch}\n")

        losses_per_epoch["train"].append(train_loss)
        losses_per_epoch["val"].append(validation_loss)
        losses_per_epoch["test"].append(test_loss)
        accuracies_per_epoch["train"].append(train_accuracy)
        accuracies_per_epoch["val"].append(validation_accuracy)
        accuracies_per_epoch["test"].append(test_accuracy)

        if epoch % save_every == 0:
            torch.save(model.state_dict(), checkpoint_path / f"{epoch}.pth")
            
            with open(accuracy_json_file_path, "w+") as accuracy_json_f:
                json.dump(accuracies_per_epoch, accuracy_json_f)
            
            with open(loss_json_file_path, "w+") as loss_json_f:
                json.dump(losses_per_epoch, loss_json_f)

    # At the end of training, save
    torch.save(model.state_dict(), checkpoint_path / f"{epoch}.pth")

In [25]:
conv1d_optimizer = torch.optim.Adam(model_conv1d.parameters(), lr=LEARNING_RATE)
train_loop(model_conv1d, loaders, conv1d_optimizer, METRICS_PATH / "conv1d_accuracies.json", METRICS_PATH / "conv1d_loss.json", CHECKPOINT_PATH / "conv1d_checkpoints", 52)

Loaded accuracy dictionary at /content/drive/MyDrive/CWLData/metrics/conv1d_accuracies.json
Loaded loss dictionary at /content/drive/MyDrive/CWLData/metrics/conv1d_loss.json
Loaded weights generated at epoch 46
Training starting at epoch 47


  input = module(input)


INFO
- Model: ConvNet1D - epoch 47
STATS
- Training: Loss 0.6887, Accuracy 0.5492 - Validation: Loss 0.7057, Accuracy 0.3750 - Test: Loss 0.6866, Accuracy 0.5833
Best Test Accuracy at maximum Validation Accuracy (validation_accuracy = 0.3750) is 0.5833 at epoch 47

INFO
- Model: ConvNet1D - epoch 48
STATS
- Training: Loss 0.6878, Accuracy 0.5530 - Validation: Loss 0.7055, Accuracy 0.3750 - Test: Loss 0.6904, Accuracy 0.5417
Best Test Accuracy at maximum Validation Accuracy (validation_accuracy = 0.3750) is 0.5417 at epoch 48

INFO
- Model: ConvNet1D - epoch 49
STATS
- Training: Loss 0.6879, Accuracy 0.5530 - Validation: Loss 0.7017, Accuracy 0.3958 - Test: Loss 0.6892, Accuracy 0.5625
Best Test Accuracy at maximum Validation Accuracy (validation_accuracy = 0.3958) is 0.5625 at epoch 49

INFO
- Model: ConvNet1D - epoch 50
STATS
- Training: Loss 0.6880, Accuracy 0.5521 - Validation: Loss 0.7023, Accuracy 0.4167 - Test: Loss 0.6923, Accuracy 0.5208
Best Test Accuracy at maximum Validation

In [1]:
model_conv1d_no_cwl = ConvNet1D(32)
conv1d_no_cwl_optimizer = torch.optim.Adam(model_conv1d_no_cwl.parameters(), lr=LEARNING_RATE)
train_loop(model_conv1d_no_cwl, loaders_no_cwl, conv1d_no_cwl_optimizer, METRICS_PATH / "conv1d_no_cwl_accuracies.json", METRICS_PATH / "conv1d_no_cwl_loss.json", CHECKPOINT_PATH / "conv1d_no_cwl_checkpoints", 100)

NameError: ignored

In [None]:
class ConvNet3D(nn.Module):
    def __init__(self, input_channels):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv3d(input_channels, 32, kernel_size=3),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.MaxPool1d(5))
        self.layer2 = nn.Sequential(
            nn.Conv3d(32, 64, kernel_size=3, dilation=2),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.MaxPool1d(10))
        self.layer3 = nn.Flatten()
        self.layer4 = nn.Sequential(
            nn.Linear(2496,2),
            nn.Softmax())

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        return out

In [None]:
def train_loop_fmri(model, loaders, optimizer, accuracy_json, loss_json, checkpoint_prefix, n_epochs, debug=False, save_every=2):
    # Initialize training, validation, test losses and accuracy list

    accuracy_json_file_path = Path(accuracy_json)
    loss_json_file_path = Path(loss_json)

    if accuracy_json_file_path.exists() and accuracy_json_file_path.is_file():
        with open(accuracy_json_file_path, "r+") as accuracy_json_f:
            accuracies_per_epoch = json.load(accuracy_json_f)
            print(f"Loaded accuracy dictionary at {accuracy_json_file_path}")
    else:
        accuracies_per_epoch = {"train": [], "val": [], "test": []}

    if loss_json_file_path.exists() and loss_json_file_path.is_file():
        with open(loss_json_file_path, "r+") as loss_json_f:
            losses_per_epoch = json.load(loss_json_f)
            print(f"Loaded loss dictionary at {loss_json_file_path}")
    else:
        losses_per_epoch = {"train": [], "val": [], "test": []}

    starting_epoch = 0

    checkpoint_path = Path(checkpoint_prefix) 

    checkpoint_path.mkdir(exist_ok=True)

    # Check for the latest weights
    checkpoints = [checkpoint.name for checkpoint in checkpoint_path.glob("*.pth") if checkpoint.is_file()]

    if len(checkpoints) > 0:
        epochs = [int(s.replace(".pth", "")) for s in checkpoints]

        latest_state_dict_path = checkpoint_path / f"{max(epochs)}.pth"

        # Loading state dict
        model.load_state_dict(torch.load(latest_state_dict_path, map_location=DEVICE))
        print(f"Loaded weights generated at epoch {max(epochs)}")

        starting_epoch = max(epochs)

    best_accuracy = 0
    best_accuracy_val = 0
    best_epoch = 0
    
    predicted_labels = [] 
    correct_labels = []

    print(f"Training starting at epoch {starting_epoch + 1}")
    for epoch in range(starting_epoch + 1, n_epochs + starting_epoch + 1):
        # Initialize loss/accuracy variables
        losses = {"train": 0, "val": 0, "test": 0}
        accuracies = {"train": 0, "val": 0, "test": 0}
        counts = {"train": 0, "val": 0, "test": 0}
        
        # Adjust learning rate for SGD
        """
        if OPTIMIZER == "SGD":
            lr = LEARNING_RATE * (LR_DECAY ** (epoch // LEARNING_RATE_DECAY_EVERY))
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr
        """
        
        # Process each split
        for split in ("train", "val", "test"):
            # Set network mode
            if split == "train":
                model.train()
                torch.set_grad_enabled(True)
            else:
                model.eval()
                torch.set_grad_enabled(False)
            
            # Process all split batches
            for i, (_, input, target) in enumerate(loaders[split]):

                # Move model to device
                model = model.to(DEVICE)
                
                # Move tensors to device
                input = input.to(DEVICE)
                target = target.to(DEVICE)
                
                if debug:
                    print(input.device)

                # Forward
                output = model(input.squeeze())

                # Compute loss
                loss = F.cross_entropy(output, target)
                losses[split] += loss.item()
                
                # Compute accuracy
                _, pred = output.data.max(1)
                correct = pred.eq(target.data).sum().item()
                accuracy = correct / input.data.size(0)   
                accuracies[split] += accuracy
                counts[split] += 1
                
                # Backward and optimize
                if split == "train":
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
        
        # Print info at the end of the epoch
        if accuracies["val"] / counts["val"] >= best_accuracy_val:
            best_accuracy_val = accuracies["val"] / counts["val"]
            best_accuracy = accuracies["test"] / counts["test"]
            best_epoch = epoch

        train_loss = losses["train"] / counts["train"]
        train_accuracy = accuracies["train"] / counts["train"]
        validation_loss = losses["val"] / counts["val"]
        validation_accuracy = accuracies["val"] / counts["val"]
        test_loss = losses["test"] / counts["test"]
        test_accuracy = accuracies["test"] / counts["test"]

        print("INFO")
        print(f"- Model: {model.__class__.__name__}")
        print("STATS")
        print(f"- Training: Loss {train_loss:.4f}, Accuracy {train_accuracy:.4f} " 
            f"- Validation: Loss {validation_loss:.4f}, Accuracy {validation_accuracy:.4f} "
            f"- Test: Loss {test_loss:.4f}, Accuracy {test_accuracy:.4f}")
        print(f"Best Test Accuracy at maximum Validation Accuracy (validation_accuracy = {best_accuracy_val:.4f}) is {best_accuracy:.4f} at epoch {best_epoch}\n")

        losses_per_epoch["train"].append(train_loss)
        losses_per_epoch["val"].append(validation_loss)
        losses_per_epoch["test"].append(test_loss)
        accuracies_per_epoch["train"].append(train_accuracy)
        accuracies_per_epoch["val"].append(validation_accuracy)
        accuracies_per_epoch["test"].append(test_accuracy)

        if epoch % save_every == 0:
            torch.save(model.state_dict(), checkpoint_path / f"{epoch}.pth")
            
            with open(accuracy_json_file_path, "w+") as accuracy_json_f:
                json.dump(accuracies_per_epoch, accuracy_json_f)
            
            with open(loss_json_file_path, "w+") as loss_json_f:
                json.dump(losses_per_epoch, loss_json_f)

    # At the end of training, save
    torch.save(model.state_dict(), checkpoint_path / f"{epoch}.pth")

In [34]:
model_conv3d = ConvNet3D(1)
conv3d_optimizer = torch.optim.Adam(model_conv3d.parameters(), lr=LEARNING_RATE)
train_loop_fmri(model_conv3d, loaders, conv3d_optimizer, METRICS_PATH / "conv3d_accuracies.json", METRICS_PATH / "conv3d_loss.json", CHECKPOINT_PATH / "conv3d_checkpoints", 100)

'ConvNet1D'