In [None]:
# STL
import os
from typing import Optional, Dict, Any
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# 3rd Party
import torch
from torch.utils.data import DataLoader
torch.manual_seed(0)
from transformers import BertModel, AutoModel, BertTokenizerFast, AutoTokenizer, PreTrainedTokenizerFast
from tokenizers.pre_tokenizers import BertPreTokenizer, PreTokenizer
# Local
from gatbert.constants import DEFAULT_MODEL, Stance
from gatbert.datasets import MapDataset, make_encoder, make_collate_fn

In [2]:
import dataclasses
@dataclasses.dataclass
class Sample:
    context: str
    target: str
    stance: Stance

In [3]:
# from itertools import product
# import numpy as np
# def make_encoder(tokenizer: PreTrainedTokenizerFast, pretokenizer: Optional[PreTokenizer] = None):
#     relation_types = [
#         0, # Token-to-Token
#         1, # Token-to-CN
#         2, # CN-to-Token
#         3, # CN-to-CN
#     ]

#     # Node indices start at 1 so we can use 0 for a "padding node"
#     node_universe = np.arange(1, 1001)

#     gen = torch.Generator().manual_seed(0)

#     def encode_sample(sample: Sample):
#         context = sample.context
#         target = sample.target
#         stance = sample.stance.value

#         if pretokenizer:
#             pre_context = [pair[0] for pair in pretokenizer.pre_tokenize_str(context)]
#             pre_target = [pair[0] for pair in pretokenizer.pre_tokenize_str(target)]
#             result = tokenizer(text=pre_target, text_pair=pre_context, is_split_into_words=True, return_tensors='pt')
#         else:
#             result = tokenizer(text=target, text_pair=context)

#         result = {k: torch.squeeze(v) for (k, v) in result.items()}
#         n_text_nodes = len(result['input_ids'])

#         edge_ids = []
#         for head in range(n_text_nodes):
#             edge_ids.append( (head, head, 0) )
#             for tail in range(head + 1, n_text_nodes):
#                 edge_ids.append( (head, tail, 0) )
#                 edge_ids.append( (tail, head, 0) )

#         # Right now can only make fake KB nodes
#         num_kb_nodes = torch.randint(3, 10, size=(), generator=gen)
#         chosen_nodes = torch.tensor(np.random.choice(node_universe, size=int(num_kb_nodes), replace=False))
#         for (token_id, cn_id) in zip(*torch.where(torch.randn(n_text_nodes, num_kb_nodes, generator=gen) < .1)):
#             edge_ids.append( (token_id, cn_id + n_text_nodes, 1) )
#         for (token_id, cn_id) in zip(*torch.where(torch.randn(n_text_nodes, num_kb_nodes, generator=gen) < .1)):
#             edge_ids.append( (cn_id + n_text_nodes, token_id, 2) )
#         for (cn_id, cn_id_b) in zip(*torch.where(torch.randn(num_kb_nodes, num_kb_nodes, generator=gen) < .1)):
#             edge_ids.append( (cn_id + n_text_nodes, cn_id_b + n_text_nodes, 3) )
#         result['kb_ids'] = chosen_nodes

#         edge_ids.sort()
#         sparse_ids = torch.tensor(edge_ids).transpose(1, 0)
#         result['edges'] = sparse_ids
#         result['stance'] = torch.tensor(stance)

#         return result

#     return encode_sample


In [4]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased', use_fast=True)
encoder = make_encoder(tokenizer, BertPreTokenizer())

In [5]:
fake_samples = [
    Sample(
        context="We hold these truths to be self-evident, that all men are created equal, that they are endowed by their Creator with certain unalienable Rights, that among these are Life, Liberty and the pursuit of Happiness.",
        target="Independence from Britain",
        stance=Stance.FAVOR
    ),
    Sample(
        context="Four score and seven years ago our fathers brought forth on this continent, a new nation, conceived in Liberty, and dedicated to the proposition that all men are created equal.",
        target="Social Security",
        stance=Stance.NONE
    )
]

In [6]:
# def make_collate_fn(tokenizer):
#     def collate_fn(samples: Dict[str, Any]):
#         batched = {}
#         token_padding = tokenizer.pad_token_id
#         type_padding = tokenizer.pad_token_type_id
#         batched['input_ids'] = torch.nn.utils.rnn.pad_sequence([s['input_ids'] for s in samples], batch_first=True, padding_value=token_padding)
#         batched['token_type_ids'] = torch.nn.utils.rnn.pad_sequence([s['token_type_ids'] for s in samples], batch_first=True, padding_value=type_padding)
#         batched['kb_ids'] = torch.nn.utils.rnn.pad_sequence([s['kb_ids'] for s in samples], batch_first=True, padding_value=0)

#         batched['attention_mask'] = batched['input_ids'] != token_padding
#         batched['stance'] = torch.stack([s['stance'] for s in samples], dim=0)

#         batch_edges = []
#         for (i, sample_edges) in enumerate(map(lambda s: s['edges'], samples)):
#             batch_edges.append(torch.concatenate([
#                 torch.full(size=(1, sample_edges.shape[1]), fill_value=i),
#                 sample_edges
#             ]))
#         batched['edges'] = torch.concatenate(batch_edges, dim=-1)
#         return batched
#     return collate_fn

In [7]:
ds = MapDataset([encoder(s) for s in fake_samples])

In [None]:
loader = DataLoader(ds, batch_size=2, shuffle=False, collate_fn=make_collate_fn(tokenizer))

In [None]:
for d in loader:
    print(d['kb_ids'])

In [4]:
in_features = 123
attention_units = 53
out_features = 264
n_heads = 6
n_relations = 7
n_bases = 3
max_nodes = 10
batch_size = 5
gen = torch.Generator().manual_seed(1)
random_features = 5 * (torch.randn(batch_size, max_nodes, in_features, generator=gen) - .5)
random_features.shape
random_adj = torch.randint(0, 2, size=[batch_size, max_nodes, max_nodes, n_relations], generator=gen).to_sparse()
random_adj

In [None]:
tokenizer.pad_token_type_id