## 模型代码优化测试

In [45]:
import torch
import dgl
import torch.nn as nn
import dgl.function as fn
from dgl.nn.pytorch.conv import EdgeWeightNorm, GraphConv
import numpy as np
from importlib import reload
import surpport.Args as A
import surpport.mySQL as mySQL
import surpport.dataprocess as DP
import surpport.myfunction as MF

## 模型

In [46]:
class Classifier7(nn.Module):
    def __init__(self, arg): 
        super(Classifier7, self).__init__()
        self.n_classes = arg.num_labels
        self.pip_num = 3
        self.pt_num = 3

        ## // todo: need to rectify to the dynamic or use same width type model
        Conv = []
        Gcn = []
        for i in range(self.pip_num):
            Conv.append(nn.Conv1d(128, 96, kernel_size=3, stride=1, padding=i+1, dilation=i+1))
            Gcn.append(GraphConv(96, 54, norm='right', weight=True, bias=True))
        self.ConvList = nn.ModuleList(Conv)
        self.GcnList = nn.ModuleList(Gcn)


        self.pressConv = nn.Conv1d(self.pip_num, 1, kernel_size=self.pip_num, stride=1, padding='same')

        self.PR_GIN = nn.ModuleList([
            GraphConv(54, 34, norm='right', weight=True, bias=True),  
            GraphConv(34, 25, norm='right', weight=True, bias=True), 
            GraphConv(25, 16, norm='right', weight=True, bias=True),
            ])
        self.classify1 = nn.Linear(75, 120)
        self.classify2 = nn.Linear(120, self.n_classes)
        self.batchnorm1 = nn.BatchNorm1d(96)
        self.PR_BN = nn.ModuleList([nn.BatchNorm1d(54),
            nn.BatchNorm1d(34),nn.BatchNorm1d(16),])
        self.batchnorm3 = nn.BatchNorm1d(75)
        self.actf = nn.ReLU()
        self.tanh = nn.Tanh()
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, graph):
        emd_h = self.BN_embedding(graph)
        emd_h = self.dropout(emd_h)
        logits = self.Presentation(graph, emd_h)
        return logits, emd_h


    def cross_connect(self, h01, h02) -> torch.Tensor:
        h01 = h01.permute(0, 2, 1).squeeze(0)
        return self.batchnorm1(self.actf(h01)+h02.permute(0, 2, 1).squeeze(0))

    # Brain Network embedding
    def BN_embedding(self, graph) -> torch.Tensor:
        # conv1D 需要变形，所以要单独写个函数
        # 普通1D卷积, TCN 不合适
        h0 = graph.ndata['h'].float().unsqueeze(0).permute(0, 2, 1) # shape = [1, 128, 18*n]
        h0x = []
        for i in range(self.pip_num):
            h0x.append(self.ConvList[i](h0))
        
        edeg_w = graph.edata['f'].float()
        h2x = []
        for i in range(self.pip_num):
            h1x = self.cross_connect(h0x[i%self.pip_num], h0x[(i+1)%self.pip_num])
            temp = self.GcnList[i](graph, h1x, edge_weight=edeg_w)
            h2x.append(self.actf(temp).unsqueeze(0)) # reback to node x emb

        ac_h1 = self.pressConv(torch.cat(h2x, dim=0).permute(
            1, 0, 2)).permute(1, 0, 2).squeeze(0)
        ac_h1 = self.PR_BN[0](ac_h1) # without dropout, it's indivual for module 

        return ac_h1

    # presetation module
    def Presentation(self, graph, emd_h) -> torch.Tensor:
        h0x = [emd_h]
        edeg_w = graph.edata['f'].float()
        for i in range(len(self.PR_GIN)):
            h0x.append(self.actf(self.PR_GIN[i](graph, h0x[i], edge_weight=edeg_w)))

        with graph.local_scope():
            ac_h = torch.cat(h0x[1:], 1)
            # print(ac_h.shape)
            graph.ndata['nh'] = ac_h
            # print(graph)
            hg = dgl.readout_nodes(graph, 'nh', None, op='sum', ntype=None)
            return self.classify2(self.classify1(hg))

    @ property
    def num_labels(self):
        return 5

    def __getitem__(self, idx):
        return self.graphs[idx], self.label[idx]

    def __len__(self):
        return len(self.graphs)


