In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import Parameter
from torch_geometric.nn.pool.topk_pool import topk,filter_adj
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, ChebConv, GraphConv,DataParallel

from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

import argparse
import time
from tqdm import tqdm
import copy as cp

from torch.utils.data import random_split
from torch_geometric.loader import DataLoader, DataListLoader

import joblib
from torch_geometric.data import Data
import os
from sklearn.metrics import classification_report,confusion_matrix
# torch.cuda.set_device(1)

# GSAPool 模块

In [2]:
class GSAPool(torch.nn.Module):

    def __init__(self, in_channels, pooling_ratio=0.5, alpha=0.6, pooling_conv="GCNConv", fusion_conv="false",
                    min_score=None, multiplier=1, non_linearity=torch.tanh):
        super(GSAPool,self).__init__()
        self.in_channels = in_channels

        self.ratio = pooling_ratio
        self.alpha = alpha

        self.sbtl_layer = self.conv_selection(pooling_conv, in_channels)
        self.fbtl_layer = nn.Linear(in_channels, 1)
        self.fusion = self.conv_selection(fusion_conv, in_channels, conv_type=1)

        self.min_score = min_score
        self.multiplier = multiplier
        self.fusion_flag = 0
        if(fusion_conv!="false"):
            self.fusion_flag = 1
        self.non_linearity = non_linearity

    def conv_selection(self, conv, in_channels, conv_type=0):
        if(conv_type == 0):
            out_channels = 1
        elif(conv_type == 1):
            out_channels = in_channels
        if(conv == "GCNConv"):
            return GCNConv(in_channels,out_channels)
        elif(conv == "ChebConv"):
            return ChebConv(in_channels,out_channels,1)
        elif(conv == "SAGEConv"):
            return SAGEConv(in_channels,out_channels)
        elif(conv == "GATConv"):
            return GATConv(in_channels,out_channels, heads=1, concat=True)
        elif(conv == "GraphConv"):
            return GraphConv(in_channels,out_channels)
        else:
            raise ValueError

    def forward(self, x, edge_index, edge_attr=None, batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        x = x.unsqueeze(-1) if x.dim() == 1 else x

        #SBTL
        score_s = self.sbtl_layer(x,edge_index).squeeze()
        #FBTL
        score_f = self.fbtl_layer(x).squeeze()
        #hyperparametr alpha
        score = score_s*self.alpha + score_f*(1-self.alpha)

        score = score.unsqueeze(-1) if score.dim()==0 else score

        if self.min_score is None:
            score = self.non_linearity(score)
        else:
            score = softmax(score, batch)
        perm = topk(score, self.ratio, batch)

        #fusion
        if(self.fusion_flag == 1):
            x = self.fusion(x, edge_index)
    
        x = x[perm] * score[perm].view(-1, 1)
        x = self.multiplier * x if self.multiplier != 1 else x
        
        batch = batch[perm]
        edge_index, edge_attr = filter_adj(
            edge_index, edge_attr, perm, num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm

# 评论网络特征提取模块

In [3]:
class GSAPoolModel(torch.nn.Module):
    def __init__(self,args):
        super(GSAPoolModel, self).__init__()

        self.args = args
        self.nhid = args.r_nhid

        self.num_features = args.r_num_features
        
        self.alpha = args.r_alpha
        self.pooling_ratio = args.r_pooling_ratio
        self.dropout_ratio = args.dropout_ratio

        self.pooling_layer_type = args.r_pooling_layer_type
        self.feature_fusion_type = args.r_feature_fusion_type
        self.out_feature_nhid = args.out_feature_nhid
        
        self.conv1 = self.conv_selection(args.conv,self.num_features, self.nhid)
        self.pool1 = GSAPool(self.nhid, pooling_ratio=self.pooling_ratio, alpha = self.alpha, 
                     pooling_conv=self.pooling_layer_type, fusion_conv=self.feature_fusion_type)
        
        self.conv2 = self.conv_selection(args.conv,self.nhid, self.nhid)
        self.pool2 = GSAPool(self.nhid, pooling_ratio=self.pooling_ratio, alpha = self.alpha, 
                     pooling_conv=self.pooling_layer_type, fusion_conv=self.feature_fusion_type)
        
        self.conv3 = self.conv_selection(args.conv,self.nhid, self.nhid)
        self.pool3 = GSAPool(self.nhid, pooling_ratio=self.pooling_ratio, alpha = self.alpha, 
                     pooling_conv=self.pooling_layer_type, fusion_conv=self.feature_fusion_type)

        self.lin1 = torch.nn.Linear(self.nhid*6, self.out_feature_nhid)
      
    def conv_selection(self, conv, in_channels,out_channels):
        if conv == "gcn":
            return GCNConv(in_channels, out_channels)
        elif conv =="gat":
            return GATConv(in_channels, out_channels)
        return None
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch


        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv3(x, edge_index))
        x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = torch.cat([x1, x2, x3], dim=1)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=self.dropout_ratio, training=self.training)
        # 返回隐藏层状态
        return x

