In [None]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import torch
from torch_geometric.loader import DataLoader
import torch.nn as nn
from torch_geometric.nn import GCNConv,EdgeConv
from torch_geometric.nn import global_mean_pool,global_max_pool
from torch.utils.data import random_split
import torch.nn.functional as F
from tqdm import trange
from sklearn.metrics import r2_score

In [None]:
class final_Classification_W(nn.Module):
    
    def __init__(self):
        super().__init__()
        torch.manual_seed(42)
       
        self.nn1 = nn.Linear(100, 128)
        self.nn2 = nn.Linear(128,256)
        self.nn3 = nn.Linear(256,128)
        self.nn4 = nn.Linear(128,64)
        self.nn5 = nn.Linear(64,32)
        self.out = nn.Linear(32, 1)
        
        self.loss_function = nn.BCELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=5e-4)
        self.counter=0
        self.progress = []
        
    def forward(self,x,batch, batch_size):
        
        x = x.to(device)
        x = x.view(batch_size,100)        
        batch = batch.to(device)
        x=self.nn1(x)

        x=self.nn2(x)        
        x = x.relu()
        x=self.nn3(x)
        x = x.relu()
        x=self.nn4(x)
        x = x.relu()
        x=self.nn5(x)
        x = x.relu()        
        out = self.out(x)
        x = out.sigmoid() 
        return x
    
    def train(self,data):
        
        self.optimizer.zero_grad()
        
        outputs = self.forward(data.x.float(),data.batch,data.batch_size)
        y = data.y.to(device)
        batch_size = data.batch_size
        y = y.float().unsqueeze(1)
        loss = self.loss_function(outputs,y)
        
        self.progress.append(loss.item())        
        loss.backward()
        self.optimizer.step()
        
        return loss
    

    def pred(self,data):
        output = self.forward(data.x.float(),data.batch,data.batch_size)
        return output

In [None]:
def conf_matrix(data_loader,model):
    tp,fp,fn,tn = 0,0,0,0
    predict_result = [i for i in model.pred(next(iter(data_loader))).detach().cpu().numpy()]
    real_result = next(iter(data_loader)).y.numpy()
    for i in range(len(predict_result)):
        if real_result[i] == 1 and predict_result[i][0] >= 0.5 :
            tp += 1
        elif real_result[i] == 0 and predict_result[i][0] >= 0.5 :
            fp += 1
        elif real_result[i] == 1 and predict_result[i][0] < 0.5 :
            fn += 1
        else:
            tn += 1
    return [tp,fp,fn,tn]

def evaluation(tp,fp,fn,tn):
    acc = (tp+tn) / (tp+fp+fn+tn)
    pre = tp/(tp+fp)
    recall = tp/(tp+fn)
    f1 = (2*(tp/(tp+fn))*(tp/(tp+fp))) / ((tp/(tp+fn))+(tp/(tp+fp)))
    return acc,pre,recall,f1

def get_db(data_loader,model):
    db=[]
    labels=next(iter(data_loader)).y.numpy()
    preds=[i for i in model.pred(next(iter(data_loader))).detach().cpu().numpy()]
    
    for i in range(len(labels)):
        db.append([labels[i],preds[i]])
    db = sorted(db,key=lambda x: x[1], reverse=True)
    return db

def roc_coord(db,tp,fp,fn,tn):
    xy_arr = []
    ttp,ffp = 0,0
    neg = tn+fp
    pos = tp+fn
    for i in range(len(db)):
        ttp += db[i][0]
        ffp += 1 - db[i][0]
        xy_arr.append([ffp/neg,ttp/pos])
    return xy_arr

def auc(xy_arr):
    auc=0
    prev_x=0
    for x,y in xy_arr:
        if x != prev_x:
            auc += (x-prev_x) * y
            prev_x = x
    return auc

In [None]:
lr=0.0001
model = final_Classification_W()
data = torch.load("/home/chengc/workspace/cc/0815/data/cal_data/dude_2200/MLMS_FISD.pt")
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
train_data,val_data,test_data = random_split(dataset = data,lengths = [1760,330,110],generator = torch.Generator().manual_seed(42))
train_loader = DataLoader(train_data, batch_size = 64, shuffle = True)
model.to(device)

In [None]:
thresh = 0
basic = "/home/chengc/workspace/cc/1125修改/model"
model_path = os.path.join(basic,"whole_spec+FISD.pth")
model_path_ckp = os.path.join(basic,"VS2Net_2200.pth")
#model.load_state_dict(torch.load(model_path))
for epoch in trange(500):
    train_l=[]
    for data in train_loader:
        loss = model.train(data)
        train_l.append(loss)
    train_loss = torch.stack(train_l,dim=0).mean(dim=0)
    ttp,tfp,tfn,ttn = conf_matrix(val_loader,model)
    tdb = get_db(val_loader,model)
    txy_arr = roc_coord(tdb,ttp,tfp,tfn,ttn)
    tauc = auc(txy_arr)
    trecall = ttp/(ttp+tfn) 
    if tauc > thresh:           
        print(f"Epoch {epoch} | Trainloss {train_loss} | ValAuc {tauc} | Valrecall {trecall}")
        torch.save(model.state_dict(),model_path)
        thresh = tauc
    if epoch % 100 == 0 :
        print(f"___Epoch {epoch} | Trainloss {train_loss} | ValAuc {tauc} | Valrecall {trecall}")
        torch.save(model.state_dict(),model_path_ckp)   