In [1]:
import json
from transformers import BertTokenizer
from torch_geometric.nn import SAGEConv
import torch
from torch_geometric.data import Data


class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

In [10]:
def text2vec(_text: str):
    return tokenizer(_text, truncation=False, return_attention_mask=False, return_token_type_ids=False)['input_ids']

In [37]:
import re
def id2tokens(_ids):
    original_tokens = tokenizer.convert_ids_to_tokens(_ids)
    # 去除特殊字符和标点符号
    cleaned_tokens = [token.replace("##", "") for token in original_tokens if token != '[CLS]' and token != '[SEP]']
    return "".join(cleaned_tokens)

In [11]:
with open("bugs_psi/AaltoXml_1b-all.json") as f:
    psi = json.load(f)
encoded_input = text2vec("x")
encoded_input

[101, 193, 102]

In [32]:
model = GNN(hidden_channels=64, out_channels=16)
data = Data()
nodes = [[int(n['id']), [] if 'text' not in n else text2vec(n['text']), [] if 'qualified-name' not in n else text2vec(n['qualified-name'])] for n in psi['nodes']]
nodes

[[1,
  [],
  [101,
   120,
   188,
   19878,
   120,
   1514,
   120,
   179,
   15677,
   120,
   3254,
   120,
   4946,
   1775,
   1306,
   1233,
   120,
   170,
   1348,
   2430,
   120,
   1112,
   27250,
   120,
   1249,
   27250,
   2064,
   14300,
   1592,
   10582,
   1183,
   1708,
   7804,
   2511,
   119,
   179,
   15677,
   102]],
 [2,
  [101,
   120,
   115,
   138,
   1348,
   2430,
   24868,
   14538,
   115,
   115,
   22359,
   113,
   172,
   114,
   1386,
   118,
   22515,
   7926,
   18613,
   15186,
   1777,
   117,
   27629,
   7926,
   119,
   21718,
   24171,
   13130,
   137,
   178,
   2293,
   119,
   20497,
   115,
   115,
   24689,
   1181,
   1223,
   1103,
   24689,
   9467,
   1107,
   1103,
   4956,
   149,
   9741,
   11680,
   12649,
   1134,
   1110,
   115,
   1529,
   1114,
   1103,
   2674,
   3463,
   119,
   115,
   1192,
   1336,
   1136,
   1329,
   1142,
   4956,
   2589,
   1107,
   14037,
   1114,
   1103,
   24689,
   119,
   115,
   115

In [38]:
for node in nodes:
    text = id2tokens(node[2])
    if text == "com.fasterxml.aalto.async.AsyncByteScanner":
        print(node[0])

1111


In [42]:
data.x = nodes
data.x

[[1,
  [],
  [101,
   120,
   188,
   19878,
   120,
   1514,
   120,
   179,
   15677,
   120,
   3254,
   120,
   4946,
   1775,
   1306,
   1233,
   120,
   170,
   1348,
   2430,
   120,
   1112,
   27250,
   120,
   1249,
   27250,
   2064,
   14300,
   1592,
   10582,
   1183,
   1708,
   7804,
   2511,
   119,
   179,
   15677,
   102]],
 [2,
  [101,
   120,
   115,
   138,
   1348,
   2430,
   24868,
   14538,
   115,
   115,
   22359,
   113,
   172,
   114,
   1386,
   118,
   22515,
   7926,
   18613,
   15186,
   1777,
   117,
   27629,
   7926,
   119,
   21718,
   24171,
   13130,
   137,
   178,
   2293,
   119,
   20497,
   115,
   115,
   24689,
   1181,
   1223,
   1103,
   24689,
   9467,
   1107,
   1103,
   4956,
   149,
   9741,
   11680,
   12649,
   1134,
   1110,
   115,
   1529,
   1114,
   1103,
   2674,
   3463,
   119,
   115,
   1192,
   1336,
   1136,
   1329,
   1142,
   4956,
   2589,
   1107,
   14037,
   1114,
   1103,
   24689,
   119,
   115,
   115

In [43]:
data.y = [1111]

In [45]:
edges = [[int(e['source']), int(e['target'])] for e in psi['edges']]
edges

[[1, 2],
 [1, 3],
 [1, 4],
 [1, 5],
 [1, 6],
 [1, 7],
 [1, 8],
 [1, 9],
 [4, 10],
 [4, 11],
 [4, 12],
 [4, 13],
 [6, 14],
 [6, 15],
 [6, 16],
 [6, 17],
 [6, 18],
 [6, 19],
 [6, 20],
 [6, 21],
 [6, 22],
 [8, 23],
 [8, 24],
 [8, 25],
 [8, 26],
 [8, 27],
 [8, 28],
 [8, 29],
 [8, 30],
 [8, 31],
 [8, 32],
 [8, 33],
 [8, 34],
 [8, 35],
 [8, 36],
 [8, 37],
 [8, 38],
 [8, 39],
 [8, 40],
 [8, 41],
 [8, 42],
 [8, 43],
 [8, 44],
 [8, 45],
 [8, 46],
 [8, 47],
 [8, 48],
 [8, 49],
 [8, 50],
 [8, 51],
 [8, 52],
 [8, 53],
 [8, 54],
 [8, 55],
 [8, 56],
 [8, 57],
 [8, 58],
 [8, 59],
 [8, 60],
 [8, 61],
 [8, 62],
 [8, 63],
 [8, 64],
 [8, 65],
 [8, 66],
 [8, 67],
 [8, 68],
 [8, 69],
 [8, 70],
 [8, 71],
 [8, 72],
 [8, 73],
 [8, 74],
 [8, 75],
 [8, 76],
 [8, 77],
 [8, 78],
 [8, 79],
 [8, 80],
 [8, 81],
 [8, 82],
 [8, 83],
 [8, 84],
 [8, 85],
 [8, 86],
 [8, 87],
 [8, 88],
 [8, 89],
 [8, 90],
 [8, 91],
 [8, 92],
 [8, 93],
 [8, 94],
 [8, 95],
 [8, 96],
 [8, 97],
 [8, 98],
 [8, 99],
 [8, 100],
 [8, 101],
 [8, 1

In [46]:
data.edge_index = edges

In [47]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
import torch.nn.functional as F

def train(data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    mask = data['paper'].train_mask
    loss = F.cross_entropy(out['paper'][mask], data['paper'].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)

In [55]:
device = torch.device('cuda')
data = data.to(device)

In [57]:
model = model.to(device)

In [59]:
model(data, data.edge_index)

ValueError: `MessagePassing.propagate` only supports integer tensors of shape `[2, num_messages]`, `torch_sparse.SparseTensor` or `torch.sparse.Tensor` for argument `edge_index`.

In [61]:
id2tokens([1313, 2312312, 3])

'home[UNK][unused3]'