# 传播网络特征提取模块(ASAPooling)

In [4]:
from torch_geometric.nn.pool import ASAPooling
class ASAPoolModel(torch.nn.Module):
    def __init__(self,args):
        super(ASAPoolModel, self).__init__()

        self.args = args
        self.nhid = args.p_nhid

        self.num_features = args.p_num_features
        
        self.pooling_ratio = args.p_pooling_ratio
        self.dropout_ratio = args.dropout_ratio
        self.out_feature_nhid = args.out_feature_nhid

        self.conv1 = self.conv_selection(args.conv, self.num_features, self.nhid)
        self.pool1 = ASAPooling(self.nhid, pooling_ratio=self.pooling_ratio, 
                     dropout=self.dropout_ratio)
        
        self.conv2 = self.conv_selection(args.conv, self.nhid, self.nhid)
        self.pool2 = ASAPooling(self.nhid, pooling_ratio=self.pooling_ratio, 
                     dropout=self.dropout_ratio)
        
        self.conv3 = self.conv_selection(args.conv, self.nhid, self.nhid)
        self.pool3 = ASAPooling(self.nhid, pooling_ratio=self.pooling_ratio, 
                     dropout=self.dropout_ratio)

        self.lin1 = torch.nn.Linear(self.nhid*6, self.out_feature_nhid)
        
    def conv_selection(self, conv, in_channels,out_channels):
        if conv == "gcn":
            return GCNConv(in_channels, out_channels)
        elif conv =="gat":
            return GATConv(in_channels, out_channels)
        return None
  
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv3(x, edge_index))
        x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = torch.cat([x1, x2, x3], dim=1)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=self.dropout_ratio, training=self.training)
        # 返回隐藏层状态
        return x

# 虚假新闻检测模型

In [5]:
class FakeNewsDetecModel(torch.nn.Module):
    def __init__(self,args):
        super(FakeNewsDetecModel, self).__init__()
        self.relpy_layer = GSAPoolModel(args)
        self.propagate_layer = ASAPoolModel(args)
        self.r_nhid = args.r_nhid
        self.p_nhid = args.p_nhid
        self.num_classes = args.num_classes
        self.lin1 = torch.nn.Linear(args.out_feature_nhid * 2, args.out_feature_nhid)
        self.lin2 = torch.nn.Linear(args.out_feature_nhid, args.num_classes)
    def forward(self, reply_nwtwork, propagate_network):
        x1 = self.relpy_layer(reply_nwtwork)
        x2 = self.propagate_layer(propagate_network)
        
        x = torch.cat([x1, x2], dim=1)
        x = F.relu(self.lin1(x))
        x = F.log_softmax(self.lin2(x), dim=-1)
        return x
        

# 超参数设置

In [6]:
parser = argparse.ArgumentParser()

parser.add_argument('--seed', type=int, default=777, help='random seed')
# cuda:id
parser.add_argument('--device', type=str, default='cuda:1', help='specify cuda devices')
parser.add_argument('--num_classes', type=int, default=2,
                    help='num of classed')
