In [6]:
# STL
import os
import json
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# 3rd Party
import torch
from copy import deepcopy
from transformers import  BertTokenizerFast
# Local
from gatbert.data import parse_graph_tsv, Sample
from gatbert.graph_sample import GraphSample
from gatbert.constants import NodeType, Stance

In [42]:
pretrained_model_name = 'bert-base-uncased'
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name)

In [49]:
# sample = next(parse_graph_tsv('scrap.tsv'))

sample = GraphSample(
    stance=Stance.FAVOR,
    target=["Pakistan", "government"],
    context="We need to stop supporting governmentalists who harbor terrorists .".split(),
    kb=["/c/en/governmentalists", "/c/en/pakistan"],
    edges=[]
)
clean_kb_sample = [uri.split('/')[3] for uri in sample.kb]

In [50]:
tokenized_text = tokenizer(text=sample.target, text_pair=sample.context, is_split_into_words=True, return_offsets_mapping=True, return_tensors='pt')
tokenized_kb = tokenizer(text=clean_kb_sample, is_split_into_words=True, return_offsets_mapping=True, return_tensors='pt')

In [45]:
tokenized_text['input_ids']

tensor([[  101,  4501,  2231,   102,  2057,  2342,  2000,  2644,  4637, 10605,
          5130,  2040,  6496, 15554,  1012,   102]])

In [46]:
tokenizer.convert_ids_to_tokens(tokenized_text['input_ids'].squeeze())

['[CLS]',
 'pakistan',
 'government',
 '[SEP]',
 'we',
 'need',
 'to',
 'stop',
 'supporting',
 'governmental',
 '##ists',
 'who',
 'harbor',
 'terrorists',
 '.',
 '[SEP]']

In [47]:
tokenized_text['offset_mapping']

tensor([[[ 0,  0],
         [ 0,  8],
         [ 0, 10],
         [ 0,  0],
         [ 0,  2],
         [ 0,  4],
         [ 0,  2],
         [ 0,  4],
         [ 0, 10],
         [ 0, 12],
         [12, 16],
         [ 0,  3],
         [ 0,  6],
         [ 0, 10],
         [ 0,  1],
         [ 0,  0]]])

In [51]:
tokenizer.convert_ids_to_tokens(tokenized_kb['input_ids'].squeeze())

['[CLS]', 'governmental', '##ists', 'pakistan', '[SEP]']

In [52]:
tokenized_kb['offset_mapping']

tensor([[[ 0,  0],
         [ 0, 12],
         [12, 16],
         [ 0,  8],
         [ 0,  0]]])

In [56]:
from collections import defaultdict

# (node_index, subword_index)
pool_inds = []
expand_list = defaultdict(list)

new_nodes_index = -1
orig_nodes_index = -1

# For token subwords, we will split a token's nodes into subwords
token_offset_mapping = tokenized_text['offset_mapping'].squeeze()
# Handle splitting of token nodes into subword nodes
for (subword_index, (start, end)) in enumerate(token_offset_mapping):
    new_nodes_index += 1
    if start != end: # Real character, not a special character
        if start == 0: # Start of a token
            orig_nodes_index += 1
        expand_list[orig_nodes_index].append(new_nodes_index)
    pool_inds.append((new_nodes_index, subword_index))

# For KB subwords, we plan to pool each into one combined node
# Get this working next
kb_offset_mapping = tokenized_kb['offset_mapping'].squeeze()
for (subword_index, (start, end)) in enumerate(kb_offset_mapping, start=subword_index + 1):
    if start == end:
        # Special character; skip over
        new_nodes_index += 1
    elif start == 0:
        new_nodes_index += 1
        orig_nodes_index += 1
        expand_list[orig_nodes_index].append(new_nodes_index)
    pool_inds.append((new_nodes_index, subword_index))
num_new_nodes = new_nodes_index + 1