In [1]:
# STL
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# 3rd Party
import torch
from transformers import AutoTokenizer
# Local
from gatbert.data import parse_graph_tsv
from gatbert.graph_sample import GraphSample, Edge
from gatbert.constants import Stance

  from .autonotebook import tqdm as notebook_tqdm


In [46]:
rob_tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-MNLI", use_fast=True, add_prefix_space=True)
rob_encoding = rob_tokenizer(['abbreviation of European Union'], is_split_into_words=True, return_offsets_mapping=True)
print(rob_encoding)
rob_tokenizer.convert_ids_to_tokens(rob_encoding['input_ids'])

{'input_ids': [0, 45986, 46021, 9, 796, 1332, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1], 'offset_mapping': [(0, 0), (0, 5), (5, 12), (13, 15), (16, 24), (25, 30), (0, 0)]}


['<s>', 'Ġabbre', 'viation', 'Ġof', 'ĠEuropean', 'ĠUnion', '</s>']

In [2]:
pretrained_model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name, use_fast=True, add_prefix_space=True)
encoding = tokenizer(['abbreviation of european union'], is_split_into_words=True, return_offsets_mapping=True)
print(encoding)
tokenizer.convert_ids_to_tokens(encoding['input_ids'])

{'input_ids': [101, 22498, 1997, 2647, 2586, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1], 'offset_mapping': [(0, 0), (0, 12), (13, 15), (16, 24), (25, 30), (0, 0)]}


['[CLS]', 'abbreviation', 'of', 'european', 'union', '[SEP]']

In [3]:
target = ["Pakistan", "government"]
context = "We need to stop supporting governmentalists who harbor terrorists .".split()
kb =["/c/en/governmentalists", "/c/en/pakistan"]
sample = GraphSample(
    stance=Stance.FAVOR,
    target=target,
    context=context,
    kb=kb,
    edges=[
        Edge(len(target) + 5, len(target) + len(context) + 0, 42), # govermenalists in context to /c/en/governmentalists
    ]
)


In [3]:
graph_gen = parse_graph_tsv('scrap2.tsv')

In [4]:
self = next(graph_gen)

In [15]:
from gatbert.constants import MAX_KB_NODES, TOKEN_TO_TOKEN_RELATION_ID
import logging
from itertools import product
from collections import defaultdict, OrderedDict

# FIXME: This is assuming all the KB tokens are conceptnet URIs
clean_kb = [uri.split('/')[3] for uri in self.kb]
clean_kb = [uri.replace("_", ' ') for uri in clean_kb]

tokenized_text = tokenizer(text=self.target,
                           text_pair=self.context,
                           is_split_into_words=True,
                           return_offsets_mapping=True,
                           return_tensors='pt',
                           truncation='longest_first')
tokenized_kb = tokenizer(text=clean_kb,
                         is_split_into_words=True,
                         return_offsets_mapping=True,
                         return_tensors='pt')
device = tokenized_text['input_ids'].device

relevant_keys = ['input_ids', 'offset_mapping']
tokenized_text = {k:tokenized_text[k] for k in relevant_keys}
tokenized_kb = {k:tokenized_kb[k] for k in relevant_keys}
# Assumes the tokens with 0-length offsets are special tokens
real_toks = tokenized_kb['offset_mapping'][0, :, 1] != 0
for k in relevant_keys:
    tokenized_kb[k] = tokenized_kb[k][:, real_toks]

In [16]:
len(self.kb)

335

In [17]:
# old_node_index -> [new_node_indices]
expand_list = defaultdict(list)
# new_node_index -> [subword_indices]
pool_inds = OrderedDict()

new_nodes_index = -1
orig_nodes_index = -1

# For token subwords, we will split a token's nodes into subwords
token_offset_mapping = tokenized_text['offset_mapping'].squeeze()
# Handle splitting of token nodes into subword nodes
for (subword_index, (start, end)) in enumerate(token_offset_mapping):
    new_nodes_index += 1
    pool_inds[new_nodes_index] = []

    if start != end: # Real character, not a special character
        if start == 0: # Start of a token
            orig_nodes_index += 1
        expand_list[orig_nodes_index].append(new_nodes_index)
    pool_inds[new_nodes_index].append(subword_index)


