In [95]:
# STL
import os
from typing import Optional, Dict, Any
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# 3rd Party
import torch
from torch.utils.data import DataLoader
torch.manual_seed(0)
from transformers import BertModel, AutoModel, BertTokenizerFast, AutoTokenizer, PreTrainedTokenizerFast
from tokenizers.pre_tokenizers import BertPreTokenizer, PreTokenizer
# Local
from gatbert.constants import DEFAULT_MODEL
from gatbert.datasets import MapDataset

In [96]:
from itertools import product
def make_encoder(tokenizer: PreTrainedTokenizerFast, pretokenizer: Optional[PreTokenizer] = None, make_fake_edges=True):
    relation_types = [
        0, # Token-to-Token
        1, # Token-to-CN
        2, # CN-to-Token
        3, # CN-to-CN
    ]

    gen = torch.Generator().manual_seed(0)

    def encode_sample(sample: Dict[str, Any]):
        context: str = sample['context']
        target: str = sample['target']
        stance: int = sample['stance']

        if pretokenizer:
            pre_context = [pair[0] for pair in pretokenizer.pre_tokenize_str(context)]
            pre_target = [pair[0] for pair in pretokenizer.pre_tokenize_str(target)]
            result = tokenizer(text=pre_target, text_pair=pre_context, is_split_into_words=True, return_tensors='pt')
        else:
            result = tokenizer(text=target, text_pair=context)

        result = {k: torch.squeeze(v) for (k, v) in result.items()}
        n_text_nodes = len(result['input_ids'])


        edge_ids = []
        for head in range(n_text_nodes):
            edge_ids.append( (head, tail, 0) )
            for tail in range(head + 1, n_text_nodes):
                edge_ids.append( (head, tail, 0) )
                edge_ids.append( (tail, head, 0) )
        total_nodes = n_text_nodes
        if make_fake_edges:
            n_fake_cn_nodes = torch.randint(3, 5, size=(), generator=gen)
            for (token_id, cn_id) in zip(*torch.where(torch.randn(n_text_nodes, n_fake_cn_nodes, generator=gen) < .1)):
                edge_ids.append( (token_id, cn_id + n_text_nodes, 1) )
            for (token_id, cn_id) in zip(*torch.where(torch.randn(n_text_nodes, n_fake_cn_nodes, generator=gen) < .1)):
                edge_ids.append( (cn_id + n_text_nodes, token_id, 2) )
            for (cn_id, cn_id_b) in zip(*torch.where(torch.randn(n_fake_cn_nodes, n_fake_cn_nodes, generator=gen) < .1)):
                edge_ids.append( (cn_id + n_text_nodes, cn_id_b + n_text_nodes, 3) )
            total_nodes += n_fake_cn_nodes
        edge_ids.sort()
        sparse_ids = torch.tensor(edge_ids).transpose(1, 0)
        result['edges'] = sparse_ids
        result['nodes'] = total_nodes
        result['stance'] = torch.tensor(stance)

        return result

    return encode_sample


In [89]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased', use_fast=True)
encoder = make_encoder(tokenizer, BertPreTokenizer())

In [93]:
fake_samples = [
    { 
    "context": "We hold these truths to be self-evident, that all men are created equal, that they are endowed by their Creator with certain unalienable Rights, that among these are Life, Liberty and the pursuit of Happiness.",
    "target": "Independence from Britain",
    "stance": 2
    },
    { 
    "context": "Four score and seven years ago our fathers brought forth on this continent, a new nation, conceived in Liberty, and dedicated to the proposition that all men are created equal.",
    "target": "Social Security",
    "stance": 0
    }
]

In [128]:
def make_collate_fn(tokenizer):
    def collate_fn(samples: Dict[str, Any]):
        batched = {}
        token_padding = tokenizer.pad_token_id
        type_padding = tokenizer.pad_token_type_id
        batched['input_ids'] = torch.nn.utils.rnn.pad_sequence([s['input_ids'] for s in samples], batch_first=True, padding_value=token_padding)
        batched['token_type_ids'] = torch.nn.utils.rnn.pad_sequence([s['token_type_ids'] for s in samples], batch_first=True, padding_value=type_padding)
        batched['attention_mask'] = batched['input_ids'] != token_padding
        batched['stance'] = torch.stack([s['stance'] for s in samples], dim=0)

        batch_edges = []
        for (i, sample_edges) in enumerate(map(lambda s: s['edges'], samples)):
            batch_edges.append(torch.concatenate([
                torch.full(size=(1, sample_edges.shape[1]), fill_value=i),
                sample_edges
            ]))
        batched['edges'] = torch.concatenate(batch_edges, dim=-1)
        return batched
    return collate_fn

In [129]:
ds = MapDataset([encoder(s) for s in fake_samples])

In [130]:
loader = DataLoader(ds, batch_size=2, shuffle=False, collate_fn=make_collate_fn(tokenizer))

In [131]:
for d in loader:
    print(d)

{'input_ids': tensor([[  101,  7824,  1121,  2855,   102,  1284,  2080,  1292,  3062,  1116,
          1106,  1129,  2191,   118, 10238,   117,  1115,  1155,  1441,  1132,
          1687,  4463,   117,  1115,  1152,  1132, 22868,  1118,  1147,   140,
         26284,  1114,  2218,  8362, 10584,  7076,  2165,  5399,   117,  1115,
          1621,  1292,  1132,  2583,   117,  8146,  1105,  1103,  9542,  1104,
         25410,   119,   102],
        [  101,  3563,  4354,   102,  3396,  2794,  1105,  1978,  1201,  2403,
          1412, 15920,  1814,  5275,  1113,  1142, 10995,   117,   170,  1207,
          3790,   117, 10187,  1107,  8146,   117,  1105,  3256,  1106,  1103,
         21133,  1115,  1155,  1441,  1132,  1687,  4463,   119,   102,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 

In [132]:
d['edges'].shape

torch.Size([4, 4591])

In [4]:
in_features = 123
attention_units = 53
out_features = 264
n_heads = 6
n_relations = 7
n_bases = 3
max_nodes = 10
batch_size = 5
gen = torch.Generator().manual_seed(1)
random_features = 5 * (torch.randn(batch_size, max_nodes, in_features, generator=gen) - .5)
random_features.shape
random_adj = torch.randint(0, 2, size=[batch_size, max_nodes, max_nodes, n_relations], generator=gen).to_sparse()
random_adj

In [114]:
tokenizer.pad_token_type_id

0