In [3]:
import torch
from torch.nn import functional as F
from torch import nn
from torch.utils.data import DataLoader
from torch import optim
import numpy as np
from Embedding import Normalization
from tqdm import tqdm
import time
import random
from earlystop import EarlyStopping
from ret import ReT
from convolution import Convolution
from data_loader import DataSets, StratifiedSampler

In [4]:
def fix_seed(seed=66):
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
# 固定随机种子
fix_seed()

In [None]:
# 训练模型
def train(model, train_iter, valid_iter, lr, num_epochs, device="cuda"):
    """训练模型"""

    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv1d or type(m) == nn.Conv2d:
            nn.init.kaiming_normal_(m.weight)

    early_stop = EarlyStopping(patience=10, verbose=True)

    model.apply(init_weights)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.003)
    loss = nn.CrossEntropyLoss()
    start_time = time.time()

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_accuracy = 0

        for data, ecg_feature, eog_feature, label in tqdm(train_iter):
            optimizer.zero_grad()
            data = data.to(device)
            ecg_feature = ecg_feature.to(device)
            eog_feature = eog_feature.to(device)
            label = label.type(torch.LongTensor).to(device)
            pred = model(data, ecg_feature, eog_feature)
            l = loss(pred, label)
            l.backward()
            optimizer.step()
            acc = (pred.argmax(dim=1) == label).float().mean()
            epoch_accuracy += acc / len(train_iter)
            epoch_loss += l / len(train_iter)
        epoch_time = time.time() - start_time

        with torch.no_grad():
            model.eval()
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            for data, ecg_feature, eog_feature, label in tqdm(valid_iter):
                data = data.to(device)
                label = label.type(torch.LongTensor).to(device)
                ecg_feature = ecg_feature.to(device)
                eog_feature = eog_feature.to(device)

                pred = model(data, ecg_feature, eog_feature)
                val_loss = loss(pred, label)

                acc = (pred.argmax(dim=1) == label).float().mean()
                epoch_val_accuracy += acc / len(valid_iter)
                epoch_val_loss += val_loss / len(valid_iter)
        print(
            f"epoch {epoch + 1}: train loss {epoch_loss:.3f}, train acc {epoch_accuracy:.3f}, "
            f"val loss {epoch_val_loss:.3f}, val acc {epoch_val_accuracy:.3f}, time {epoch_time:.1f}s"
        )
        start_time = time.time()
        early_stop(
            epoch_val_accuracy,
            model,
        )
        if early_stop.early_stop:
            print("Early stopping")
            break