In [None]:
# STL
import os
from typing import Optional, Dict, Any
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# 3rd Party
import torch
from torch.utils.data import DataLoader
torch.manual_seed(0)
from transformers import BertModel, AutoModel, BertTokenizerFast, AutoTokenizer, PreTrainedTokenizerFast
from tokenizers.pre_tokenizers import BertPreTokenizer, PreTokenizer
# Local
from gatbert.constants import DEFAULT_MODEL, Stance
from gatbert.data import MapDataset, make_encoder, make_collate_fn

In [2]:
import dataclasses
@dataclasses.dataclass
class Sample:
    context: str
    target: str
    stance: Stance

In [4]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased', use_fast=True)
encoder = make_encoder(tokenizer, BertPreTokenizer())

In [5]:
fake_samples = [
    Sample(
        context="We hold these truths to be self-evident, that all men are created equal, that they are endowed by their Creator with certain unalienable Rights, that among these are Life, Liberty and the pursuit of Happiness.",
        target="Independence from Britain",
        stance=Stance.FAVOR
    ),
    Sample(
        context="Four score and seven years ago our fathers brought forth on this continent, a new nation, conceived in Liberty, and dedicated to the proposition that all men are created equal.",
        target="Social Security",
        stance=Stance.NONE
    )
]

In [7]:
ds = MapDataset([encoder(s) for s in fake_samples])

In [None]:
loader = DataLoader(ds, batch_size=2, shuffle=False, collate_fn=make_collate_fn(tokenizer))

In [None]:
for d in loader:
    print(d['kb_ids'])

In [4]:
in_features = 123
attention_units = 53
out_features = 264
n_heads = 6
n_relations = 7
n_bases = 3
max_nodes = 10
batch_size = 5
gen = torch.Generator().manual_seed(1)
random_features = 5 * (torch.randn(batch_size, max_nodes, in_features, generator=gen) - .5)
random_features.shape
random_adj = torch.randint(0, 2, size=[batch_size, max_nodes, max_nodes, n_relations], generator=gen).to_sparse()
random_adj

In [None]:
tokenizer.pad_token_type_id