# hyper-parameters
parser.add_argument('--dataset', type=str, default='gossipcop', help='[politifact, gossipcop]')
parser.add_argument('--batch_size', type=int, default = 32, help='batch size')
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate')
parser.add_argument('--weight_decay', type=float, default=0.0001, help='weight decay')
parser.add_argument('--dropout_ratio', type=float, default=0.5, help='dropout ratio')
parser.add_argument('--epochs', type=int, default=200, help='maximum number of epochs')
parser.add_argument('--min_loss', type=float, default=1e10,help='min loss value')
parser.add_argument('--patience', type=int, default=15,help='patience for earlystopping')
parser.add_argument('--save_path', type=str, default='/home/dhc/workspace/GSAPool',help='path to save result')
parser.add_argument('--training_times', type=int, default=20, help='')

parser.add_argument('--out_feature_nhid', type=int, default=30, help='')
# ASAPool超参数

parser.add_argument('--p_nhid', type=int, default=25, help='hidden size')
parser.add_argument('--p_concat', type=bool, default=True, help='node feature and graph embedding')
parser.add_argument('--p_model', type=str, default='gcn', help='model type, [gcn, gat, sage]')
parser.add_argument('--p_pooling_ratio', type=float, default=0.5,help='pooling ratio')
parser.add_argument('--conv',type=str, default='gcn', help='model type, [gcn, gat, sage]')
#GSAPool超参数
parser.add_argument('--r_nhid', type=int, default=256, help='hidden size')
parser.add_argument('--r_pooling_ratio', type=float, default=0.5,help='pooling ratio')
parser.add_argument('--r_alpha', type=float, default=0.5,help='combination_ratio')
parser.add_argument('--r_pooling_layer_type', type=str, default='GCNConv',help='GCNConv/SAGEConv/ChebConv/GATConv/GraphConv')
parser.add_argument('--r_feature_fusion_type', type=str, default='GATConv',help='GATConv')

args = parser.parse_known_args()[0]

# 加载数据集

In [7]:
def load_fake_news_net_dataset(reply_network_file, propagate_network_file):

    #gossipcop  politifact
    #gossipcop_reply_network__no_empty_text_with_tweet_top_5_edge_by_similar_include_tweet_reply_edge_60_85base
    reply_network_data_file_path = "/home/dhc/dataset/fake_news_net/processed/" + reply_network_file
    propagate_network_data_file_path = "/home/dhc/dataset/fake_news_net/processed/" + propagate_network_file
    
    reply_pkl = open(reply_network_data_file_path, 'rb')  ## 以二进制方式打开文件
    propagate_pkl = open(propagate_network_data_file_path, 'rb')  ## 以二进制方式打开文件
    
    reply_network_data = joblib.load(reply_pkl)  ##用load()方法把文件内容序列化为Python对象
    propagate_network_data = joblib.load(propagate_pkl)  ##用load()方法把文件内容序列化为Python对象
    reply_pkl.close()
    propagate_pkl.close()
    dataset = []
    nodes = 0
    edges = 0
    for item1, item2 in zip(reply_network_data, propagate_network_data):
        nodes = nodes + len(item1.x)
        edges = edges + len(item1.edge_index[0])
        dataset.append([item1, item2])
        #print(len(item1.edge_index))
    print(nodes, edges)
    return dataset