# Need to fast-forward past the token nodes to the external ones
# Some of the token nodes may have been truncated by the tokenizer
orig_nodes_index = len(self.target) + len(self.context) - 1

# For KB subwords, we plan to pool each into one combined node
kb_offset_mapping = tokenized_kb['offset_mapping'].squeeze()
n_kb_nodes = 0
for (subword_index, (start, end)) in enumerate(kb_offset_mapping, start=subword_index + 1):
    if start == 0:
        assert end != 0, "Special tokens should have been scrubbed"
        if n_kb_nodes >= MAX_KB_NODES:
            logging.warning("Discarded %s/%s of external nodes", len(self.kb) - n_kb_nodes, len(self.kb))
            break
        n_kb_nodes += 1
        new_nodes_index += 1
        pool_inds[new_nodes_index] = []
        orig_nodes_index += 1
        expand_list[orig_nodes_index].append(new_nodes_index)
    pool_inds[new_nodes_index].append(subword_index)
else:
    # Needs to be 1 greater than the last subword we included
    subword_index += 1

concat_ids = torch.concatenate([tokenized_text['input_ids'], tokenized_kb['input_ids']], dim=-1).squeeze()
# The tokenizer already did truncation for tokens, but this is where we do truncation for external nodes
concat_ids = concat_ids[..., :subword_index]

num_new_nodes = new_nodes_index + 1

mask_indices = []
mask_values = []
for (new_node_ind, subword_inds) in pool_inds.items():
    mask_indices.extend((0, new_node_ind, subword_ind) for subword_ind in subword_inds)
    v = 1 / len(subword_inds)
    mask_values.extend(v for _ in subword_inds)

mask_indices = torch.tensor(mask_indices, device=device).transpose(1, 0)
mask_values = torch.tensor(mask_values, device=device)
node_mask = torch.sparse_coo_tensor(
    indices=mask_indices,
    values=mask_values,
    size=(1, num_new_nodes, concat_ids.shape[-1]),
    is_coalesced=True,
    dtype=torch.float,
    device=device
)

# Indices into a sparse array (batch, max_new_nodes, max_new_nodes, relation)
# Need a 0 at the beginning for batch
new_edges = []
# The original token-to-token edges of a standard BERT model
num_text_tokens = tokenized_text['input_ids'].shape[-1]
new_edges.extend((0, head, tail, TOKEN_TO_TOKEN_RELATION_ID) for (head, tail) in product(range(num_text_tokens), range(num_text_tokens)))
new_edges.extend((0, tail, head, TOKEN_TO_TOKEN_RELATION_ID) for (head, tail) in product(range(num_text_tokens), range(num_text_tokens)))

# The edges that we read from the file.
# Update their head/tail indices to account for subwords and special tokens
discarded = 0
for edge in self.edges:
    if edge.head_node_index not in expand_list or edge.tail_node_index not in expand_list:
        discarded += 1
        continue
    expand_list[edge.tail_node_index]
    new_edges.extend((0, head, tail, edge.relation_id) for (head, tail) in product(expand_list[edge.head_node_index], expand_list[edge.tail_node_index]))
new_edges.sort()
logging.debug("Discarded %s/%s edges.", discarded, len(self.edges))

new_edges = torch.tensor(new_edges, device=device).transpose(1, 0)
rval = {
    "input_ids" : concat_ids,
    "node_mask" : node_mask,
    "edge_indices": new_edges
}




In [18]:
rval['input_ids'].shape, rval['node_mask'].shape, rval['edge_indices'].shape

(torch.Size([307]), torch.Size([1, 197, 307]), torch.Size([4, 9754]))

In [19]:
tokenized_text['input_ids'].shape

torch.Size([1, 69])