In [1]:
import torch

from transformers import BertTokenizer, BertModel, AutoTokenizer, DebertaModel, AutoModel, PreTrainedModel


class deberta:

    def __init__(self):
        self.__name__ = 'microsoft/deberta-base'
        self.__num_node_features__ = 768 
        self.device = 'cpu'
        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-base")
# Load model directly
        self.model = AutoModel.from_pretrained("microsoft/deberta-base")
        # self.model = DebertaModel.from_pretrained("microsoft/deberta-base")
        
        # self.__output_dim__ = self.__model__.
    # @property
    def parameters(self):
        return self.model.parameters()

    @property
    def num_node_features(self):
        return 768

    def to(self, device):
        self.model = self.model.to(device)
        self.device = device
        return self

    def forward(self, text):

        def model_forward_input(input):
            input = self.tokenizer(input, return_tensors='pt').to(self.device)
            output = self.model(**input).last_hidden_state.mean(dim=1)
            # print(output.shape)
            # return self.model(**input).last_hidden_state.mean(dim=1)
            # print(output.shape)
            return torch.squeeze(output)

        return torch.stack(list(map(model_forward_input, text)))

    def __call__(self, data):
        if isinstance(data, str):
            return self.forward([data])
        if isinstance(data, list):
            return self.forward(data)
        # x = self.forward(data.text)
        # data.x = x
        # return data

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device(0)     # change to 'cuda' for GPU
# device = torch.device(1)     # change to 'cuda' for GPU

In [3]:
lm = deberta().to(device)
# lm('I love you', input_text=True)



In [4]:
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset

dataset = dgl.data.CoraGraphDataset()
# device = 'cuda'      # change to 'cuda' for GPU
graph = dataset[0]

train_mask = graph.ndata['train_mask']
train_nids = torch.nonzero(train_mask, as_tuple=False).squeeze()
val_mask = graph.ndata['val_mask']
val_nids = torch.nonzero(val_mask, as_tuple=False).squeeze()
test_mask = graph.ndata['test_mask']
test_nids = torch.nonzero(test_mask, as_tuple=False).squeeze()

sampler = dgl.dataloading.NeighborSampler([4, 4])
train_dataloader = dgl.dataloading.DataLoader(
    # The following arguments are specific to DGL's DataLoader.
    graph,              # The graph
    train_nids,         # The node IDs to iterate over in minibatches
    sampler,            # The neighbor sampler
    device=device,      # Put the sampled MFGs on CPU or GPU
    # The following arguments are inherited from PyTorch DataLoader.
    batch_size=8,    # Batch size
    shuffle=True,       # Whether to shuffle the nodes for every epoch
    drop_last=True,    # Whether to drop the last incomplete batch
    num_workers=0       # Number of sampler processes
)


  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [5]:
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv

class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean')
        self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type='mean')
        self.h_feats = h_feats

    def forward(self, mfgs, x):
        # Lines that are changed are marked with an arrow: "<---"
        
        h_dst = x[:mfgs[0].num_dst_nodes()]  # <---
        h = self.conv1(mfgs[0], (x, h_dst))  # <---
        h = F.relu(h)
        h_dst = h[:mfgs[1].num_dst_nodes()]  # <---
        h = self.conv2(mfgs[1], (h, h_dst))  # <---
        return h

model = Model(in_feats=lm.__num_node_features__, h_feats=64, num_classes=7).to(device)
# model = Model(in_feats=1433, h_feats=64, num_classes=7).to(device)


In [6]:
# opt = torch.optim.Adam(list(model.parameters())+list(lm.parameters())) # 
opt = torch.optim.Adam([
    {'params': lm.parameters(), 'lr': 1e-4},
    {'params': model.parameters(), 'lr': 1e-1}])
valid_dataloader = dgl.dataloading.DataLoader(
    graph, val_nids, sampler,
    batch_size=8,
    shuffle=False,
    drop_last=False,
    num_workers=0,
    device=device
)

In [7]:
import tqdm
import sklearn.metrics

from data_utils import load_data

