In [None]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import torch
from tqdm import trange
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch.nn as nn
from torch.utils.data import random_split

In [None]:
class final_Classification_W(nn.Module):
    
    def __init__(self):
        super().__init__()
        torch.manual_seed(42)
       
        self.nn1 = nn.Linear(100, 128)
        self.nn2 = nn.Linear(128,256)
        self.nn3 = nn.Linear(256,128)
        self.nn4 = nn.Linear(128,64)
        self.nn5 = nn.Linear(64,32)
        self.out = nn.Linear(32, 1)
        
        self.loss_function = nn.BCELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=5e-4)
        self.counter=0
        self.progress = []
        
    def forward(self,x):
        
        x = x.to(device)
        x=self.nn1(x)
        x = x.relu()
        x=self.nn2(x)
        x = x.relu()
        x=self.nn3(x)
        x = x.relu()
        x=self.nn4(x)
        x = x.relu()
        x=self.nn5(x)
        x = x.relu()        
        out = self.out(x)
        x = out.sigmoid() 
        return x
    
    def train(self,dataloader):
        losses=[]
        for i in range(len(dataloader)):
            data = dataloader[i]
            outputs = self.forward(data.x.float())
            y = data.y.to(device)
            y = y.float()
            loss_t = self.loss_function(outputs,y)
            losses.append(loss_t)
        loss = torch.stack(losses,dim=0).mean(dim=0)
        #self.counter += 1
        #if (self.counter % 10 == 0):
        self.progress.append(loss.item())
            
        #if (self.counter % 1000 == 0):
            #print(f"counter={self.counter}, loss={loss.item()}")
            
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss
    

    def pred(self,data):
        output = self.forward(data.x.float())
        return output

In [None]:
def conf_matrix(dataset,model):
    tp,fp,fn,tn = 0,0,0,0
    for i in range(len(dataset)):
        out = model.pred(dataset[i])
        if dataset[i].y == 1 and out >= 0.5 :
            tp += 1
        elif dataset[i].y == 0 and out >= 0.5 :
            fp += 1
        elif dataset[i].y == 1 and out < 0.5 :
            fn += 1
        else:
            tn += 1
    return [tp,fp,fn,tn]

def evaluation(tp,fp,fn,tn):
    acc = (tp+tn) / (tp+fp+fn+tn)
    pre = tp/(tp+fp)
    recall = tp/(tp+fn)
    f1 = (2*(tp/(tp+fn))*(tp/(tp+fp))) / ((tp/(tp+fn))+(tp/(tp+fp)))
    return acc,pre,recall,f1

def get_db(dataset,model):
    db=[]
    labels=[]
    preds=[]
    for data in dataset:
        labels.append(data.y.numpy())
        preds.append(model.pred(data).detach().cpu().numpy())
    for i in range(len(labels)):
        db.append([labels[i],preds[i]])
    db = sorted(db,key=lambda x: x[1], reverse=True)
    return db

def roc_coord(db,tp,fp,fn,tn):
    xy_arr = []
    ttp,ffp = 0,0
    neg = tn+fp
    pos = tp+fn
    for i in range(len(db)):
        ttp += db[i][0]
        ffp += 1 - db[i][0]
        xy_arr.append([ffp/neg,ttp/pos])
    return xy_arr

def auc(xy_arr):
    auc=0
    prev_x=0
    for x,y in xy_arr:
        if x != prev_x:
            auc += (x-prev_x) * y
            prev_x = x
    return auc

In [None]:
lr=0.0001
model = final_Classification_W()
data = torch.load("/home/chengc/workspace/cc/0815/data/cal_data/dude_2200/MLMS_FISD.pt")
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
train_data,val_data,test_data = random_split(dataset = data,lengths = [1760,330,110],generator = torch.Generator().manual_seed(42))
train_loader = DataLoader(train_data, batch_size = 64, shuffle = True)
model.to(device)

