# Head Files

In [3]:
import torch
import torchvision
# from torch.utils.tensorboard import SummaryWriter
import time
from torch import nn
from torch.utils.data import DataLoader
from load_data import MyData  # self-made
from torchvision import transforms
from tqdm import tqdm_notebook as tqdm # View procedure
import os
import scipy.io
from random import random
import numpy as np
import gc
from torch.utils.tensorboard import SummaryWriter
from network_cnn import MyNetwork

# 1. Prepare datasets

In [19]:
hc_path = f"../eegmap_chunks/rest/hc/hc_rest.pt"
mcs_path = f"../eegmap_chunks/rest/mcs/mcs_rest.pt"
uws_path = f"../eegmap_chunks/rest/uws/uws_rest.pt"
hc = torch.load(hc_path)
mcs = torch.load(mcs_path)
uws = torch.load(uws_path)
print(hc.size())
print(mcs.size())
print(uws.size())

torch.Size([852, 2400, 10, 11])
torch.Size([1051, 2400, 10, 11])
torch.Size([872, 2400, 10, 11])


## 1.1 Map Data

In [12]:
# del hc_indices
# del train_hc
# del test_hc
gc.collect()
torch.cuda.empty_cache()

# generate random seed
torch.manual_seed(32)
# generate random index
hc_indices = torch.randperm(len(hc))
mcs_indices = torch.randperm(len(mcs))
uws_indices = torch.randperm(len(uws))
# split the datasets
len_hc = len(hc)
train_size = int(0.8*len_hc)
train_hc = hc[hc_indices[:train_size]]
test_hc = hc[hc_indices[train_size:]]
torch.save(train_hc, "../eegmap_chunks/rest/hc/rest_hc_train.pt")
torch.save(test_hc, "../eegmap_chunks/rest/hc/rest_hc_test.pt")
del hc_indices
del train_hc
del test_hc
gc.collect()
torch.cuda.empty_cache()

len_mcs = len(mcs)
train_size = int(0.8*len_mcs)
train_mcs = mcs[mcs_indices[:train_size]]
test_mcs = mcs[mcs_indices[train_size:]]
torch.save(train_mcs, "../eegmap_chunks/rest/mcs/rest_mcs_train.pt")
torch.save(test_mcs, "../eegmap_chunks/rest/mcs/rest_mcs_test.pt")
del mcs_indices
del train_mcs
del test_mcs
gc.collect()
torch.cuda.empty_cache()

len_uws = len(uws)
train_size = int(0.8*len_uws)
train_uws = uws[uws_indices[:train_size]]
test_uws = uws[uws_indices[train_size:]]
torch.save(train_uws, "../eegmap_chunks/rest/uws/rest_uws_train.pt")
torch.save(test_uws, "../eegmap_chunks/rest/uws/rest_uws_test.pt")
del uws_indices
del train_uws
del test_uws
gc.collect()
torch.cuda.empty_cache()

## 1.2 Labels

In [26]:
data_mode = "test"
hc_path = f"../eegmap_chunks/rest/hc/rest_hc_{data_mode}.pt"
mcs_path = f"../eegmap_chunks/rest/mcs/rest_mcs_{data_mode}.pt"
uws_path = f"../eegmap_chunks/rest/uws/rest_uws_{data_mode}.pt"
hc=torch.load(hc_path)
print(hc.size())
mcs = torch.load(mcs_path)
print(mcs.size())
uws=torch.load(uws_path)
print(uws.size())

torch.Size([171, 2400, 10, 11])
torch.Size([211, 2400, 10, 11])
torch.Size([175, 2400, 10, 11])


In [27]:
label_hc = torch.zeros(len(hc))
label_mcs = torch.ones(len(mcs))
label_uws = torch.ones(len(uws)) * 2
torch.save(label_hc, f"../eegmap_chunks/rest/hc/rest_hc_{data_mode}_label.pt")
torch.save(label_mcs, f"../eegmap_chunks/rest/mcs/rest_mcs_{data_mode}_label.pt")
torch.save(label_uws, f"../eegmap_chunks/rest/uws/rest_uws_{data_mode}_label.pt")

In [38]:
del hc
del mcs
del uws
gc.collect()
torch.cuda.empty_cache()

NameError: name 'hc' is not defined

## 1.3 Trainset and Testset

