In [1]:
# STL
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# 3rd Party
import torch
from transformers import AutoTokenizer
# Local
from gatbert.data import parse_graph_tsv
from gatbert.graph_sample import GraphSample, Edge
from gatbert.constants import Stance

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
rob_tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-MNLI", use_fast=True, add_prefix_space=True)
rob_encoding = rob_tokenizer(['abbreviation of European Union'], is_split_into_words=True, return_offsets_mapping=True)
print(rob_encoding)
rob_tokenizer.convert_ids_to_tokens(rob_encoding['input_ids'])

{'input_ids': [0, 45986, 46021, 9, 796, 1332, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1], 'offset_mapping': [(0, 0), (0, 5), (5, 12), (13, 15), (16, 24), (25, 30), (0, 0)]}


['<s>', 'Ġabbre', 'viation', 'Ġof', 'ĠEuropean', 'ĠUnion', '</s>']

In [3]:
pretrained_model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name, use_fast=True, add_prefix_space=True)
encoding = tokenizer(['abbreviation of european union'], is_split_into_words=True, return_offsets_mapping=True)
print(encoding)
tokenizer.convert_ids_to_tokens(encoding['input_ids'])

{'input_ids': [101, 22498, 1997, 2647, 2586, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1], 'offset_mapping': [(0, 0), (0, 12), (13, 15), (16, 24), (25, 30), (0, 0)]}


['[CLS]', 'abbreviation', 'of', 'european', 'union', '[SEP]']

In [3]:
target = ["Pakistan", "government"]
context = "We need to stop supporting governmentalists who harbor terrorists .".split()
kb =["/c/en/governmentalists", "/c/en/pakistan"]
sample = GraphSample(
    stance=Stance.FAVOR,
    target=target,
    context=context,
    kb=kb,
    edges=[
        Edge(len(target) + 5, len(target) + len(context) + 0, 42), # govermenalists in context to /c/en/governmentalists
    ]
)


In [4]:
graph_gen = parse_graph_tsv('scrap2.tsv')

In [5]:
self = next(graph_gen)

In [6]:
result_dict = self.encode(tokenizer)

Token indices sequence length is longer than the specified maximum sequence length for this model (597 > 512). Running this sequence through the model will result in indexing errors


In [7]:
for (k, v) in result_dict.items():
    print(k, v.shape)

input_ids torch.Size([307])
node_mask torch.Size([1, 197, 307])
edge_indices torch.Size([4, 9754])