In [None]:
thresh=0.9
basic = "/home/chengc/workspace/cc/0815/model"
model_path = os.path.join(basic,"VS2Net_2200_MLMSFISD.pth")
model_path_ckp = os.path.join(basic,"VS2Net_2200_MLMSFISD_ckp.pth")
for epoch in trange(1000):
    losses=[]
    for loader in train_loader:
        loss = model.train(loader)
        losses.append(loss)
    lossmean = torch.stack(losses,dim=0).mean(dim=0)
    ttp,tfp,tfn,ttn = conf_matrix(val_data,model)
    tdb = get_db(val_data,model)
    txy_arr = roc_coord(tdb,ttp,tfp,tfn,ttn)
    tauc = auc(txy_arr)
    trecall = ttp/(ttp+tfn)
    tacc = (ttp+ttn)/(ttp+tfp+tfn+ttn) 
    if tauc > thresh and tacc > 0.85:
        print(f"Epoch {epoch} | Trainloss {lossmean} | ValAuc {tauc} | ValAcc {tacc}")
        torch.save(model.state_dict(),model_path)
        thresh = tauc
    if epoch % 100 == 0 :
        print(f"___Epoch {epoch} | Trainloss {lossmean} | ValAuc {tauc} | ValAcc {tacc}")
        torch.save(model.state_dict(),model_path_ckp)
    if epoch < 100 :
        print(f"___Epoch {epoch} | Trainloss {lossmean} | ValAuc {tauc} | ValAcc {tacc}")

#MLMS_FISD
#Epoch 194 | Trainloss 0.07332178205251694 | ValAuc 0.9407058823529391 | ValAcc 0.896969696969697
#500轮
# TestAuc 0.9237973605789697 | TestAcc 0.8727272727272727

model.load_state_dict(torch.load(model_path))
ttp,tfp,tfn,ttn = conf_matrix(test_data,model)
tdb = get_db(test_data,model)
txy_arr = roc_coord(tdb,ttp,tfp,tfn,ttn)
tauc = auc(txy_arr)
trecall = ttp/(ttp+tfn)
tacc = (ttp+ttn)/(ttp+tfp+tfn+ttn)
print(tauc)
print(trecall)
print(tacc)

In [None]:
lr=0.0001
model = final_Classification_W()
data = torch.load("/home/chengc/workspace/cc/0815/data/cal_data/dude_2200/DFT_FISD.pt")
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')
train_data,val_data,test_data = random_split(dataset = data,lengths = [1760,330,110],generator = torch.Generator().manual_seed(42))
train_loader = DataLoader(train_data, batch_size = 64, shuffle = True)
model.to(device)

In [None]:
thresh=0.9
basic = "/home/chengc/workspace/cc/0815/model"
model_path = os.path.join(basic,"VS2Net_2200_DFTFISD.pth")
model_path_ckp = os.path.join(basic,"VS2Net_2200_DFTFISD_ckp.pth")
for epoch in trange(500):
    losses=[]
    for loader in train_loader:
        loss = model.train(loader)
        losses.append(loss)
    lossmean = torch.stack(losses,dim=0).mean(dim=0)
    ttp,tfp,tfn,ttn = conf_matrix(val_data,model)
    tdb = get_db(val_data,model)
    txy_arr = roc_coord(tdb,ttp,tfp,tfn,ttn)
    tauc = auc(txy_arr)
    trecall = ttp/(ttp+tfn)
    tacc = (ttp+ttn)/(ttp+tfp+tfn+ttn) 
    if tauc > thresh and tacc > 0.85:
        print(f"Epoch {epoch} | Trainloss {lossmean} | ValAuc {tauc} | ValAcc {tacc}")
        torch.save(model.state_dict(),model_path)
        thresh = tauc
    if epoch % 50 == 0 :
        print(f"___Epoch {epoch} | Trainloss {lossmean} | ValAuc {tauc} | ValAcc {tacc}")
        torch.save(model.state_dict(),model_path_ckp)
    if epoch < 100 :
        print(f"___Epoch {epoch} | Trainloss {lossmean} | ValAuc {tauc} | ValAcc {tacc}")

In [None]:
#Epoch 450 | Trainloss 0.0028689163736999035 | ValAuc 0.943006535947713 | ValAcc 0.896969696969697
#500轮
# TestAuc 0.8654746700723712 | TestAcc 0.8545454545454545

In [None]:
model.load_state_dict(torch.load(model_path))
ttp,tfp,tfn,ttn = conf_matrix(test_data,model)
tdb = get_db(test_data,model)
txy_arr = roc_coord(tdb,ttp,tfp,tfn,ttn)
tauc = auc(txy_arr)
trecall = ttp/(ttp+tfn)
tacc = (ttp+ttn)/(ttp+tfp+tfn+ttn)
print(tauc)
print(trecall)
print(tacc)