In [1]:
from torch.utils.tensorboard import SummaryWriter
import os
import surpport.mySQL as mySQL
import surpport.myfunction as MF
from surpport.Args import Args
from surpport.dataprocess import dataload
import torch.nn.functional as F
import torch
import numpy as np

Using backend: pytorch


## dataload 重构

### dataload

In [2]:
from torch.utils.data import Dataset
from dgl.data.utils import load_graphs
from os.path import join
from dgl.dataloading import GraphDataLoader

class cross_Dataset(Dataset):
    def __init__(self, arg) -> None:
        super(Dataset, self).__init__()
        self.arg = arg # type: Args
        self.load(arg)
    
    def load(self, arg):
        f = arg.select_f # type: str
        ori = r'E:\DATABASE\FirstGNN\CrossData'
        path = join(ori, f+'_dgl_graph')
        gdata = torch.load(path)
        self.graphs, self.labels = gdata['graphs'], gdata['labels']

    def __getitem__(self, index):
        return self.graphs[index], self.labels[index]

    def __len__(self):
        return len(self.graphs)

def c_dataload(arg, model, opt, alpha):
    """cross exp dataload

    Args:
        arg (Args): Args need
        model (nn.Moudle): ours model
        opt (torch.opt): torch.opt
        alpha (int): 1~11, contain by num of patients

    Returns:
        tuple: arg, tr_dataloader, ts_dataloader
    """
    dataset = cross_Dataset(arg)
    if not arg.new_train:
        MF.continue_tr(arg.dir, model, opt, arg)
    else:
        end_num = alpha * 1875 # 单人一个
        lenth = len(dataset)
        arg.tr_id = list(range(end_num))
        arg.ts_id = list(range(end_num, lenth-1))
    
    train_dataset = torch.utils.data.Subset(
        dataset=dataset, indices=arg.tr_id)
    test_dataset = torch.utils.data.Subset(
        dataset=dataset, indices=arg.ts_id)
    tr_dataloader = dataloader(
        train_dataset, arg.batch_size, collate=MF.collate, shuffle=True)
    ts_dataloader = dataloader(
        test_dataset, arg.batch_size, collate=MF.collate, shuffle=True)
    return arg, tr_dataloader, ts_dataloader

def dataloader(dataset, batch_size, collate, shuffle):
    gdataloader = GraphDataLoader(
        dataset,
        collate_fn=collate,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=False
        # sampler = sampler
    )
    return gdataloader    

In [9]:
arg = Args()
arg.d_prepare('Cross', 2)
dataset = cross_Dataset(arg)

In [11]:
dataset[1]

(Graph(num_nodes=18, num_edges=98,
       ndata_schemes={'h': Scheme(shape=(128,), dtype=torch.float32)}
       edata_schemes={'f': Scheme(shape=(1,), dtype=torch.float32)}),
 'labels')

### 构造测试

In [2]:
def load(f: int, name: str, graphs: list, labels: list):
    f = f'f{f}'
    ori = r'E:\DATABASE\FirstGNN\graphbin'
    path = join(ori, name, f+'_dgl_graph.bin')
    g, l = load_graphs(path)
    graphs.extend(g)
    labels.extend(l['labels'])

In [6]:
arg = Args()
save_path = r'E:\DATABASE\FirstGNN\CrossData'
fs = [2, 6, 7, 15]
names = arg.patient
for f in fs:
    graphs = []
    labels = []
    for name in names:
        load(f, name, graphs, labels)
    gdata = {'graphs': graphs, 'labels': labels}
    file_path = join(save_path, f'f{f}_dgl_graph')
    torch.save(gdata, file_path)

In [None]:
gdata = {'graphs': graphs, 'labels': labels}
save_path = r'E:\DATABASE\FirstGNN\CrossData'
file_path = join(save_path, f'f{f}_dgl_graph') 
torch.save(gdata, file_path)

In [7]:
gdata = torch.load(r'E:\DATABASE\FirstGNN\CrossData\f2_dgl_graph')

In [8]:
l = gdata['labels']

In [11]:
l[15]

tensor([0])

In [12]:
del gdata
del graphs
del labels

### 运行测试