best_accuracy = 0
best_model_path = 'model.pt'
dataset, num_classes, text = load_data('cora', use_dgl=True, use_text=True)
for epoch in range(100):
    model.train()

    with tqdm.tqdm(train_dataloader) as tq:
        for step, (input_nodes, output_nodes, mfgs) in enumerate(tq):
            # print(input_nodes, output_nodes)
            # break
            # feature copy from CPU to GPU takes place here
            # inputs = mfgs[0].srcdata['feat']
            # print(inputs.shape, input_nodes.shape)
            inputs = [text[i] for i in input_nodes]
            inputs = lm(inputs).to(device)
            # mfgs = mfgs.to(device)
            labels = mfgs[-1].dstdata['label']
            # print(len(labels), inputs.shape, input_nodes.shape, output_nodes.shape)    
            # break        
            predictions = model(mfgs, inputs)

            loss = F.cross_entropy(predictions, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

            accuracy = sklearn.metrics.accuracy_score(labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy())

            tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % accuracy}, refresh=False)

            del input_nodes, output_nodes, mfgs, inputs, labels, predictions, loss
            torch.cuda.empty_cache()

    model.eval()

    predictions = []
    labels = []
    with torch.no_grad() and tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
        for input_nodes, output_nodes, mfgs in tq:
            inputs = [text[i] for i in input_nodes]
            # inputs = mfgs[0].srcdata['feat']
            labels.append(mfgs[-1].dstdata['label'].cpu().numpy())
            inputs = lm(inputs).to(device)
            predictions.append(model(mfgs, inputs).argmax(1).cpu().numpy())
        predictions = np.concatenate(predictions)
        labels = np.concatenate(labels)
        accuracy = sklearn.metrics.accuracy_score(labels, predictions)
        print('Epoch {} Validation Accuracy {}'.format(epoch, accuracy))
        if best_accuracy < accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), best_model_path)
        #     inputs = mfgs[0].srcdata['feat']
        #     labels.append(mfgs[-1].dstdata['label'].cpu().numpy())
        #     predictions.append(model(mfgs, inputs).argmax(1).cpu().numpy())
        # predictions = np.concatenate(predictions)
        # labels = np.concatenate(labels)
        # accuracy = sklearn.metrics.accuracy_score(labels, predictions)
        print('Epoch {} Validation Accuracy {} Best Accuracy {}'.format(epoch, accuracy, best_accuracy))
        if best_accuracy < accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), best_model_path)

        # Note that this tutorial do not train the whole model to the end.
        # break

100%|██████████| 17/17 [01:43<00:00,  6.10s/it, loss=2.046, acc=0.125] 
100%|██████████| 63/63 [01:21<00:00,  1.29s/it]


Epoch 0 Validation Accuracy 0.072
Epoch 0 Validation Accuracy 0.072 Best Accuracy 0.072


100%|██████████| 17/17 [01:42<00:00,  6.03s/it, loss=1.995, acc=0.125]
100%|██████████| 63/63 [01:21<00:00,  1.29s/it]


Epoch 1 Validation Accuracy 0.072
Epoch 1 Validation Accuracy 0.072 Best Accuracy 0.072


100%|██████████| 17/17 [01:41<00:00,  5.97s/it, loss=1.931, acc=0.250]
100%|██████████| 63/63 [01:21<00:00,  1.29s/it]


Epoch 2 Validation Accuracy 0.316
Epoch 2 Validation Accuracy 0.316 Best Accuracy 0.316


100%|██████████| 17/17 [01:44<00:00,  6.17s/it, loss=1.952, acc=0.250]
100%|██████████| 63/63 [01:22<00:00,  1.31s/it]


Epoch 3 Validation Accuracy 0.114
Epoch 3 Validation Accuracy 0.114 Best Accuracy 0.316


100%|██████████| 17/17 [01:43<00:00,  6.09s/it, loss=1.954, acc=0.125]
100%|██████████| 63/63 [01:21<00:00,  1.30s/it]


Epoch 4 Validation Accuracy 0.122
Epoch 4 Validation Accuracy 0.122 Best Accuracy 0.316


100%|██████████| 17/17 [01:43<00:00,  6.09s/it, loss=1.959, acc=0.125]
100%|██████████| 63/63 [01:20<00:00,  1.28s/it]


Epoch 5 Validation Accuracy 0.162
Epoch 5 Validation Accuracy 0.162 Best Accuracy 0.316


100%|██████████| 17/17 [01:46<00:00,  6.27s/it, loss=1.987, acc=0.125]
100%|██████████| 63/63 [01:22<00:00,  1.30s/it]


Epoch 6 Validation Accuracy 0.156
Epoch 6 Validation Accuracy 0.156 Best Accuracy 0.316


100%|██████████| 17/17 [01:41<00:00,  6.00s/it, loss=1.960, acc=0.125]
100%|██████████| 63/63 [01:22<00:00,  1.31s/it]


Epoch 7 Validation Accuracy 0.316
Epoch 7 Validation Accuracy 0.316 Best Accuracy 0.316


100%|██████████| 17/17 [01:42<00:00,  6.03s/it, loss=1.967, acc=0.125]
100%|██████████| 63/63 [01:22<00:00,  1.31s/it]


Epoch 8 Validation Accuracy 0.162
Epoch 8 Validation Accuracy 0.162 Best Accuracy 0.316


100%|██████████| 17/17 [01:39<00:00,  5.84s/it, loss=2.025, acc=0.000]
100%|██████████| 63/63 [01:22<00:00,  1.30s/it]


