# Head Files

In [1]:
import torch
import torchvision
# from torch.utils.tensorboard import SummaryWriter
import time
from torch import nn
from torch.utils.data import DataLoader
from load_data import MyData  # self-made
from torchvision import transforms
from tqdm import tqdm_notebook as tqdm # View procedure
import os
import scipy.io
from random import random
import numpy as np
import gc
from torch.utils.tensorboard import SummaryWriter
from network_cnn_lstm import MyNetwork
from torchnlp.word_to_vector import GloVe
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

# 训练模型

## 设置超参数

In [2]:
BATCH_SIZE = 1
C,H,W = 1,1,2400
learn_rate = 0.0005
num_epochs = 80

## 设置随机种子

In [3]:
import random
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
manualSeed = 4
random.seed(manualSeed)
torch.manual_seed(manualSeed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(manualSeed)

## 设置优化器和损失函数

In [4]:
import torch.optim as optim
device = torch.device("cuda:0")
gc.collect()
torch.cuda.empty_cache()

# ==损失函数权重
# ======== 二分类HC/DOC
# 计算总样本数量
# condition1
# total_samples = 887 + 985 + 879
# condition2
# total_samples = 929 + 1029 + 886
# condition3
total_samples = 887 + 975 + 879
# 计算每个类别的权重
# condition1
# weights = [total_samples / 887, total_samples / (985 + 879)]
# condition2
# weights = [total_samples / 929, total_samples / (1029 + 886)]
# condition3
weights = [total_samples / 887, total_samples / (975 + 879)]

# ======== 二分类MCS/UWS
# 计算总样本数量
# condition1
# total_samples = 985 + 879
# condition2
# total_samples = 1029 + 886
# condition3
# total_samples = 975 + 879
# 计算每个类别的权重
# condition1
# weights = [total_samples / 985, total_samples / 879]
# condition2
# weights = [total_samples / 1029, total_samples / 886]
# condition3
# weights = [total_samples / 975, total_samples / 879]

# 将权重转换为张量
weights_tensor = torch.tensor(weights, device=device)

# 定义交叉熵损失函数并设置权重
criterion = nn.CrossEntropyLoss(weight=weights_tensor)
# criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)

## Training

In [8]:
# 定义LSTM超参数
input_size = 64  # 输入特征维度
hidden_size = 64  # 隐藏单元数量
num_layers = 2  # LSTM层数
output_size = 2  # 输出类别数量
# 创建模型实例
model_list = ['', '_CNN', '_CNN_spa', '_CNN_spa_lstm']
model_name = model_list[0]
if model_name == model_list[0]: #  CascadeCept
    from network_cnn_lstm import MyNetwork
elif model_name == model_list[1]: #  CNN
    from network_cnn_lstm_2 import MyNetwork
elif model_name == model_list[2]: # CascadeCept_1
    from network_cnn_lstm_3 import MyNetwork
elif model_name == model_list[3]: # CascadeCept_2
    from network_cnn_lstm_4 import MyNetwork
model = MyNetwork(input_size, hidden_size, num_layers, output_size)
model = model.to(device)

# experimental dir: rest, conditionA, conditionB, conditionC
exper_dir = "conditionC"
root_dir = f"../data/eegmap_split/{exper_dir}"
# classification = "hc_doc" or "mcs_uws"
classification = "hc_doc"
fold_num = 1
for fold in tqdm(range(fold_num)):
    # train num folds
    fold = 0 # 选择折数
    # -- prepare datasets
    train_data = []
    train_label = []
    test_data = []
    test_label = []
    # ---- hc
    dataset = MyData(root_dir, "train", "hc") # hc
    # find the fold file
    count = 0
    for person in range(len(dataset)):
        filename = os.path.join(dataset.path, dataset.file_path[person])
        # extract the pure name of the file
        parts = filename.split("\\")
        file_name = parts[-1]
        name_without_extension = file_name.split(".")[0]
        # label or data
        file_last = name_without_extension.split("_")[-1]
        if file_last.isdigit(): # data
            # is this fold or not
            if int(file_last) == fold: # yes
                print(filename)
                count = count + 1
                data_map = torch.load(filename)
                # train or valid
                if name_without_extension.split("_")[-2] == "train":
                    for i in range(data_map.size(0)):
                        train_data.append(data_map[i])
                elif name_without_extension.split("_")[-2] == "val":
                    for i in range(data_map.size(0)):
                        test_data.append(data_map[i])
                if count == 4:
                    del data_map
                    gc.collect()
                    torch.cuda.empty_cache() 
                    break
            else:   # not
                pass
        else: # label
            # is this fold or not
            file_last = name_without_extension.split("_")[-2]
            if int(file_last) == fold: # yes
                print(filename)
                count = count + 1
                data_map = torch.load(filename)
                # train or valid
                if name_without_extension.split("_")[-3] == "train":
                    for i in range(data_map.size(0)):
                        train_label.append(data_map[i])
                elif name_without_extension.split("_")[-3] == "val":
                    for i in range(data_map.size(0)):
                        test_label.append(data_map[i])
                if count == 4:
                    del data_map
                    gc.collect()
                    torch.cuda.empty_cache() 
                    break
            else:   # not
                pass
        del filename, parts, file_name, name_without_extension, file_last
        gc.collect()
        torch.cuda.empty_cache()    
    # ---- mcs
    dataset = MyData(root_dir, "train", "mcs") # mcs
    # find the fold file
    count = 0
    for person in range(len(dataset)):
        filename = os.path.join(dataset.path, dataset.file_path[person])
        # extract the pure name of the file
        parts = filename.split("\\")
        file_name = parts[-1]
        name_without_extension = file_name.split(".")[0]
        # label or data
        file_last = name_without_extension.split("_")[-1]
        if file_last.isdigit(): # data
            # is this fold or not
            if int(file_last) == fold: # yes
                print(filename)
                count = count + 1
                data_map = torch.load(filename)
                # train or valid
                if name_without_extension.split("_")[-2] == "train":
                    for i in range(data_map.size(0)):
                        train_data.append(data_map[i])
                elif name_without_extension.split("_")[-2] == "val":
                    for i in range(data_map.size(0)):
                        test_data.append(data_map[i])
                if count == 4:
                    del data_map
                    gc.collect()
                    torch.cuda.empty_cache() 
                    break
            else:   # not
                pass
        else: # label
            # is this fold or not
            file_last = name_without_extension.split("_")[-2]
            if int(file_last) == fold: # yes
                print(filename)
                count = count + 1
                data_map = torch.load(filename)
                # train or valid
                if name_without_extension.split("_")[-3] == "train":
                    for i in range(data_map.size(0)):
                        train_label.append(data_map[i])
                elif name_without_extension.split("_")[-3] == "val":
                    for i in range(data_map.size(0)):
                        test_label.append(data_map[i])
                if count == 4:
                    del data_map
                    gc.collect()
                    torch.cuda.empty_cache() 
                    break
            else:   # not
                pass
        del filename, parts, file_name, name_without_extension, file_last
        gc.collect()
        torch.cuda.empty_cache()    
    # ---- uws
    dataset = MyData(root_dir, "train", "uws") # uws
    # find the fold file
    count = 0
    for person in range(len(dataset)):
        filename = os.path.join(dataset.path, dataset.file_path[person])
        # extract the pure name of the file
        parts = filename.split("\\")
        file_name = parts[-1]
        name_without_extension = file_name.split(".")[0]
        # label or data
        file_last = name_without_extension.split("_")[-1]
        if file_last.isdigit(): # data
            # is this fold or not
            if int(file_last) == fold: # yes
                print(filename)
                count = count + 1
                data_map = torch.load(filename)
                # train or valid
                if name_without_extension.split("_")[-2] == "train":
                    for i in range(data_map.size(0)):
                        train_data.append(data_map[i])
                elif name_without_extension.split("_")[-2] == "val":
                    for i in range(data_map.size(0)):
                        test_data.append(data_map[i])
                if count == 4:
                    del data_map
                    gc.collect()
                    torch.cuda.empty_cache() 
                    break
            else:   # not
                pass
        else: # label
            # is this fold or not
            file_last = name_without_extension.split("_")[-2]
            if int(file_last) == fold: # yes
                print(filename)
                count = count + 1
                data_map = torch.load(filename)
                # train or valid
                if name_without_extension.split("_")[-3] == "train":
                    for i in range(data_map.size(0)):
                        train_label.append(data_map[i])
                elif name_without_extension.split("_")[-3] == "val":
                    for i in range(data_map.size(0)):
                        test_label.append(data_map[i])
                if count == 4:
                    del data_map
                    gc.collect()
                    torch.cuda.empty_cache() 
                    break
            else:   # not
                pass 
        del filename, parts, file_name, name_without_extension, file_last
        gc.collect()
        torch.cuda.empty_cache()    
    print(torch.stack(train_data).size())
    print(torch.stack(train_label).size())
    print(torch.stack(test_data).size())
    print(torch.stack(test_label).size())
    del dataset
    gc.collect()
    torch.cuda.empty_cache()  
    
    train_data = torch.stack(train_data)
    train_label = torch.stack(train_label)
    test_data = torch.stack(test_data)
    test_label = torch.stack(test_label)
    # train dataset
    train_td = TensorDataset(train_data, train_label)
    train_loader = DataLoader(train_td, batch_size = BATCH_SIZE, shuffle = True)
    # test dataset
    test_td = TensorDataset(test_data, test_label)
    test_loader = DataLoader(test_td, batch_size = BATCH_SIZE, shuffle = True)
    del train_data, train_label, test_data, test_label, train_td, test_td
    gc.collect()
    torch.cuda.empty_cache()
    
    # set mode for each fold
    model = MyNetwork(input_size, hidden_size, num_layers, output_size)
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr = learn_rate)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 40, 60], gamma=0.2)
    # -- start training
    start_time = time.time()
    # train and test step records
    total_train_step = 0
    total_test_step = 0
    min_test_loss = 1000
    # add Tensorboard
    writer_train = SummaryWriter(f"../logs/{classification}/{exper_dir}{model_name}/logs_train_{fold}")
    writer_valid = SummaryWriter(f"../logs/{classification}/{exper_dir}{model_name}/logs_valid_{fold}")
    writer_valid_acc = SummaryWriter(f"../logs/{classification}/{exper_dir}{model_name}/logs_valid_acc_{fold}")
    for i in tqdm(range(num_epochs)):  
        print(f"========= Epoch {i} Training =========")
        # train steps
        model.train()
        for data in train_loader:
            # x, y
            data_map, label=data
            data_map_reshaped = torch.reshape(data_map, (110, 1, 1, 2400))
            label_int = label.long()
            data_map_reshaped=data_map_reshaped.to(device)
            label_int=label_int.to(device)
            del data_map, label
            gc.collect()
            torch.cuda.empty_cache()
            # y_pred
            label_pred = model(data_map_reshaped)
            # Loss Computation and Optimization
            loss = criterion(label_pred,label_int)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # draw tensorboard
            total_train_step = total_train_step + 1
            # print info
            if total_train_step % 1000 == 0:
                end_time = time.time()
                print(label_pred)
                print(f"Train time: {end_time - start_time}")
                print(f"Train steps: {total_train_step}, Loss: {loss.item()}")
            writer_train.add_scalar("train_loss",loss.item(),total_train_step)
            # Clear gpu
            del data, data_map_reshaped, label_int, label_pred, loss
            gc.collect()
            torch.cuda.empty_cache()
        
        # Evaluation and save the best model
        print(f"========= Epoch {i} Testing =========")
        model.eval()
        total_test_loss = 0
        test_count = 0
        total_test_acc = 0
        with torch.no_grad():
            for data in test_loader:
                test_count = test_count + 1
                # x, y
                data_map, label=data
                data_map_reshaped = torch.reshape(data_map, (110, 1, 1, 2400))
                label_int = label.long()
                data_map_reshaped = data_map_reshaped.to(device)
                label_int = label_int.to(device)
                del data_map, label
                gc.collect()
                torch.cuda.empty_cache()
                # y_pred
                label_pred_test = model(data_map_reshaped)
                loss = criterion(label_pred_test,label_int)
