## Imports

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import sys
from pathlib import Path

import torch

from datasets import load_dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration

In [5]:
ELK_PATH = Path("../../../elk/")
print(ELK_PATH.resolve())

modules = [
    ELK_PATH,
    ELK_PATH / "elk" / "training",
    ELK_PATH / "elk" / "promptsource",
]

for module in modules:
    if not str(module) in sys.path:
        sys.path.insert(0, str(module.resolve()))

print(sys.path[:3])

from reporter import Reporter
from templates import DatasetTemplates

/fsx/home-augustas/elk
['/fsx/home-augustas/elk/elk/promptsource', '/fsx/home-augustas/elk/elk/training', '/fsx/home-augustas/elk']


## Config

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Dataset

In [7]:
dataset = load_dataset("imdb", split="train[:64]")
dataset

Found cached dataset imdb (/admin/home-augustas/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


Dataset({
    features: ['text', 'label'],
    num_rows: 64
})

In [8]:
dataset_template_path = "AugustasM/burns-datasets-VINC/all"

dataset_templates = DatasetTemplates(dataset_template_path)
dataset_templates.templates = {
    x.name: x for x in dataset_templates.templates.values() if x.get_answer_choices_list(dataset[0]) is not None
}
print(dataset_templates.templates)

template = list(dataset_templates.templates.values())[0]
template

{'Consider the text below': <templates.Template object at 0x7f70f406e860>, 'Is the given text truthful': <templates.Template object at 0x7f70f406e890>, 'Text first': <templates.Template object at 0x7f70f406e8c0>, 'No question no choices': <templates.Template object at 0x7f70f406d750>, 'No question with choices': <templates.Template object at 0x7f70f406e800>}


<templates.Template at 0x7f70f406e860>

In [9]:
q, a = template.apply(dataset[0])
a

'no'

## Tokenizer

In [10]:
# model_name = "allenai/unifiedqa-v2-t5-11b-1363200"
model_name = "allenai/unifiedqa-v2-t5-3b-1363200"
# model_name = "allenai/unifiedqa-t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name, truncation_side="left")

In [11]:
text_batch = tokenizer(q, add_special_tokens=True, return_tensors="pt", text_target=a.strip()).to(device)
text_batch["input_ids"].shape

torch.Size([1, 459])

In [12]:
text_batch.keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

## Model

In [13]:
%%time

model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
print(model.config.is_encoder_decoder)
print(hasattr(model, "get_encoder") and callable(model.get_encoder))
print(f"# params: {sum(p.numel() for p in model.parameters())}")

True
True
# params: 2851598336
CPU times: user 29.5 s, sys: 21.6 s, total: 51.1 s
Wall time: 55.4 s


### Encoder and decoder

In [14]:
with torch.no_grad():
    outputs = model(**text_batch, output_hidden_states=True)

outputs.keys()

odict_keys(['loss', 'logits', 'past_key_values', 'decoder_hidden_states', 'encoder_last_hidden_state', 'encoder_hidden_states'])

In [15]:
len(outputs["encoder_hidden_states"]), len(outputs["decoder_hidden_states"])

(25, 25)

In [16]:
hiddens = outputs["decoder_hidden_states"]
[x.shape for x in hiddens]

[torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024]),
 torch.Size([1, 2, 1024])]

In [17]:
[h[:, -1, :].shape for h in hiddens]

[torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024])]

### Encoder only

In [81]:
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
print(model.config.is_encoder_decoder)
if hasattr(model, "get_encoder") and callable(model.get_encoder):
    model = model.get_encoder()

print(f"# params: {sum(p.numel() for p in model.parameters())}")

True
# params: 35330816


In [82]:
from transformers import AutoConfig

model_cfg = AutoConfig.from_pretrained(model_name)

# Ordered by preference
_DECODER_ONLY_SUFFIXES = [
    "CausalLM",
    "LMHeadModel",
]
# Includes encoder-decoder models
_AUTOREGRESSIVE_SUFFIXES = ["ConditionalGeneration"] + _DECODER_ONLY_SUFFIXES

def is_autoregressive(model_cfg, include_enc_dec: bool) -> bool:
    """Check if a model config is autoregressive."""
    archs = model_cfg.architectures
    if not isinstance(archs, list):
        return False

    suffixes = _AUTOREGRESSIVE_SUFFIXES if include_enc_dec else _DECODER_ONLY_SUFFIXES
    return any(arch_str.endswith(suffix) for arch_str in archs for suffix in suffixes)

has_lm_preds = is_autoregressive(model_cfg, include_enc_dec=False)
has_lm_preds

False

In [84]:
question_ids = text_batch.input_ids
print(question_ids.shape)

answer_tokenized = tokenizer(
    a.strip(),
    # Don't include [CLS] and [SEP] in the answer
    add_special_tokens=False,
    return_tensors="pt",
).to(device)

answer_ids = answer_tokenized.input_ids
ids = torch.cat([question_ids, answer_ids], -1)
ids.shape

torch.Size([1, 459])


torch.Size([1, 460])

In [89]:
with torch.no_grad():
    outputs = model(input_ids=ids.long(), output_hidden_states=True)

outputs.keys()

odict_keys(['last_hidden_state', 'hidden_states'])

In [90]:
hiddens = outputs["hidden_states"]
[x.shape for x in hiddens]

[torch.Size([1, 460, 512]),
 torch.Size([1, 460, 512]),
 torch.Size([1, 460, 512]),
 torch.Size([1, 460, 512]),
 torch.Size([1, 460, 512]),
 torch.Size([1, 460, 512]),
 torch.Size([1, 460, 512])]

In [91]:
[h[:, -1, :].shape for h in hiddens]

[torch.Size([1, 512]),
 torch.Size([1, 512]),
 torch.Size([1, 512]),
 torch.Size([1, 512]),
 torch.Size([1, 512]),
 torch.Size([1, 512]),
 torch.Size([1, 512])]