In [None]:
from torch_geometric.nn.models import MLP
from torch_geometric.nn.conv import GCNConv, GINConv, SAGEConv
import torch
import torch.nn.functional as F
# from sentence_transformers import SentenceTransformer
import torch.nn as nn
from torch_geometric.nn import LabelPropagation
from torch_geometric.nn.models import GAT
import dgl.nn.pytorch as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv as PYGGATConv
# import rev.memgcn as memgcn
# from rev.rev_layer import SharedDropout
import copy
from torch_sparse import SparseTensor, matmul
from torch_geometric.utils import degree
import numpy as np
import math
import tqdm
from dgl import function as fn
from dgl._ffi.base import DGLError
from dgl.nn.pytorch.utils import Identity
from dgl.ops import edge_softmax
from dgl.utils import expand_as_pair
from torch_geometric.loader import NeighborLoader

In [None]:
class GCN(torch.nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dimension, num_classes, dropout, norm=None) -> None:
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()
        self.num_layers = num_layers
        self.dropout = dropout
        if num_layers == 1:
            self.convs.append(GCNConv(input_dim, num_classes, cached=False,
                             normalize=True))
        else:
            self.convs.append(GCNConv(input_dim, hidden_dimension, cached=False,
                             normalize=True))
            if norm:
                self.norms.append(torch.nn.BatchNorm1d(hidden_dimension))
            else:
                self.norms.append(torch.nn.Identity())

            for _ in range(num_layers - 2):
                self.convs.append(GCNConv(hidden_dimension, hidden_dimension, cached=False,
                             normalize=True))
                if norm:
                    self.norms.append(torch.nn.BatchNorm1d(hidden_dimension))
                else:
                    self.norms.append(torch.nn.Identity())

            self.convs.append(GCNConv(hidden_dimension, num_classes, cached=False, normalize=True))

    def forward(self, x, edge_index):
        # x, edge_index, edge_weight= data.x, data.edge_index, data.edge_weight
        for i in range(self.num_layers):
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.convs[i](x, edge_index)
            if i != self.num_layers - 1:
                x = self.norms[i](x)
                x = F.relu(x)
        return x

In [None]:
import torch

from transformers import BertTokenizer, BertModel, AutoTokenizer, DebertaModel, AutoModel, PreTrainedModel

class deberta:

    def __init__(self):
        self.__name__ = 'microsoft/deberta-base'
        self.__num_node_features__ = 768 
        self.device = 'cpu'
        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-base")
# Load model directly
        self.model = AutoModel.from_pretrained("microsoft/deberta-base")
        # self.model = DebertaModel.from_pretrained("microsoft/deberta-base")
        
        # self.__output_dim__ = self.__model__.
    # @property
    def parameters(self):
        return self.model.parameters()

    def train(self):
        self.model.train()

    def eval(self):
        self.model.eval()

    @property
    def num_node_features(self):
        return 768

    def to(self, device):
        self.model = self.model.to(device)
        self.device = device
        return self

    def forward(self, text):

        def model_forward_input(input):
            input = self.tokenizer(input, return_tensors='pt').to(self.device)
            output = self.model(**input).last_hidden_state.mean(dim=1)
            # print(output.shape)
            # return self.model(**input).last_hidden_state.mean(dim=1)
            # print(output.shape)
            return torch.squeeze(output)

        return torch.stack(list(map(model_forward_input, text)))

    def __call__(self, data):
        if isinstance(data, str):
            return self.forward([data])
        if isinstance(data, list):
            return self.forward(data)

In [None]:
def seed(seed_val):
    import random
    import numpy as np
    import torch
    # import tensorflow as tf
    # tf.random.set_seed(seed_value)
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
seed(42)

In [None]:
data
data, num_classes, text = load_data('ogbn-arxiv', use_dgl=False, use_text=True)
# data.text = text
# data.x = text

In [None]:
lm = deberta()
gcn = GCN(num_layers=2, input_dim=lm.num_node_features, hidden_dimension=128, num_classes=num_classes, dropout=0.1, norm=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

lm = lm.to(device)
gcn = gcn.to(device)

# Create NeighborLoader for mini-batch training
train_loader = NeighborLoader(data, input_nodes=data.train_mask, num_neighbors=[2, 2], batch_size=4, shuffle=True)
valid_loader = NeighborLoader(data, input_nodes=data.val_mask, num_neighbors=[2, 2], batch_size=4, shuffle=False)
test_loader = NeighborLoader(data, input_nodes=data.test_mask, num_neighbors=[2, 2], batch_size=4, shuffle=False)

optimizer = torch.optim.Adam([
    {'params': lm.parameters(), 'lr': 1e-4, 'weight_decay': 5e-4},
    {'params': gcn.parameters(), 'lr': 0.01, 'weight_decay': 5e-4}])

In [None]:
cnt = 0
for data in train_loader:
    cnt += 1
    print(cnt)
print(cnt)

In [None]:
def train():
    lm.train()
    gcn.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = lm([text[i] for i in batch.n_id])
        out = gcn(out, batch.edge_index)
        loss = F.cross_entropy(out[batch.train_mask], batch.y[batch.train_mask].squeeze())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        del out
        torch.cuda.empty_cache()
    return total_loss / len(train_loader)

def test(loader):
    lm.eval()
    gcn.eval()
    correct = 0
    total = 0
    for batch in loader:
        batch = batch.to(device)
        out = lm([text[i] for i in batch.n_id])
        out = gcn(out, batch.edge_index)
        pred = out.argmax(dim=1)
        correct += (pred[batch.test_mask] == batch.y[batch.test_mask].squeeze()).sum().item()
        total += batch.test_mask.sum().item()
        del out
        torch.cuda.empty_cache()
    return correct / total

In [None]:
for epoch in range(1, 201):
    loss = train()
    with torch.no_grad():
        train_acc = test(train_loader)
        valid_acc = test(valid_loader)
        test_acc = test(test_loader)
    # if epoch % 10 == 0:
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, '
            f'Valid Acc: {valid_acc:.4f}, Test Acc: {test_acc:.4f}')