In [16]:
# STL
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import json
import copy
from itertools import product
from collections import defaultdict, OrderedDict
# 3rd Party
import torch
from transformers import  BertTokenizerFast
# Local
from gatbert.data import parse_graph_tsv, Sample
from gatbert.graph_sample import GraphSample, Edge
from gatbert.constants import NodeType, Stance

In [2]:
pretrained_model_name = 'bert-base-uncased'
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name)

In [3]:
# sample = next(parse_graph_tsv('scrap.tsv'))

target = ["Pakistan", "government"]
context = "We need to stop supporting governmentalists who harbor terrorists .".split()
kb =["/c/en/governmentalists", "/c/en/pakistan"]

sample = GraphSample(
    stance=Stance.FAVOR,
    target=target,
    context=context,
    kb=kb,
    edges=[
        Edge(len(target) + 5, len(target) + len(context) + 0, 42), # govermenalists in context to /c/en/governmentalists
    ]
)
clean_kb_sample = [uri.split('/')[3] for uri in sample.kb]

In [4]:
tokenized_text = tokenizer(text=sample.target, text_pair=sample.context, is_split_into_words=True, return_offsets_mapping=True, return_tensors='pt')
tokenized_kb = tokenizer(text=clean_kb_sample, is_split_into_words=True, return_offsets_mapping=True, return_tensors='pt')

In [5]:
tokenized_text['input_ids']

tensor([[  101,  4501,  2231,   102,  2057,  2342,  2000,  2644,  4637, 10605,
          5130,  2040,  6496, 15554,  1012,   102]])

In [6]:
tokenizer.convert_ids_to_tokens(tokenized_text['input_ids'].squeeze())

['[CLS]',
 'pakistan',
 'government',
 '[SEP]',
 'we',
 'need',
 'to',
 'stop',
 'supporting',
 'governmental',
 '##ists',
 'who',
 'harbor',
 'terrorists',
 '.',
 '[SEP]']

In [7]:
tokenized_text['offset_mapping']

tensor([[[ 0,  0],
         [ 0,  8],
         [ 0, 10],
         [ 0,  0],
         [ 0,  2],
         [ 0,  4],
         [ 0,  2],
         [ 0,  4],
         [ 0, 10],
         [ 0, 12],
         [12, 16],
         [ 0,  3],
         [ 0,  6],
         [ 0, 10],
         [ 0,  1],
         [ 0,  0]]])

In [8]:
tokenizer.convert_ids_to_tokens(tokenized_kb['input_ids'].squeeze())

['[CLS]', 'governmental', '##ists', 'pakistan', '[SEP]']

In [9]:
tokenized_kb['offset_mapping']

tensor([[[ 0,  0],
         [ 0, 12],
         [12, 16],
         [ 0,  8],
         [ 0,  0]]])

In [10]:
# (node_index, subword_index)
pool_inds = []
expand_list = defaultdict(list)

pool_inds = OrderedDict()

new_nodes_index = -1
orig_nodes_index = -1

# For token subwords, we will split a token's nodes into subwords
token_offset_mapping = tokenized_text['offset_mapping'].squeeze()
# Handle splitting of token nodes into subword nodes
for (subword_index, (start, end)) in enumerate(token_offset_mapping):
    new_nodes_index += 1
    pool_inds[new_nodes_index] = []

    if start != end: # Real character, not a special character
        if start == 0: # Start of a token
            orig_nodes_index += 1
        expand_list[orig_nodes_index].append(new_nodes_index)
    pool_inds[new_nodes_index].append(subword_index)

# For KB subwords, we plan to pool each into one combined node
# Get this working next
kb_offset_mapping = tokenized_kb['offset_mapping'].squeeze()
for (subword_index, (start, end)) in enumerate(kb_offset_mapping, start=subword_index + 1):
    if start == end:
        # Special character; skip over
        new_nodes_index += 1
        pool_inds[new_nodes_index] = []
    elif start == 0:
        new_nodes_index += 1
        pool_inds[new_nodes_index] = []
        orig_nodes_index += 1
        expand_list[orig_nodes_index].append(new_nodes_index)
    pool_inds[new_nodes_index].append(subword_index)
num_new_nodes = new_nodes_index + 1

In [11]:
concat_ids = torch.concatenate([tokenized_text['input_ids'], tokenized_kb['input_ids']], dim=-1).squeeze()

In [12]:
mask_indices = []
mask_values = []
for (new_node_ind, subword_inds) in pool_inds.items():
    mask_indices.extend((new_node_ind, subword_ind) for subword_ind in subword_inds)
    v = 1 / len(subword_inds)
    mask_values.extend(v for _ in subword_inds)

mask_indices = torch.tensor(mask_indices).transpose(1, 0)
mask_values = torch.tensor(mask_values)
node_mask = torch.sparse_coo_tensor(
    indices=mask_indices,
    values=mask_values,
    size=(num_new_nodes, concat_ids.shape[-1]),
    is_coalesced=True,
    dtype=torch.float,
    device=concat_ids.device
)

In [13]:
node_mask

tensor(indices=tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13,
                        14, 15, 16, 17, 17, 18, 19],
                       [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13,
                        14, 15, 16, 17, 18, 19, 20]]),
       values=tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
                      1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
                      1.0000, 1.0000, 1.0000, 0.5000, 0.5000, 1.0000, 1.0000]),
       size=(20, 21), nnz=21, layout=torch.sparse_coo)

In [14]:
sample.edges

[Edge(head_node_index=7, tail_node_index=12, relation_id=42)]

In [15]:
expand_list

defaultdict(list,
            {0: [1],
             1: [2],
             2: [4],
             3: [5],
             4: [6],
             5: [7],
             6: [8],
             7: [9, 10],
             8: [11],
             9: [12],
             10: [13],
             11: [14],
             12: [17],
             13: [18]})

In [18]:
new_edges = []
for edge in sample.edges:
    if edge.head_node_index not in expand_list:
        print(f"Warning: found no expansions for node {edge.head_node_index}")
        continue
    head_expand_list = expand_list[edge.head_node_index]
    if edge.tail_node_index not in expand_list:
        print(f"Warning: found no expansions for node {edge.tail_node_index}")
        continue
    tail_expand_list = expand_list[edge.tail_node_index]
    new_edges.extend((head, tail, edge.relation_id) for (head, tail) in product(head_expand_list, tail_expand_list))

In [19]:
new_edges

[(9, 17, 42), (10, 17, 42)]

In [20]:
tokenizer.convert_ids_to_tokens(concat_ids)

['[CLS]',
 'pakistan',
 'government',
 '[SEP]',
 'we',
 'need',
 'to',
 'stop',
 'supporting',
 'governmental',
 '##ists',
 'who',
 'harbor',
 'terrorists',
 '.',
 '[SEP]',
 '[CLS]',
 'governmental',
 '##ists',
 'pakistan',
 '[SEP]']

In [21]:
node_mask.shape

torch.Size([20, 21])