In [1]:
import os

import numpy as np
from sklearn.metrics import precision_recall_fscore_support
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from transformers import BertConfig


from dataloading_utils import load_dataset, load_deprels
from dataloader import get_data_loaders
from modeling.bert import BertRGCNRelationClassifierResidual

In [2]:
eval_dataset = "mscorpus"
base_path = "/projects/flow_graphs/"
batch_size = 16
graph_data_source = "dep"
max_seq_len = 512
bert_model = "bert-base-uncased"
gnn = "rgcn"
gnn_depth = 4
node_emb_dim =  768
device = "cuda:0"

In [3]:
dataset = load_dataset(base_path, eval_dataset)

In [4]:
train_data = dataset["train"]["rels"]
dev_data = dataset["dev"]["rels"]
test_data = dataset["test"]["rels"]

deprel_dict = load_deprels(
    path=os.path.join(base_path, "data", "enh_dep_rel.txt"), enhanced=False
)

print(
    "train size: {}, dev size {}, test size: {}".format(
        len(train_data), len(dev_data), len(test_data)
    )
)

src_labels = set([data["label"] for data in dataset["train"]["rels"]])
train_labels = [data["label"] for data in train_data]
labels = sorted(list(set(train_labels)))
lbl2id = {lbl: idx for idx, lbl in enumerate(labels)}

100%|██████████| 365/365 [00:00<00:00, 502270.66it/s]

train size: 12330, dev size 2287, test size: 3782





In [5]:
train_loader, dev_loader, test_loader = get_data_loaders(
    train_data,
    dev_data,
    test_data,
    lbl2id,
    graph_data_source,
    max_seq_len,
    batch_size,
)

In [6]:
bertconfig = BertConfig.from_pretrained(bert_model, num_labels=len(labels))
if "bert-large" in bert_model:
    bertconfig.relation_emb_dim = 1024
elif "bert-base" in bert_model:
    bertconfig.relation_emb_dim = 768

bertconfig.node_emb_dim = node_emb_dim
bertconfig.dep_rels = len(deprel_dict)
bertconfig.gnn_depth = gnn_depth
bertconfig.gnn = gnn

In [7]:
checkpoint_file = "/projects/flow_graphs/checkpoints/transfer-risec-mscorpus-fewshot_0.01-dep_residual-residual-rgcn-depth_4-seed_1-lr_2e-05.pt"
state_dict = torch.load(checkpoint_file)

In [8]:
model = BertRGCNRelationClassifierResidual.from_pretrained(bert_model, config=bertconfig, use_graph_data=True)
model.load_state_dict(state_dict)
model.eval()
model.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertRGCNRelationClassifierResidual: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertRGCNRelationClassifierResidual from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertRGCNRelationClassifierResidual from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertRGCNRelationClassifierResidual were not initialized from the

BertRGCNRelationClassifierResidual(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, el

In [9]:
y_true = []
y_pred = []

for data in tqdm(dev_loader):
    tokens_tensors = data["tokens_tensors"].to(device)
    segments_tensors = data["segments_tensors"].to(device)
    e1_mask = data["e1_mask"].to(device)
    e2_mask = data["e2_mask"].to(device)
    masks_tensors = data["masks_tensors"].to(device)
    labels = data["label_ids"].to(device)

    if data["graph_data"] is not None:
        graph_data = data["graph_data"].to(device)
    else:
        graph_data = None

    print(graph_data)

    with torch.no_grad():
        output_dict = model(
            input_ids=tokens_tensors,
            token_type_ids=segments_tensors,
            e1_mask=e1_mask,
            e2_mask=e2_mask,
            attention_mask=masks_tensors,
            labels=labels,
            graph_data=graph_data,
                    # can optionally pass in dependency/amr tensors if we want to use both.
        )

    _, pred = torch.max(output_dict["logits"], 1)
    y_pred.extend(list(np.array(pred.cpu().detach())))
    y_true.extend(list(np.array(labels.cpu().detach())))
    

  0%|          | 0/143 [00:00<?, ?it/s]

DataBatch(x=[663, 512], edge_index=[2, 647], edge_type=[647], n1_mask=[663], n2_mask=[663], batch=[663], ptr=[17])
DataBatch(x=[579, 512], edge_index=[2, 563], edge_type=[563], n1_mask=[579], n2_mask=[579], batch=[579], ptr=[17])
DataBatch(x=[671, 512], edge_index=[2, 655], edge_type=[655], n1_mask=[671], n2_mask=[671], batch=[671], ptr=[17])
DataBatch(x=[554, 512], edge_index=[2, 538], edge_type=[538], n1_mask=[554], n2_mask=[554], batch=[554], ptr=[17])
DataBatch(x=[512, 512], edge_index=[2, 496], edge_type=[496], n1_mask=[512], n2_mask=[512], batch=[512], ptr=[17])
DataBatch(x=[1160, 512], edge_index=[2, 1144], edge_type=[1144], n1_mask=[1160], n2_mask=[1160], batch=[1160], ptr=[17])
DataBatch(x=[594, 512], edge_index=[2, 578], edge_type=[578], n1_mask=[594], n2_mask=[594], batch=[594], ptr=[17])
DataBatch(x=[610, 512], edge_index=[2, 594], edge_type=[594], n1_mask=[610], n2_mask=[610], batch=[610], ptr=[17])
DataBatch(x=[704, 512], edge_index=[2, 688], edge_type=[688], n1_mask=[704

In [10]:
precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, average="macro")
precision, recall, f1, support

  _warn_prf(average, modifier, msg_start, len(result))


(0.6311509707651913, 0.5109942515561031, 0.5199184395532496, None)