In [69]:
train_hc = torch.load("../eegmap_chunks/rest/hc/rest_hc_train.pt")
test_hc = torch.load("../eegmap_chunks/rest/hc/rest_hc_test.pt")
train_mcs = torch.load("../eegmap_chunks/rest/mcs/rest_mcs_train.pt")
test_mcs = torch.load("../eegmap_chunks/rest/mcs/rest_mcs_test.pt")
train_uws = torch.load("../eegmap_chunks/rest/uws/rest_uws_train.pt")
test_uws = torch.load("../eegmap_chunks/rest/uws/rest_uws_test.pt")
train_hc_label = torch.load("../eegmap_chunks/rest/hc/rest_hc_train_label.pt")
test_hc_label = torch.load("../eegmap_chunks/rest/hc/rest_hc_test_label.pt")
train_mcs_label = torch.load("../eegmap_chunks/rest/mcs/rest_mcs_train_label.pt")
test_mcs_label = torch.load("../eegmap_chunks/rest/mcs/rest_mcs_test_label.pt")
train_uws_label = torch.load("../eegmap_chunks/rest/uws/rest_uws_train_label.pt")
test_uws_label = torch.load("../eegmap_chunks/rest/uws/rest_uws_test_label.pt")

In [70]:
print(train_hc.size())
print(train_mcs.size())
print(train_uws.size())
train_data = []
train_label = []
test_data = []
test_label = []
print(train_hc[0].size())
# train dataset
for i in tqdm(range(len(train_hc))):
    train_data.append(train_hc[i])
    train_label.append(train_hc_label[i])
for i in tqdm(range(len(train_mcs))):
    train_data.append(train_mcs[i])
    train_label.append(train_mcs_label[i])
for i in tqdm(range(len(train_uws))):
    train_data.append(train_uws[i])
    train_label.append(train_uws_label[i])
print(len(train_data))
print(len(train_label))
# test dataset
for i in tqdm(range(len(test_hc))):
    test_data.append(test_hc[i])
    test_label.append(test_hc_label[i])
for i in tqdm(range(len(test_mcs))):
    test_data.append(test_mcs[i])
    test_label.append(test_mcs_label[i])
for i in tqdm(range(len(test_uws))):
    test_data.append(test_uws[i])
    test_label.append(test_uws_label[i])
print(len(test_data))
print(len(test_label))

torch.Size([681, 2400, 10, 11])
torch.Size([840, 2400, 10, 11])
torch.Size([697, 2400, 10, 11])
torch.Size([2400, 10, 11])


HBox(children=(IntProgress(value=0, max=681), HTML(value='')))




HBox(children=(IntProgress(value=0, max=840), HTML(value='')))




HBox(children=(IntProgress(value=0, max=697), HTML(value='')))


2218
2218


HBox(children=(IntProgress(value=0, max=171), HTML(value='')))




HBox(children=(IntProgress(value=0, max=211), HTML(value='')))




HBox(children=(IntProgress(value=0, max=175), HTML(value='')))


557
557


In [71]:
torch.save(torch.stack(train_data),f"../eegmap_chunks/rest/train_data.pt")
torch.save(torch.stack(train_label),f"../eegmap_chunks/rest/train_label.pt")
torch.save(torch.stack(test_data),f"../eegmap_chunks/rest/test_data.pt")
torch.save(torch.stack(test_label),f"../eegmap_chunks/rest/test_label.pt")

In [72]:
del train_hc
del test_hc
del train_mcs
del test_mcs
del train_uws
del test_uws
del train_hc_label
del test_hc_label
del train_mcs_label
del test_mcs_label
del train_uws_label
del test_uws_label
del train_data
del train_label
del test_data
del test_label
gc.collect()
torch.cuda.empty_cache()

# 2. Train the model

## Load data and labels via .pt files

In [4]:
train_data = torch.load(f"../eegmap_chunks/rest/train_data.pt")
train_label = torch.load(f"../eegmap_chunks/rest/train_label.pt")
test_data = torch.load(f"../eegmap_chunks/rest/test_data.pt")
test_label = torch.load(f"../eegmap_chunks/rest/test_label.pt")

In [5]:
print(train_data.size())
print(train_label.size())
print(test_data.size())
print(test_label.size())

torch.Size([2218, 2400, 10, 11])
torch.Size([2218])
torch.Size([557, 2400, 10, 11])
torch.Size([557])


## Hyperparameters and Related parameters

In [6]:
BATCH_SIZE = 64
C,H,W = 1,1,2400
learn_rate = 0.001
num_epochs = 100

In [7]:
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
# train dataset
train_td = TensorDataset(train_data, train_label)
train_loader = DataLoader(train_td, batch_size = BATCH_SIZE, shuffle = True)
# test dataset
test_td = TensorDataset(test_data, test_label)
test_loader = DataLoader(test_td, batch_size = BATCH_SIZE, shuffle = True)

