In [1]:
# STL
import os
from typing import Optional, Dict, Any
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# 3rd Party
import torch
from torch.utils.data import DataLoader
torch.manual_seed(0)
from transformers import BertModel, AutoModel, BertTokenizerFast, AutoTokenizer, PreTrainedTokenizerFast
from tokenizers.pre_tokenizers import BertPreTokenizer, PreTokenizer
# Local
from gatbert.constants import DEFAULT_MODEL, Stance, NodeType, NODE_PAD_ID
from gatbert.models import GATBert
from gatbert.data import *
from gatbert.rgat_layer import RGATLayer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pretrained_model_name = 'bert-base-cased'

In [3]:
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name, use_fast=True)

In [4]:
encoder = make_encoder(tokenizer, BertPreTokenizer())
fake_samples = [
    Sample(
        context="We hold these truths to be self-evident, that all men are created equal, that they are endowed by their Creator with certain unalienable Rights, that among these are Life, Liberty and the pursuit of Happiness.",
        target="Independence from Britain",
        stance=Stance.FAVOR
    ),
    Sample(
        context="Four score and seven years ago our fathers brought forth on this continent, a new nation, conceived in Liberty, and dedicated to the proposition that all men are created equal.",
        target="Social Security",
        stance=Stance.NONE
    )
]
ds = MapDataset([encoder(s) for s in fake_samples])
loader = DataLoader(ds, batch_size=2, shuffle=False, collate_fn=make_collate_fn(tokenizer))

In [5]:
gat_model = GATBert(
    pretrained_model=pretrained_model_name,
    n_relations=len(DummyRelationType),
    n_kb_nodes=1001,
    n_classes=len(Stance)
)

In [6]:
for batch in loader:
    stance = batch.pop('stance')
    output = gat_model(**batch)

In [8]:
output.shape

torch.Size([2, 3])