In [1]:
import torch
from torch_geometric.data import Data
from torch.utils.data import Dataset
import json
from gensim.models.keyedvectors import KeyedVectors
from torch import nn
import numpy as np

In [2]:
from data import *
dataset = semEval2017Dataset(
    dataPath='../dataset/semeval2017-task8/', 
    type='train'
)
glove25d = KeyedVectors.load_word2vec_format(
    '../dataset/glove/glove.twitter.27B.25d.gensim.txt',
    binary=False
)

In [10]:
vocabSize = len(glove25d) + 3
vectorSize = glove25d.vector_size
glove25d.add_vectors(["<start>", "<end>", "<unk>"] ,np.random.randn(3, vectorSize))


array([ 0.21294 ,  0.31035 ,  0.17694 ,  0.87498 ,  0.067926,  0.59171 ,
       -0.098218,  1.5896  , -0.428   , -1.3655  , -0.15278 , -2.501   ,
       -5.5652  , -0.10232 ,  0.39577 ,  0.1555  , -0.55181 ,  0.34671 ,
       -0.57379 , -0.30717 ,  0.043623, -0.39707 ,  0.64551 , -0.33537 ,
        0.020467], dtype=float32)

In [3]:
from torch.utils.data import DataLoader
from data import collate
cls = torch.randn(25).tolist()
loader = DataLoader(
    dataset,
    shuffle = True,
    num_workers = 4,
    collate_fn = lambda x: collate(x, glove25d, cls)
)

In [4]:
from ABGCN import *
from torch import optim
model = ABGCN(
    w2vDim = 25,
    s2vDim = 64, # 使用的句嵌入的维度
    gcnHiddenDim = 64, # GCN隐藏层的维度（GCNconv1的输出维度）
    rumorFeatureDim = 64, # GCN输出层的维度
    numRumorTag = 3, # 谣言标签种类数
    numStanceTag = 4, # 立场标签种类数
    numHeads = 5
)
device = torch.device('cpu')
model = model.set_device(device)
loss_func = torch.nn.CrossEntropyLoss(reduction='mean').to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-2, weight_decay=1e-2)

In [5]:
for data in loader:
    print(data['nodeFeature'].shape)

    nodeFeature = model.forwardStance(data)

    print(nodeFeature.shape)

    break

torch.Size([7, 29, 25])
torch.Size([7, 4])


In [None]:
from tqdm import tqdm
from torch.nn.functional import softmax

totalLoss = 0.
testdata = dataset[0]
for data in tqdm(iter(loader), 
                 desc="[epoch {:d}, rumor]".format(1),
                 leave=False, 
                 ncols=100):

                 
    optimizer.zero_grad()
    p = model.forwardRumor(data)
    loss = loss_func(p, rumorTag)
    totalLoss += loss
    loss.backward()
    optimizer.step()

    p = softmax(p, dim=1)


In [8]:
from tqdm import tqdm
from torch.nn.functional import softmax
from sklearn.metrics import f1_score

for epoch in range(1, 100):
    stanceTrue = []
    stancePre = []
    totalLoss = 0.

    for data in loader:
        stanceTag = data['stanceTag']
        stanceTrue += data['stanceTag'].tolist()

        model.train()
        optimizer.zero_grad()
        p = model.forwardStance(data)
        loss = loss_func(p, stanceTag)
        totalLoss += loss
        loss.backward()
        optimizer.step()

        p = softmax(p, dim=1)
        stancePre += p.max(dim=1)[1].tolist()
        
    accuracy = (np.array(stancePre) == np.array(stanceTrue)).sum() / len(stanceTrue)
    macroF1Stance = f1_score(stanceTrue, stancePre, labels=[0,1,2,3], average='macro')
    print(totalLoss / len(loader), macroF1Stance, accuracy)


tensor(0.8827, grad_fn=<DivBackward0>) 0.3664517167458828 0.6809815950920245
tensor(0.8805, grad_fn=<DivBackward0>) 0.3739524186044866 0.6835771590372818
tensor(0.8814, grad_fn=<DivBackward0>) 0.37313296977890675 0.6885323265691364
tensor(0.8790, grad_fn=<DivBackward0>) 0.3692088799601987 0.6826333176026428
tensor(0.8803, grad_fn=<DivBackward0>) 0.35746072011779106 0.683341198678622
tensor(0.8761, grad_fn=<DivBackward0>) 0.37023320623264133 0.6798017932987258
tensor(0.8817, grad_fn=<DivBackward0>) 0.36673998929704593 0.6864086833411986
tensor(0.8808, grad_fn=<DivBackward0>) 0.38842461678550755 0.6882963662104766
tensor(0.8771, grad_fn=<DivBackward0>) 0.3722053150436 0.6847569608305805
tensor(0.8797, grad_fn=<DivBackward0>) 0.375022099724999 0.6864086833411986
tensor(0.8785, grad_fn=<DivBackward0>) 0.365707365726919 0.6819254365266635
tensor(0.8811, grad_fn=<DivBackward0>) 0.3664137676632411 0.6828692779613025
tensor(0.8776, grad_fn=<DivBackward0>) 0.37172717601441924 0.6826333176026428