In [8]:
del train_data
del train_label
del test_data
del test_label
del train_td
del test_td
gc.collect()
torch.cuda.empty_cache()

## Ensuring deterministicity through Random seeding

In [9]:
import random
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
manualSeed = 4
random.seed(manualSeed)
torch.manual_seed(manualSeed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(manualSeed)

# Test Model Validation

In [89]:
# 原始输入矩阵
input_matrix = torch.randn(10, 11, 1, 1, 2400)

# 重塑为 (110, 1, 1, 2400)
reshaped_input = torch.reshape(input_matrix, (110, 1, 1, 2400))

# 创建 MyNetwork 实例
net = MyNetwork()

# 将重塑后的输入传入网络进行处理
output = net(reshaped_input)

# 输出的形状
print("Output shape:", output.shape)
print(MyNetwork())

Output shape: torch.Size([1, 30])
MyNetwork(
  (conv1): Sequential(
    (0): Conv2d(1, 64, kernel_size=(1, 9), stride=(1, 3), padding=(0, 4))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
    (7): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
  )
  (conv2_1): Sequential(
    (0): Conv2d(16, 32, kernel_size=(1, 9), stride=(1, 2), padding=(0, 4))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (conv2_2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(1, 13), stride=(1, 2), padding=(0, 6))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, 

In [108]:
del input_matrix
del reshaped_input
del net
del output
gc.collect()
torch.cuda.empty_cache()

## Setting the optimizer and Loss function

In [19]:
import torch.optim as optim
device = torch.device("cuda:0")
gc.collect()
torch.cuda.empty_cache()
model = MyNetwork()
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr = learn_rate)# Try out weight decay , 
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[90,190,450], gamma=0.1)

criterion = nn.SmoothL1Loss()
criterion = criterion.to(device)

## Initializing Tensorboard

In [95]:
# add Tensorboard
writer = SummaryWriter("../logs_train")


## Training

In [20]:
start_time = time.time()
# train and test step records
total_train_step = 0
total_test_step = 0
min_test_loss = 1000

for i in tqdm(range(num_epochs)):    
    # train steps
    model.train()
    for data in train_loader:
        data_map, label=data
        data_map = torch.reshape(data_map, (110, 1, BATCH_SIZE, 2400))
        data_map = data_map.to(torch.float32)
        label = label.to(torch.float32)
        gc.collect()
        torch.cuda.empty_cache()
        data_map=data_map.to(device)
        label=label.to(device)
        label_pred = model(data_map)
        
        # Loss Computation and Optimization
        loss=criterion(label,label_pred)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # draw tensorboard
        total_train_step = total_train_step + 1
        # print info
#         if total_train_step % 100 == 0:
        end_time = time.time()
        print(f"Train time: {end_time - start_time}")
        print(f"Train steps: {total_train_step}, Loss: {loss.item()}")
        writer.add_scalar("train_loss",loss.item(),total_train_step)
        
        # Clear gpu
        del data
        del data_map
        del label
        del label_pred
        del loss
        gc.collect()
        torch.cuda.empty_cache()
        
    # Evaluation and save the best model
    print(f"========= Epoch {i+1} Testing =========")
    model.eval()
    total_test_loss = 0
    with torch.no_grad():
        for data in test_loader:
            data_map, label=data
            data_map = torch.reshape(data_map, (110, 1, 64, 2400))
            data_map = data_map.to(torch.float32)
            label = label.to(torch.float32)
            data_map = data_map.to(device)
            label = label.to(device)
            label_pred_test = model(data_map)
            loss = criterion(label_pred_test,label)
            # draw tensorboad
            total_test_loss = total_test_loss + loss
            # Clear gpu
            del data_map
            del label
            del label_pred_test
            del loss
            del data
            gc.collect()
            torch.cuda.empty_cache()
    print(f"Total Loss: {total_test_loss}")
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    total_test_step = total_test_step + 1
    if total_test_loss < min_test_loss:
        min_test_loss = total_test_loss
        print("..........Saving the model..........")
        torch.save(model.state_dict(),f"../model/Epoch{i+1}.pt")

HBox(children=(IntProgress(value=0), HTML(value='')))

RuntimeError: CUDA out of memory. Tried to allocate 344.00 MiB (GPU 0; 8.00 GiB total capacity; 7.16 GiB already allocated; 0 bytes free; 7.17 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [26]:
# Clear gpu
del optimizer
del scheduler
del criterion
del model
del data_map
del label
del label_pred
del label_pred_test
del loss
del data

NameError: name 'model' is not defined

In [27]:
gc.collect()
torch.cuda.empty_cache()