In [41]:
import torch

from transformers import BertTokenizer, BertModel, AutoTokenizer, DebertaModel


class deberta:

    def __init__(self):
        self.__name__ = 'deberta-base'
        # self.__num_node_features__ = 
        self.device = 'cpu'
        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base")
        self.model = DebertaModel.from_pretrained("microsoft/deberta-v3-base")
        
        # self.__output_dim__ = self.__model__.

    @property
    def num_node_features(self):
        return 768

    def to(self, device):
        self.model = self.model.to(device)
        self.device = device
        return self

    def forward(self, text):

        def model_forward_input(input):
            input = self.tokenizer(input, return_tensors='pt').to(self.device)
            output = self.model(**input).last_hidden_state.mean(dim=1)
            # print(output.shape)
            # return self.model(**input).last_hidden_state.mean(dim=1)
            # print(output.shape)
            return torch.squeeze(output)

        return torch.stack(list(map(model_forward_input, text)))

    def __call__(self, data):
        x = self.forward(data.text)
        data.x = x
        return data

In [42]:
lm = deberta()




SSLError: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /microsoft/deberta-v3-base/resolve/main/tokenizer_config.json (Caused by SSLError(SSLZeroReturnError(6, 'TLS/SSL connection has been closed (EOF) (_ssl.c:1135)')))"), '(Request ID: 4a1d9b88-9e74-452c-acc6-98e8c291150a)')

In [30]:
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset

dataset = dgl.data.CoraGraphDataset()
device = 'cuda'      # change to 'cuda' for GPU
graph = dataset[0]

train_mask = graph.ndata['train_mask']
train_nids = torch.nonzero(train_mask, as_tuple=False).squeeze()
val_mask = graph.ndata['val_mask']
val_nids = torch.nonzero(val_mask, as_tuple=False).squeeze()
test_mask = graph.ndata['test_mask']
test_nids = torch.nonzero(test_mask, as_tuple=False).squeeze()

sampler = dgl.dataloading.NeighborSampler([4, 4])
train_dataloader = dgl.dataloading.DataLoader(
    # The following arguments are specific to DGL's DataLoader.
    graph,              # The graph
    train_nids,         # The node IDs to iterate over in minibatches
    sampler,            # The neighbor sampler
    device=device,      # Put the sampled MFGs on CPU or GPU
    # The following arguments are inherited from PyTorch DataLoader.
    batch_size=64,    # Batch size
    shuffle=True,       # Whether to shuffle the nodes for every epoch
    drop_last=False,    # Whether to drop the last incomplete batch
    num_workers=0       # Number of sampler processes
)


  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [31]:
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv

class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean')
        self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type='mean')
        self.h_feats = h_feats

    def forward(self, mfgs, x):
        # Lines that are changed are marked with an arrow: "<---"

        h_dst = x[:mfgs[0].num_dst_nodes()]  # <---
        h = self.conv1(mfgs[0], (x, h_dst))  # <---
        h = F.relu(h)
        h_dst = h[:mfgs[1].num_dst_nodes()]  # <---
        h = self.conv2(mfgs[1], (h, h_dst))  # <---
        return h

model = Model(in_feats=1433, h_feats=64, num_classes=7).to(device=device)


In [32]:
opt = torch.optim.Adam(model.parameters())
valid_dataloader = dgl.dataloading.DataLoader(
    graph, val_nids, sampler,
    batch_size=1024,
    shuffle=False,
    drop_last=False,
    num_workers=0,
    device=device
)

In [33]:
import tqdm
import sklearn.metrics

best_accuracy = 0
best_model_path = 'model.pt'
for epoch in range(10):
    model.train()

    with tqdm.tqdm(train_dataloader) as tq:
        for step, (input_nodes, output_nodes, mfgs) in enumerate(tq):
            # feature copy from CPU to GPU takes place here
            inputs = mfgs[0].srcdata['feat']
            labels = mfgs[-1].dstdata['label']

            predictions = model(mfgs, inputs)

            loss = F.cross_entropy(predictions, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

            accuracy = sklearn.metrics.accuracy_score(labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy())

            tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % accuracy}, refresh=False)

    model.eval()

    predictions = []
    labels = []
    with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
        for input_nodes, output_nodes, mfgs in tq:
            inputs = mfgs[0].srcdata['feat']
            labels.append(mfgs[-1].dstdata['label'].cpu().numpy())
            predictions.append(model(mfgs, inputs).argmax(1).cpu().numpy())
        predictions = np.concatenate(predictions)
        labels = np.concatenate(labels)
        accuracy = sklearn.metrics.accuracy_score(labels, predictions)
        print('Epoch {} Validation Accuracy {}'.format(epoch, accuracy))
        if best_accuracy < accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), best_model_path)

        # Note that this tutorial do not train the whole model to the end.
        # break

100%|██████████| 3/3 [00:00<00:00, 56.07it/s, loss=1.937, acc=0.167]
100%|██████████| 1/1 [00:00<00:00, 43.96it/s]


Epoch 0 Validation Accuracy 0.078


100%|██████████| 3/3 [00:00<00:00, 54.62it/s, loss=1.921, acc=0.083]
100%|██████████| 1/1 [00:00<00:00, 41.16it/s]


Epoch 1 Validation Accuracy 0.086


100%|██████████| 3/3 [00:00<00:00, 50.66it/s, loss=1.915, acc=0.250]
100%|██████████| 1/1 [00:00<00:00, 43.46it/s]


Epoch 2 Validation Accuracy 0.104


100%|██████████| 3/3 [00:00<00:00, 62.09it/s, loss=1.909, acc=0.333]
100%|██████████| 1/1 [00:00<00:00, 46.52it/s]


Epoch 3 Validation Accuracy 0.138


100%|██████████| 3/3 [00:00<00:00, 55.89it/s, loss=1.840, acc=0.667]
100%|██████████| 1/1 [00:00<00:00, 45.07it/s]


Epoch 4 Validation Accuracy 0.204


100%|██████████| 3/3 [00:00<00:00, 58.14it/s, loss=1.830, acc=0.750]
100%|██████████| 1/1 [00:00<00:00, 44.77it/s]


Epoch 5 Validation Accuracy 0.304


100%|██████████| 3/3 [00:00<00:00, 53.85it/s, loss=1.829, acc=0.917]
100%|██████████| 1/1 [00:00<00:00, 39.74it/s]


Epoch 6 Validation Accuracy 0.386


100%|██████████| 3/3 [00:00<00:00, 45.91it/s, loss=1.769, acc=0.917]
100%|██████████| 1/1 [00:00<00:00, 34.88it/s]


Epoch 7 Validation Accuracy 0.47


100%|██████████| 3/3 [00:00<00:00, 49.34it/s, loss=1.719, acc=1.000]
100%|██████████| 1/1 [00:00<00:00, 41.43it/s]


Epoch 8 Validation Accuracy 0.518


100%|██████████| 3/3 [00:00<00:00, 49.48it/s, loss=1.736, acc=0.917]
100%|██████████| 1/1 [00:00<00:00, 34.28it/s]


Epoch 9 Validation Accuracy 0.546
