In [2]:
import time
from torch.utils.tensorboard import SummaryWriter
import os
import surpport.mySQL as mySQL
import surpport.myfunction as MF
from surpport.Args import Args
from surpport.dataprocess import dataload
import torch.nn.functional as F
import torch
import numpy as np
from surpport.nnstructure import Classifier7
from Experiment.Baseline.NoneGraph.EEGNet import EEGNet2
from Experiment.Baseline.GraphModel.GNN.GNN import GCN, DGCN

Using backend: pytorch


In [3]:
def recorder_build(arg):
    base_rcd = mySQL.gen_base_rcd(arg)
    recorder = {'base': base_rcd}
    return recorder


def train(arg, model, recorder, opt, tr_dataloader, ts_dataloader, writer):
    # train
    for epoch in range(1, arg.epoch_num+1):
        # print(epoch)
        recorder[str(epoch)+'-th'] = dict()
        rd = recorder[str(epoch)+'-th']
        # 由于图经过拼合，所以需要多一个dataloader的过程
        # 前两个是list
        tr_loss, tr_acc = MF.train(
            epoch,
            model, opt, tr_dataloader,
            arg,
            writer
        )
        mySQL.rcd_log(tr_loss, tr_acc, writer, rd, epoch, 'train')

        el_loss, el_acc, logits, labels = MF.evaluate(
            epoch,
            model, opt, ts_dataloader,
            arg,
            writer
        )

        mySQL.rcd_log(el_loss, el_acc, writer, rd, epoch, 'test')
        mySQL.rcd_result(logits, labels, rd)

        val_acc = np.mean(el_acc, axis=0)
        MF.save_best(val_acc, model, arg)

        mySQL.save_final(epoch, model, val_acc, arg, opt)

    mySQL.save_recorder(recorder, arg, 'flow')
    print('best acc: %.4f' % (arg.best_acc))
    return None



## 基础对比实验

In [12]:
# start_time = time.perf_counter()
models = {'eegnet2': EEGNet2, 'gcn1': GCN, 'dgcn1': GCN,}
names = ['dzq', 'kly', 'lbg']
fs = [2]

In [13]:
pairs = []
for m in models.keys():
    for n in names:
        pairs.append((m, n))
pairs

[('eegnet2', 'dzq'),
 ('eegnet2', 'kly'),
 ('eegnet2', 'lbg'),
 ('gcn1', 'dzq'),
 ('gcn1', 'kly'),
 ('gcn1', 'lbg'),
 ('dgcn1', 'dzq'),
 ('dgcn1', 'kly'),
 ('dgcn1', 'lbg')]

In [17]:
model = models[p[0]]
model

Experiment.Baseline.NoneGraph.EEGNet.EEGNet2

In [None]:
# start_time = time.perf_counter()
count = 0

for p in pairs:
    arg = Args()
    model = models[p[0]]() # 都不接受参数，所以没问题
    model = model.cuda()
    arg.d_prepare(p[1], 2)
    # ! 每次运行都改
    count += 1
    arg.m_info(m_name=p[0], m_task='220911_exp', num=count,)
    base_rcd = mySQL.gen_base_rcd(arg)
    recorder = {'base': base_rcd}
    writer = SummaryWriter(arg.tar_path+'\\Journal')

    opt_arg = {'params': model.parameters()}
    #    'lr': 6e-5, 'eps': 1e-8, 'weight_decay': 0.1}
    opt = torch.optim.Adam(**opt_arg)

    arg, tr_dataloader, ts_dataloader = dataload(arg, model, opt)

    train(arg, model, recorder, opt, tr_dataloader, ts_dataloader, writer)
    writer.close()