## 初始化

In [47]:
arg = A.Args()

In [48]:
import surpport.nnstructure as mynn
reload(mynn)

<module 'surpport.nnstructure' from 'e:\\CODEBASE\\myDGL\\FirstDGL\\surpport\\nnstructure.py'>

In [49]:
model = Classifier7(arg)
model = model.cuda()

In [50]:
arg.d_prepare()
arg.m_info(m_name='m7', m_task='220808_test', num=7,)

In [51]:
opt = torch.optim.Adam(model.parameters())
# , lr=6e-5, eps=1e-8, weight_decay=0.1) 
arg, tr_dataloader, ts_dataloader = DP.dataload(arg, model, opt)

In [52]:
len(arg.tr_id)

1500

In [53]:
base_rcd = mySQL.gen_base_rcd(arg)
recorder = {'base': base_rcd}

In [54]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(arg.tar_path+'\\Journal')

In [32]:
# 快速修改时使用
model = Classifier7(arg)
model = model.cuda()

## 训练

In [None]:
# c7
st = 30
for epoch in range(st, st+30):
    print(epoch)
    recorder[str(epoch)+'-th'] = dict()
    rd = recorder[str(epoch)+'-th']
    # 由于图经过拼合，所以需要多一个dataloader的过程
    # 前两个是list
    tr_loss, tr_acc = MF.train(
        epoch,
        model, opt, tr_dataloader,
        arg,
        writer
    )
    mySQL.rcd_log(tr_loss, tr_acc, writer, rd, epoch, 'train')

    el_loss, el_acc, logits, labels = MF.evaluate(
        epoch,
        model, opt, ts_dataloader,
        arg,
        writer
    )

    mySQL.rcd_log(el_loss, el_acc, writer, rd, epoch, 'test')
    mySQL.rcd_result(logits, labels, rd)

    val_acc = np.mean(el_acc, axis=0)
    MF.save_best(val_acc, model, arg)

    mySQL.save_final(epoch, model, val_acc, arg, opt)

mySQL.save_recorder(recorder, arg, 'flow')
print('best acc: %.4f' % (arg.best_acc))
writer.close()

In [None]:
from surpport.nnstructure import Classifier4
model = Classifier4(arg)
model.cuda()

In [None]:
for epoch in range(1, 10+1):
    print(epoch)
    recorder[str(epoch)+'-th'] = dict()
    rd = recorder[str(epoch)+'-th']
    # 由于图经过拼合，所以需要多一个dataloader的过程
    # 前两个是list
    tr_loss, tr_acc = MF.train(
        epoch,
        model, opt, tr_dataloader,
        arg,
        writer
    )
    mySQL.rcd_log(tr_loss, tr_acc, writer, rd, epoch, 'train')

    el_loss, el_acc, logits, labels = MF.evaluate(
        epoch,
        model, opt, ts_dataloader,
        arg,
        writer
    )

    mySQL.rcd_log(el_loss, el_acc, writer, rd, epoch, 'test')
    mySQL.rcd_result(logits, labels, rd)

    val_acc = np.mean(el_acc, axis=0)
    MF.save_best(val_acc, model, arg)

    mySQL.save_final(epoch, model, val_acc, arg, opt)

mySQL.save_recorder(recorder, arg, 'flow')
print('best acc: %.4f' % (arg.best_acc))
writer.close()

## 不变量的讨论

In [1]:
g = [[1 for _ in range(3)], [2 for _ in range(3)], [3 for _ in range(3)]]

for gi in g:
    gi[1] = 4

g
## ! 我日，这个会改，那不是出大问题了

[[1, 4, 1], [2, 4, 2], [3, 4, 3]]

## w & b research

In [58]:
import wandb
os.environ['WANDB_NOTEBOOK_NAME'] = '220808'
wandb.init(project="my-test-project")

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

In [60]:
wandb.config = {
  "learning_rate": 0.001,
  "epochs": 30,
  "batch_size": 32,
}

In [61]:
wandb.log({"ts_acc": val_acc})

In [62]:
%%wandb

UsageError: %%wandb is a cell magic, but the cell body is empty. Did you mean the line magic %wandb (single %)?


In [None]:
wandb.watch(model, torch.nn.functional.cross_entropy(), log_freq=40, )