In [3]:
from sentence_transformers import SentenceTransformer
import torch
from torch_geometric.nn import GCNConv  # or GraphSAGE, GAT, etc.
from torch_geometric.data import Data
import numpy as np
import pandas as pd

import pickle

import json
import os

In [None]:
#Get Original Train dataset
with open(os.path.join("./Released-Microsoft-dataset/", "train.txt"), encoding='utf-8') as f:
    train_json_data = json.load(f)

#Get the edge information that I already computed
with open('citationEdges.pkl', 'rb') as file:
    allCitations = pickle.load(file)

#Get the encoded features from the pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb model
encodedFeatures = np.load("./encodedTrainFeatures.npy")
communities = pd.read_csv("./community_assignments.csv")

In [5]:
#Find the IDs of all the journels and set custom IDs for training (0-N)
allIds = []

for idx, doc in enumerate(train_json_data):
    allIds.append(int(doc["publication_ID"]))

sortedIds = sorted(allIds)
pubIdToId = {id: idx for idx, id in enumerate(sortedIds)}

In [6]:
#Remove NA rows and sort communities by node id
communities = communities.dropna(subset=['node_id'])
communities['node_id'] = communities['node_id'].astype(int)
communities = communities.sort_values(by='node_id')


In [7]:
#Create mapping of paper to community (not complete cause of data)
pubIdToCommunity = {}
for index, row in communities.iterrows():
    pubIdToCommunity[int(row["node_id"])] = int(row["community_id"])

In [None]:
#Get the communities of each node as label tensor
allCommunities = []


for idx, doc in enumerate(train_json_data):
    docId = int(doc["publication_ID"])

    #Most of the training data missing so skipped
    if docId not in list(pubIdToCommunity.keys()):
        continue

    allCommunities.append(pubIdToCommunity[docId])
    

In [9]:
y = torch.tensor(allCommunities, dtype=torch.long)

In [27]:
#Data object needed for GNNs
data = Data(
    x=torch.tensor(encodedFeatures),         # [num_nodes, embedding_dim]
    edge_index=torch.tensor(allCitations).T,  # [2, num_edges]
    y=y,                    # [num_nodes]
)

In [13]:
#Simple GNN model
class GNNModel(torch.nn.Module):
    def __init__(self):
        super(GNNModel, self).__init__()
        self.gcn1 = GCNConv(768, 256)
        self.gcn2 = GCNConv(256, 128)
        self.relu = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(0.3)

    def forward(self, x, edge_index):
        x = self.relu(self.gcn1(x, edge_index))
        x = self.dropout(x)
        x = self.gcn2(x, edge_index)
        return x

In [None]:
#Train loop, errors out cause data incomplete
model = GNNModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Assuming `data` is your PyG Data object (with .x, .edge_index, .y, .train_mask)
model.train()
for epoch in range(100):
    optimizer.zero_grad()

    out = model(data.x, data.edge_index)             # forward pass
    loss = torch.nn.functional.cross_entropy(out, data.y)  # compute loss
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")


ValueError: Expected input batch_size (42000) to match target batch_size (939).