#                 print(label_pred_test)
                # accuracy 
                total_test_acc = total_test_acc + ((label_pred_test.argmax(1)) == label_int).sum()
                # draw tensorboad
                total_test_loss = total_test_loss + loss
                if test_count % 100 == 0:
                    print(f"Loss: {total_test_loss} Accuracy: {total_test_acc/test_count}")
                # Clear gpu
                del data_map_reshaped, label_int, label_pred_test, loss, data
                gc.collect()
                torch.cuda.empty_cache()
        print(f"Total Loss: {total_test_loss} Total Accuracy: {total_test_acc/test_count}")
        writer_valid.add_scalar("test_loss", total_test_loss, total_test_step)
        writer_valid_acc.add_scalar("test_acc", total_test_acc/test_count, total_test_step)
        total_test_step = total_test_step + 1
        print("..........Saving the model..........")
        torch.save(model.state_dict(),f"../model/{classification}/{exper_dir}{model_name}/Fold{fold}_Epoch{i}.pt") 
#         if total_test_loss < min_test_loss:
#             min_test_loss = total_test_loss
#             print("..........Saving the model..........")
#             torch.save(model.state_dict(),f"../model/{exper_dir}/Fold{fold}_Epoch{i}.pt")

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/1 [00:00<?, ?it/s]

