In [1]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch.nn as nn
from torch_geometric.nn import global_mean_pool
from torch.utils.data import random_split
from tqdm import trange

In [2]:
train_data = torch.load('/home/chengc/workspace/cc/0815/data/cal_data/dude_82/train_data.pt')
val_data = torch.load('/home/chengc/workspace/cc/0815/data/cal_data/dude_82/val_data.pt')

In [3]:
class final_Classification_W(nn.Module):
    
    def __init__(self):
        super().__init__()
        torch.manual_seed(42)
       
        self.nn1 = nn.Linear(100, 128)
        self.nn2 = nn.Linear(128,256)
        self.nn3 = nn.Linear(256,128)
        self.nn4 = nn.Linear(128,64)
        self.nn5 = nn.Linear(64,32)
        self.out = nn.Linear(32, 1)
        #self.dropout = nn.Dropout(p=0.1)
        
        self.loss_function = nn.BCELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, weight_decay=5e-4)
        self.counter=0
        self.progress = []
        
    def forward(self,x):
        
        x = x.to(device)
        x=self.nn1(x)
        x = x.relu()
        x=self.nn2(x)
        x = x.relu()
        x=self.nn3(x)
        x = x.relu()
        x=self.nn4(x)
        x = x.relu()
        x=self.nn5(x)
        x = x.relu()        
        out = self.out(x)
        #out = self.dropout(out)
        x = out.sigmoid() 
        return x
    
    def train(self,dataloader):
        losses=[]
        for i in range(len(dataloader)):
            data = dataloader[i]
            outputs = self.forward(data.x.float())
            y = data.y.to(device)
            y = y.float()
            loss_t = self.loss_function(outputs,y)
            losses.append(loss_t)
        loss = torch.stack(losses,dim=0).mean(dim=0)
        #self.counter += 1
        #if (self.counter % 10 == 0):
        self.progress.append(loss.item())
            
        #if (self.counter % 1000 == 0):
            #print(f"counter={self.counter}, loss={loss.item()}")
            
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss
    

    def pred(self,data):
        output = self.forward(data.x.float())
        return output

def conf_matrix(dataset,model):
    tp,fp,fn,tn = 0,0,0,0
    for i in range(len(dataset)):
        out = model.pred(dataset[i])
        if dataset[i].y == 1 and out >= 0.5 :
            tp += 1
        elif dataset[i].y == 0 and out >= 0.5 :
            fp += 1
        elif dataset[i].y == 1 and out < 0.5 :
            fn += 1
        else:
            tn += 1
    return [tp,fp,fn,tn]

def evaluation(tp,fp,fn,tn):
    acc = (tp+tn) / (tp+fp+fn+tn)
    pre = tp/(tp+fp)
    recall = tp/(tp+fn)
    f1 = (2*(tp/(tp+fn))*(tp/(tp+fp))) / ((tp/(tp+fn))+(tp/(tp+fp)))
    return acc,pre,recall,f1

def get_db(dataset,model):
    db=[]
    labels=[]
    preds=[]
    for data in dataset:
        labels.append(data.y.numpy())
        preds.append(model.pred(data).detach().cpu().numpy())
    for i in range(len(labels)):
        db.append([labels[i],preds[i]])
    db = sorted(db,key=lambda x: x[1], reverse=True)
    return db

def roc_coord(db,tp,fp,fn,tn):
    xy_arr = []
    ttp,ffp = 0,0
    neg = tn+fp
    pos = tp+fn
    for i in range(len(db)):
        ttp += db[i][0]
        ffp += 1 - db[i][0]
        xy_arr.append([ffp/neg,ttp/pos])
    return xy_arr

def auc(xy_arr):
    auc=0
    prev_x=0
    for x,y in xy_arr:
        if x != prev_x:
            auc += (x-prev_x) * y
            prev_x = x
    return auc

In [4]:
train_loader = DataLoader(train_data, batch_size = 3072, shuffle = True)
model = final_Classification_W()
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
model.to(device)

final_Classification_W(
  (nn1): Linear(in_features=100, out_features=128, bias=True)
  (nn2): Linear(in_features=128, out_features=256, bias=True)
  (nn3): Linear(in_features=256, out_features=128, bias=True)
  (nn4): Linear(in_features=128, out_features=64, bias=True)
  (nn5): Linear(in_features=64, out_features=32, bias=True)
  (out): Linear(in_features=32, out_features=1, bias=True)
  (loss_function): BCELoss()
)