In [3]:
import sys
# sys.path.append(r'E:\CODEBASE\myDGL\FirstDGL\Experiment\Baseline\GraphModel\GNN')
sys.path.append(r'E:\CODEBASE\myDGL\FirstDGL\Experiment\Baseline')
from NoneGraph.EEGNet import EEGNet2

In [10]:
from surpport.nnstructure import Classifier7

def m_pre():
    arg = Args()
    # model = GNN.DGCN()
    # model = Classifier7(arg)
    model = EEGNet2()
    model = model.cuda()
    return arg, model

def recorder_build(arg):
    base_rcd = mySQL.gen_base_rcd(arg)
    recorder = {'base': base_rcd}
    return recorder

In [7]:
names = ['kly']
fs = [2, ]
for f in fs:
    arg, model = m_pre()
    arg.d_prepare('Cross', f)
    # ! 每次运行都改
    arg.m_info(m_name='eegnet2', m_task='220822_cross_test', num=2)
    base_rcd = mySQL.gen_base_rcd(arg)
    recorder = {'base': base_rcd}
    writer = SummaryWriter(arg.tar_path+'\\Journal')

    # ! 根据模型修改
    opt_arg = {'params': model.parameters(),}
    # opt_arg = {'params': model.parameters(),'lr': 6e-5, 'eps': 1e-8, 'weight_decay': 0.1}
    opt = torch.optim.Adam(**opt_arg)

    arg, tr_dataloader, ts_dataloader = c_dataload(arg, model, opt, 9)

In [18]:
td = iter(tr_dataloader)
g, l = next(td)

In [9]:
st = 41
step = 40
arg.display_freq = 50
for epoch in range(st, step+st):
    print(epoch)
    recorder[str(epoch)+'-th'] = dict()
    rd = recorder[str(epoch)+'-th']
    # 由于图经过拼合，所以需要多一个dataloader的过程
    # 前两个是list
    tr_args = [epoch, model, opt, tr_dataloader, arg, writer]
    tr_loss, tr_acc = MF.train(*tr_args)
    mySQL.rcd_log(tr_loss, tr_acc, writer, rd, epoch, 'train')

    tr_args[3] = ts_dataloader
    el_loss, el_acc, logits, labels = MF.evaluate(*tr_args)

    mySQL.rcd_log(el_loss, el_acc, writer, rd, epoch, 'test')
    mySQL.rcd_result(logits, labels, rd)

    val_acc = np.mean(el_acc, axis=0)
    MF.save_best(val_acc, model, arg)

    mySQL.save_final(epoch, model, val_acc, arg, opt)

mySQL.save_recorder(recorder, arg, 'flow')
print('best acc: %.4f' % (arg.best_acc))
writer.close()

41
Epoch 41, Iter 050,train loss = 1.2588, train acc = 0.5000
Epoch 41, Iter 100,train loss = 1.1670, train acc = 0.7500
Epoch 41, Iter 150,train loss = 1.1612, train acc = 0.6562
Epoch 41, Iter 200,train loss = 1.1390, train acc = 0.6875
Epoch 41, Iter 250,train loss = 1.2098, train acc = 0.5312
Epoch 41, Iter 300,train loss = 1.2138, train acc = 0.5625
Epoch 41, Iter 350,train loss = 1.2738, train acc = 0.5625
Epoch 41, Iter 400,train loss = 1.1632, train acc = 0.6250
Epoch 41, Iter 450,train loss = 1.2928, train acc = 0.6562
Epoch 41, Iter 500,train loss = 1.2136, train acc = 0.6562
Epoch 41, val loss = 1.5084, val acc = 0.3836
未达到期望，未保存模型
42
Epoch 42, Iter 050,train loss = 1.1159, train acc = 0.7188
Epoch 42, Iter 100,train loss = 1.2480, train acc = 0.5938
Epoch 42, Iter 150,train loss = 1.1888, train acc = 0.5625
Epoch 42, Iter 200,train loss = 1.1941, train acc = 0.5938
Epoch 42, Iter 250,train loss = 1.2300, train acc = 0.5000
Epoch 42, Iter 300,train loss = 1.2181, train acc =