def data_builder(args):
    
    dataset = load_fake_news_net_dataset(args.reply_network_file, args.propagate_network_file)
    ratio = 0.8
    # 训练集数目
    num_training = int(len(dataset) * ratio)
    # 测试集数目
    num_test = len(dataset) - (num_training)
    training_set,test_set = random_split(dataset,[num_training,num_test])
    
    train_loader = DataLoader(training_set, batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(test_set,batch_size=args.batch_size,shuffle=False)
    r_num_features = dataset[0][0].x.shape[1]
    p_num_features = dataset[0][1].x.shape[1]

    return train_loader, test_loader,r_num_features, p_num_features


# 模型准备工作

In [8]:
#training configuration    politifact_
args.reply_network_file ="politifact_reply_network.pkl"
args.propagate_network_file = "politifact_propagate_network.pkl"
train_loader, test_loader,r_num_features,p_num_features = data_builder(args)

args.r_num_features = r_num_features
args.p_num_features = p_num_features

model = FakeNewsDetecModel(args).to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

361585 7243580


In [9]:
print(model)

FakeNewsDetecModel(
  (relpy_layer): GSAPoolModel(
    (conv1): GCNConv(768, 256)
    (pool1): GSAPool(
      (sbtl_layer): GCNConv(256, 1)
      (fbtl_layer): Linear(in_features=256, out_features=1, bias=True)
      (fusion): GATConv(256, 256, heads=1)
    )
    (conv2): GCNConv(256, 256)
    (pool2): GSAPool(
      (sbtl_layer): GCNConv(256, 1)
      (fbtl_layer): Linear(in_features=256, out_features=1, bias=True)
      (fusion): GATConv(256, 256, heads=1)
    )
    (conv3): GCNConv(256, 256)
    (pool3): GSAPool(
      (sbtl_layer): GCNConv(256, 1)
      (fbtl_layer): Linear(in_features=256, out_features=1, bias=True)
      (fusion): GATConv(256, 256, heads=1)
    )
    (lin1): Linear(in_features=1536, out_features=30, bias=True)
  )
  (propagate_layer): ASAPoolModel(
    (conv1): GCNConv(10, 25)
    (pool1): ASAPooling(25, ratio=0.5)
    (conv2): GCNConv(25, 25)
    (pool2): ASAPooling(25, ratio=0.5)
    (conv3): GCNConv(25, 25)
    (pool3): ASAPooling(25, ratio=0.5)
    (lin1): Li

In [10]:
# 测试函数
def test(model,loader,test = False):
    model.eval()
    correct = 0.
    loss = 0.
    test_y = []
    pre_y = []
    for data in loader:
        reply_network, propagate_network = data[0], data[1]
        reply_network = reply_network.to(args.device)
        propagate_network = propagate_network.to(args.device)
        out = model(reply_network, propagate_network)
        y = reply_network.y
        
        pred = out.max(dim=1)[1]
        test_y.extend(y.cpu().numpy())
        pre_y.extend(pred.cpu().numpy())
        correct += pred.eq(y).sum().item()
        loss += F.nll_loss(out,y,reduction='sum').item()
    if test:
        r = classification_report(test_y, pre_y)
        C = confusion_matrix(test_y, pre_y)
        print("confusion matrix: ", C, '\n')
        print(r, '\n')
    return correct / len(loader.dataset),loss / len(loader.dataset)
        

    #save result in txt
def save_result(test_acc, save_path):
    with open(os.path.join(save_path, 'result.txt'), 'a') as f:
        test_acc *= 100
        f.write(args.dataset+";")
        f.write("pooling_layer_type:"+args.r_pooling_layer_type+";")
        f.write("feature_fusion_type:"+args.r_feature_fusion_type+";")
        f.write(str(test_acc))
        f.write('\r\n')

In [11]:
# 训练
train_loss_his = []
val_loss_his = []
patience = 0
min_loss = args.min_loss
for epoch in range(80):
    model.train()
    tmp = []
    for i, data in enumerate(train_loader):
        reply_network, propagate_network = data[0], data[1]
        reply_network = reply_network.to(args.device)
        propagate_network = propagate_network.to(args.device)
        out = model(reply_network, propagate_network)
        y = reply_network.y
        loss = F.nll_loss(out, y)
        tmp.append(loss.item())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    train_loss_his.append(sum(tmp)/len(tmp))
    val_acc,val_loss = test(model,test_loader)
    val_loss_his.append(val_loss)
    print("Validation loss:{}\taccuracy:{}".format(val_loss,val_acc))
    print("Epoch{}".format(epoch))
    if val_loss < min_loss:
        torch.save(model.state_dict(),'latest.pth')
        print("Model saved at epoch{}".format(epoch))
        min_loss = val_loss
        patience = 0
    else:
        patience += 1
    if patience > args.patience:
        break 

Validation loss:0.6740978726973901	accuracy:0.6538461538461539
Epoch0
Model saved at epoch0
Validation loss:0.6718902862988986	accuracy:0.6538461538461539
Epoch1
Model saved at epoch1
Validation loss:0.6672255396842957	accuracy:0.6538461538461539
Epoch2
Model saved at epoch2
Validation loss:0.6615683986590459	accuracy:0.6538461538461539
Epoch3
Model saved at epoch3
Validation loss:0.6463871231445899	accuracy:0.6538461538461539
Epoch4
Model saved at epoch4
Validation loss:0.6255476153813876	accuracy:0.6730769230769231
Epoch5
Model saved at epoch5
Validation loss:0.6313803562751183	accuracy:0.6730769230769231
Epoch6
Validation loss:0.6199916692880484	accuracy:0.6634615384615384
Epoch7
Model saved at epoch7
Validation loss:0.6188980065859281	accuracy:0.7019230769230769
Epoch8
Model saved at epoch8
Validation loss:0.6120821237564087	accuracy:0.6923076923076923
Epoch9
Model saved at epoch9
Validation loss:0.5638550474093511	accuracy:0.7307692307692307
Epoch10
Model saved at epoch10
Validati

In [12]:
#test step
model = FakeNewsDetecModel(args).to(args.device)
model.load_state_dict(torch.load('latest.pth'))
test_acc,test_loss = test(model,test_loader,True)
print("Test accuarcy:{}".format(test_acc))
args.save_path = "./"
save_result(test_acc, args.save_path)

confusion matrix:  [[30  6]
 [11 57]] 

              precision    recall  f1-score   support

           0       0.73      0.83      0.78        36
           1       0.90      0.84      0.87        68

    accuracy                           0.84       104
   macro avg       0.82      0.84      0.82       104
weighted avg       0.84      0.84      0.84       104
 

Test accuarcy:0.8365384615384616


In [13]:
ttrainLoss = train_loss_his

In [14]:
ttrainLoss

[0.682618838090163,
 0.6814675560364356,
 0.67916667002898,
 0.6783566658313458,
 0.6686973571777344,
 0.6706614723572364,
 0.673806791122143,
 0.660442943756397,
 0.6525921317247244,
 0.6434171933394212,
 0.62601637840271,
 0.641726920237908,
 0.6282647206233098,
 0.616279125213623,
 0.5978762209415436,
 0.5726625942266904,
 0.6472890056096591,
 0.6371900530961844,
 0.5934585745518024,
 0.5781888365745544,
 0.5730788822357471,
 0.5394675158537351,
 0.5136245007698352,
 0.5786612606965579,
 0.5851367872494918,
 0.5512387569134052,
 0.5116601723891038,
 0.4863892151759221,
 0.4887762986696683,
 0.47785725731116074,
 0.44302507776480454,
 0.7134305605521569,
 0.5790034028200003,
 0.5153266122707953,
 0.48597310598079974,
 0.4614940125208635,
 0.4598308847500728,
 0.41081584187654346,
 0.39574238657951355,
 0.4097527059224936,
 0.411186369565817,
 0.39067295766793764,
 0.38245190794651324,
 0.4072675234996356,
 0.4464028959090893,
 0.4495219634129451,
 0.40731653800377476,
 0.369147527676

In [15]:
val_loss_his

[0.6740978726973901,
 0.6718902862988986,
 0.6672255396842957,
 0.6615683986590459,
 0.6463871231445899,
 0.6255476153813876,
 0.6313803562751183,
 0.6199916692880484,
 0.6188980065859281,
 0.6120821237564087,
 0.5638550474093511,
 0.5909907359343308,
 0.5800386300453773,
 0.5479863927914546,
 0.5395508362696722,
 0.5664750200051528,
 0.6126467218765845,
 0.5917607958500202,
 0.567047866491171,
 0.5178944560197684,
 0.5104576624356784,
 0.4937013112581693,
 0.5326237449279199,
 0.7061361395395719,
 0.544244287105707,
 0.49822023740181554,
 0.5176096535645999,
 0.46392120306308454,
 0.5009263157844543,
 0.45696534560276914,
 0.4555428417829367,
 0.5255248409051162,
 0.5493715955660894,
 0.4824458360671997,
 0.4661150872707367,
 0.4602063252375676,
 0.4561242484129392,
 0.4960623887869028,
 0.43290387667142427,
 0.43390725209162784,
 0.434023937353721,
 0.43195908344708955,
 0.4383156643464015,
 0.5944016575813293,
 0.42762691928790164,
 0.5420397795163668,
 0.4897039097089034,
 0.466184