Epoch 9 Validation Accuracy 0.122
Epoch 9 Validation Accuracy 0.122 Best Accuracy 0.316


100%|██████████| 17/17 [01:43<00:00,  6.10s/it, loss=2.029, acc=0.250]
100%|██████████| 63/63 [01:22<00:00,  1.31s/it]


Epoch 10 Validation Accuracy 0.114
Epoch 10 Validation Accuracy 0.114 Best Accuracy 0.316


100%|██████████| 17/17 [01:42<00:00,  6.06s/it, loss=1.992, acc=0.000]
100%|██████████| 63/63 [01:21<00:00,  1.30s/it]


Epoch 11 Validation Accuracy 0.316
Epoch 11 Validation Accuracy 0.316 Best Accuracy 0.316


100%|██████████| 17/17 [01:41<00:00,  5.96s/it, loss=1.976, acc=0.000]
100%|██████████| 63/63 [01:22<00:00,  1.31s/it]


Epoch 12 Validation Accuracy 0.072
Epoch 12 Validation Accuracy 0.072 Best Accuracy 0.316


100%|██████████| 17/17 [01:45<00:00,  6.23s/it, loss=1.973, acc=0.250]
100%|██████████| 63/63 [01:21<00:00,  1.29s/it]


Epoch 13 Validation Accuracy 0.058
Epoch 13 Validation Accuracy 0.058 Best Accuracy 0.316


100%|██████████| 17/17 [01:41<00:00,  5.94s/it, loss=2.089, acc=0.000]
100%|██████████| 63/63 [01:22<00:00,  1.30s/it]


Epoch 14 Validation Accuracy 0.316
Epoch 14 Validation Accuracy 0.316 Best Accuracy 0.316


100%|██████████| 17/17 [01:40<00:00,  5.93s/it, loss=1.971, acc=0.250]
100%|██████████| 63/63 [01:21<00:00,  1.29s/it]


Epoch 15 Validation Accuracy 0.122
Epoch 15 Validation Accuracy 0.122 Best Accuracy 0.316


100%|██████████| 17/17 [01:43<00:00,  6.10s/it, loss=1.976, acc=0.000]
100%|██████████| 63/63 [01:21<00:00,  1.29s/it]


Epoch 16 Validation Accuracy 0.122
Epoch 16 Validation Accuracy 0.122 Best Accuracy 0.316


100%|██████████| 17/17 [01:45<00:00,  6.18s/it, loss=1.977, acc=0.250]
100%|██████████| 63/63 [01:21<00:00,  1.29s/it]


Epoch 17 Validation Accuracy 0.316
Epoch 17 Validation Accuracy 0.316 Best Accuracy 0.316


100%|██████████| 17/17 [01:43<00:00,  6.08s/it, loss=1.998, acc=0.000]
100%|██████████| 63/63 [01:21<00:00,  1.29s/it]


Epoch 18 Validation Accuracy 0.072
Epoch 18 Validation Accuracy 0.072 Best Accuracy 0.316


100%|██████████| 17/17 [01:43<00:00,  6.06s/it, loss=1.950, acc=0.125]
100%|██████████| 63/63 [01:22<00:00,  1.31s/it]


Epoch 19 Validation Accuracy 0.122
Epoch 19 Validation Accuracy 0.122 Best Accuracy 0.316


100%|██████████| 17/17 [01:40<00:00,  5.90s/it, loss=1.923, acc=0.125]
100%|██████████| 63/63 [01:21<00:00,  1.29s/it]


Epoch 20 Validation Accuracy 0.162
Epoch 20 Validation Accuracy 0.162 Best Accuracy 0.316


100%|██████████| 17/17 [01:41<00:00,  6.00s/it, loss=2.015, acc=0.000]
100%|██████████| 63/63 [01:21<00:00,  1.29s/it]


Epoch 21 Validation Accuracy 0.072
Epoch 21 Validation Accuracy 0.072 Best Accuracy 0.316


100%|██████████| 17/17 [01:40<00:00,  5.93s/it, loss=1.985, acc=0.000]
100%|██████████| 63/63 [01:22<00:00,  1.30s/it]


Epoch 22 Validation Accuracy 0.114
Epoch 22 Validation Accuracy 0.114 Best Accuracy 0.316


100%|██████████| 17/17 [01:42<00:00,  6.04s/it, loss=2.025, acc=0.000]
100%|██████████| 63/63 [01:21<00:00,  1.29s/it]


Epoch 23 Validation Accuracy 0.072
Epoch 23 Validation Accuracy 0.072 Best Accuracy 0.316


 47%|████▋     | 8/17 [00:51<00:58,  6.49s/it, loss=1.996, acc=0.000]


KeyboardInterrupt: 