## Train Network Basics

## How to use the dataset class

Loading a dataset requires the data to be in the correct format (see Prepare data tutorial). Just create the dataset object and use the load method:

In [1]:
import logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%m/%d/%Y %I:%M:%S %p",
)

from meegnet.dataloaders import ContinuousDataset

data_path = "/home/arthur/data/camcan/smt"

# use Dataset class for data that has already been cut into trials
# else, use RestDataset with additional parameters of window and overlap to create trials.
dataset = ContinuousDataset(
    window = .8,
    overlap = 0,
    sfreq=500, # sampling frequency of 500 Hz
    n_subjects=20, # only load 100 subjects
    n_samples=10, # limit the number of samples for each subject to 100
    sensortype="ALL", # only use gradiometers
    lso=True, # do not use leave subject oout for data splits
)

dataset.load(data_path)

10/18/2024 08:35:40 AM Logging subjects and labels from /home/arthur/data/camcan/smt...
10/18/2024 08:35:40 AM Found 20 subjects to load.


We have loaded 100 subjects of the resting-state dataset located in data_path. There are 100 examples per subject so 10000 data examples total. With only gradiometers selected with sensors="GRAD", we only have 2 channels. The length of each time segment is 4 seconds at 200Hz which is why they are 800 time points of size.

## How to use the network class

Create the model object instance of the Model class and then use the train method with the dataset previously created.

In [2]:
from meegnet.network import Model
from torch.nn import MSELoss

save_path = data_path
net_option = "eegnet"
input_size = dataset.data[0].shape
n_outputs = 2 # Here we have 100 possible outputs as we have 1 label per subject and 100 subjects
name = "smt_meegnet"

net_params = {"linear": 100, "hlayers": 3, "dropout": .5}
my_model = Model(name, net_option, input_size, n_outputs, save_path, net_params=net_params)

print(my_model.net)

my_model.train(dataset)

10/18/2024 08:36:10 AM Creating DataLoaders...
10/18/2024 08:36:10 AM Starting Training with:
10/18/2024 08:36:10 AM batch size: 128
10/18/2024 08:36:10 AM learning rate: 1e-05
10/18/2024 08:36:10 AM patience: 20


EEGNet(
  (feature_extraction): Sequential(
    (0): Conv2d(3, 16, kernel_size=(1, 64), stride=(1, 1), padding=(1, 32), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): DepthwiseConv2d(
      (depthwise): Conv2d(16, 32, kernel_size=(102, 1), stride=(1, 1), groups=16, bias=False)
    )
    (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ELU(alpha=1.0)
    (5): AvgPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0)
    (6): Dropout(p=0.5, inplace=False)
    (7): SeparableConv2d(
      (depthwise): DepthwiseConv2d(
        (depthwise): Conv2d(32, 32, kernel_size=(1, 16), stride=(1, 1), padding=(1, 8), groups=32, bias=False)
      )
      (pointwise): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), padding=(1, 8), bias=False)
    )
    (8): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ELU(alpha=1.0)
    (10): AvgPool2d(kernel_size=(1, 8),

10/18/2024 08:36:21 AM Epoch: 1 // Batch 1/2 // loss = 0.71562
10/18/2024 08:36:23 AM Epoch: 1 // Batch 2/2 // loss = 0.78837
10/18/2024 08:36:38 AM Epoch: 1
10/18/2024 08:36:38 AM  [LOSS] TRAIN 0.7540044474329745 / VALID 0.8416368411741818
10/18/2024 08:36:38 AM  [ACC] TRAIN 0.48125 / VALID 0.30000001192092896
10/18/2024 08:36:48 AM Epoch: 2 // Batch 1/2 // loss = 0.71293
10/18/2024 08:36:51 AM Epoch: 2 // Batch 2/2 // loss = 0.75221
10/18/2024 08:37:03 AM Epoch: 2
10/18/2024 08:37:03 AM  [LOSS] TRAIN 0.7409208513169709 / VALID 0.8747682543030727
10/18/2024 08:37:03 AM  [ACC] TRAIN 0.45 / VALID 0.20000000298023224
10/18/2024 08:37:14 AM Epoch: 3 // Batch 1/2 // loss = 0.75706
10/18/2024 08:37:16 AM Epoch: 3 // Batch 2/2 // loss = 0.70096
10/18/2024 08:37:30 AM Epoch: 3
10/18/2024 08:37:30 AM  [LOSS] TRAIN 0.7284259412796468 / VALID 0.7865380775675327
10/18/2024 08:37:30 AM  [ACC] TRAIN 0.475 / VALID 0.3499999940395355
10/18/2024 08:37:41 AM Epoch: 4 // Batch 1/2 // loss = 0.74296
10/1

It is always possible to access the network inside the Model object if we want to perform single trial predictions for a figure for example:

In [3]:
import numpy as np
import torch

random_sample = 22
data_example = dataset.data[random_sample][np.newaxis] # need to add a new axis to respect expected shapes, not needed if using multiple examples.

pred = my_model.net.forward(torch.Tensor(data_example).cuda())

print(f"predicted label: {np.argmax(pred.detach().cpu().numpy())}, original label: {dataset.labels[random_sample]}")

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.DoubleTensor) should be the same