In [1]:
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '5'

In [2]:
import transformers as tr
import datasets as ds
import torch

In [3]:
dataset = ds.load_dataset("Rexhaif/mintaka-qa-en", split="test")

Using custom data configuration Rexhaif--mintaka-qa-en-a309a4f0b6175fde
Found cached dataset parquet (/root/.cache/huggingface/datasets/Rexhaif___parquet/Rexhaif--mintaka-qa-en-a309a4f0b6175fde/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


In [4]:
tokenizer = tr.AutoTokenizer.from_pretrained("google/t5-xxl-ssm-nq")

Downloading:   0%|          | 0.00/1.86k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/631 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

In [5]:
def process_fn(examples):
    questions = examples["question"]
    return tokenizer(
        questions,
        truncation=True,
        max_length=512
    )

In [6]:
questions_dataset = dataset.map(
    process_fn,
    batched=True,
    batch_size=8,
    remove_columns=dataset.column_names,
)
answers_dataset = dataset["answer"]

  0%|          | 0/500 [00:00<?, ?ba/s]

In [7]:
model = tr.AutoModelForSeq2SeqLM.from_pretrained(
    "google/t5-xxl-ssm-nq",
    device_map='auto',
    load_in_8bit=True
)

Downloading:   0%|          | 0.00/44.5G [00:00<?, ?B/s]

In [7]:
collate_fn = tr.DataCollatorWithPadding(
    tokenizer=tokenizer, padding="longest", max_length=512, pad_to_multiple_of=8, return_tensors="pt"
)
data_loader = torch.utils.data.DataLoader(
    questions_dataset,
    batch_size=8,
    collate_fn=collate_fn,
)

In [62]:
batch = next(iter(data_loader))

In [63]:
batch = {k: v.to(model.device) for k, v in batch.items()}
outputs = model.generate(
    input_ids=batch["input_ids"],
    attention_mask=batch["attention_mask"],
    max_new_tokens=10,
    num_beams=5,
    do_sample=False,
    num_return_sequences=5,
    output_scores=True,
    return_dict_in_generate=True
)

In [64]:
from rich import print
%load_ext rich

The rich extension is already loaded. To reload it, use:
  %reload_ext rich


In [65]:
scores = torch.softmax(outputs.sequences_scores.view(-1, 5), dim=-1)

In [66]:
ue_scores = torch.abs(outputs.sequences_scores.view(-1, 5)[:, 0] - outputs.sequences_scores.view(-1, 5)[:, 1])

In [67]:
decoded = tokenizer.batch_decode(outputs.sequences.view(8, 5, -1)[:, 0, :], skip_special_tokens=True)

In [68]:
for i in range(8):
    print(f"[bold blue]Question:[/bold blue] {dataset['question'][i]} [bold green]Answer:[/bold green] {dataset['answer'][i]} [bold red]Prediction:[/bold red] {decoded[i]} [bold yellow]Score:[/bold yellow] {ue_scores[i]:.4f}")