In [None]:
thresh =0.92
basic = "/home/chengc/workspace/cc/0815/model"
model_path = os.path.join(basic,"VS2Net_dude82.pth")
model_path_ckp = os.path.join(basic,"VS2Net_dude82_ckp.pth")
#model.load_state_dict(torch.load(model_path_ckp))
for epoch in trange(1000):
    losses=[]
    for loader in train_loader:
        loss = model.train(loader)
        losses.append(loss)
    lossmean = torch.stack(losses,dim=0).mean(dim=0)
    ttp,tfp,tfn,ttn = conf_matrix(val_data,model)
    tdb = get_db(val_data,model)
    txy_arr = roc_coord(tdb,ttp,tfp,tfn,ttn)
    tauc = auc(txy_arr)
    trecall = ttp/(ttp+tfn)
    tacc = (ttp+ttn)/(ttp+tfp+tfn+ttn) 
    if tauc > thresh and tacc > 0.83:
        print(f"Epoch {epoch} | Trainloss {lossmean} | ValAuc {tauc} | ValAcc {tacc}")
        torch.save(model.state_dict(),model_path)
        thresh = tauc
    if epoch % 50 == 0 :
        print(f"___Epoch {epoch} | Trainloss {lossmean} | ValAuc {tauc} | ValAcc {tacc}")
        torch.save(model.state_dict(),model_path_ckp)
    if epoch < 100 :
        print(f"___Epoch {epoch} | Trainloss {lossmean} | ValAuc {tauc} | ValAcc {tacc}")

  0%|▏                                                                                                                                                | 1/1000 [05:42<94:59:07, 342.29s/it]

___Epoch 0 | Trainloss 0.5226327776908875 | ValAuc 0.500986277205372 | ValAcc 0.9090909090909091
___Epoch 0 | Trainloss 0.5226327776908875 | ValAuc 0.500986277205372 | ValAcc 0.9090909090909091


  0%|▎                                                                                                                                               | 2/1000 [12:12<102:45:45, 370.69s/it]

___Epoch 1 | Trainloss 0.3151915967464447 | ValAuc 0.5021975633387795 | ValAcc 0.9090909090909091


  0%|▍                                                                                                                                               | 3/1000 [18:22<102:34:27, 370.38s/it]

___Epoch 2 | Trainloss 0.31131917238235474 | ValAuc 0.504043317138138 | ValAcc 0.9090909090909091


  0%|▌                                                                                                                                               | 4/1000 [24:12<100:13:18, 362.25s/it]

___Epoch 3 | Trainloss 0.30988556146621704 | ValAuc 0.5062852318098005 | ValAcc 0.9090909090909091


  0%|▋                                                                                                                                                | 5/1000 [30:12<99:51:01, 361.27s/it]

___Epoch 4 | Trainloss 0.3086230456829071 | ValAuc 0.5096569502255507 | ValAcc 0.9090909090909091


  1%|▊                                                                                                                                                | 6/1000 [36:03<98:48:22, 357.85s/it]

___Epoch 5 | Trainloss 0.30920180678367615 | ValAuc 0.5130252531366176 | ValAcc 0.9090909090909091


  1%|█                                                                                                                                                | 7/1000 [42:13<99:48:01, 361.81s/it]

___Epoch 6 | Trainloss 0.3077200949192047 | ValAuc 0.5209421418529455 | ValAcc 0.9090909090909091


  1%|█▏                                                                                                                                               | 8/1000 [48:00<98:24:36, 357.13s/it]

___Epoch 7 | Trainloss 0.30527469515800476 | ValAuc 0.5276967041886669 | ValAcc 0.9090909090909091


  1%|█▎                                                                                                                                               | 9/1000 [53:57<98:19:17, 357.17s/it]

___Epoch 8 | Trainloss 0.30381280183792114 | ValAuc 0.5373818323278562 | ValAcc 0.9090909090909091


  1%|█▍                                                                                                                                              | 10/1000 [59:45<97:25:23, 354.27s/it]

___Epoch 9 | Trainloss 0.30474376678466797 | ValAuc 0.5511496438080672 | ValAcc 0.9090909090909091


  1%|█▌                                                                                                                                            | 11/1000 [1:05:34<96:55:17, 352.80s/it]

___Epoch 10 | Trainloss 0.30387574434280396 | ValAuc 0.5632012771978409 | ValAcc 0.9090909090909091


  1%|█▋                                                                                                                                            | 12/1000 [1:11:36<97:34:29, 355.54s/it]

___Epoch 11 | Trainloss 0.3024972379207611 | ValAuc 0.5805436152413895 | ValAcc 0.9090909090909091