../data/eegmap_split/conditionC\train\hc\conditionC_hc_train_0.pt
../data/eegmap_split/conditionC\train\hc\conditionC_hc_train_0_label.pt
../data/eegmap_split/conditionC\train\hc\conditionC_hc_val_0.pt
../data/eegmap_split/conditionC\train\hc\conditionC_hc_val_0_label.pt
../data/eegmap_split/conditionC\train\mcs\conditionC_mcs_train_0.pt
../data/eegmap_split/conditionC\train\mcs\conditionC_mcs_train_0_label.pt
../data/eegmap_split/conditionC\train\mcs\conditionC_mcs_val_0.pt
../data/eegmap_split/conditionC\train\mcs\conditionC_mcs_val_0_label.pt
../data/eegmap_split/conditionC\train\uws\conditionC_uws_train_0.pt
../data/eegmap_split/conditionC\train\uws\conditionC_uws_train_0_label.pt
../data/eegmap_split/conditionC\train\uws\conditionC_uws_val_0.pt
../data/eegmap_split/conditionC\train\uws\conditionC_uws_val_0_label.pt
torch.Size([2741, 2400, 10, 11])
torch.Size([2741])
torch.Size([683, 2400, 10, 11])
torch.Size([683])


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/80 [00:00<?, ?it/s]

tensor([[-0.0882, -0.0169]], device='cuda:0', grad_fn=<AddmmBackward0>)
Train time: 148.96242833137512
Train steps: 1000, Loss: 0.7294105291366577
tensor([[-0.8039,  0.5978]], device='cuda:0', grad_fn=<AddmmBackward0>)
Train time: 298.4015381336212
Train steps: 2000, Loss: 0.22008463740348816
Loss: 52.29698181152344 Accuracy: 0.7199999690055847
Loss: 93.4146957397461 Accuracy: 0.7699999809265137
Loss: 143.3115997314453 Accuracy: 0.7633333802223206
Loss: 193.63156127929688 Accuracy: 0.7549999952316284
Loss: 239.86019897460938 Accuracy: 0.7500000596046448
Loss: 286.62542724609375 Accuracy: 0.7483333349227905
Total Loss: 323.0703430175781 Total Accuracy: 0.7540263533592224
..........Saving the model..........


FileNotFoundError: [Errno 2] No such file or directory: '../model/hc_doc/conditionC/Fold0_Epoch0.pt'