In [6]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset, load_from_disk
from datasets import Dataset, DatasetDict
from tinydb import TinyDB, Query
from functools import partial

In [7]:

tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-large")
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-large")

In [8]:
model.to("cuda:0")

T5ForConditionalGeneration(
  (shared): Embedding(32128, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=1024, bias=False)
              (k): Linear(in_features=1024, out_features=1024, bias=False)
              (v): Linear(in_features=1024, out_features=1024, bias=False)
              (o): Linear(in_features=1024, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 16)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=1024, out_features=4096, bias=False)
              (wo): Linear(in_features=4096, out_features=1024, bias=False)
              (d

In [4]:
CACHE_LOCATION = "cache/cnn_t5_large_distillation"

## Run this cell if using fresh dataset

In [5]:
cnn_dataset = load_dataset("cnn_dailymail", "1.0.0")
cnn_dataset

DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 287113
    })
    validation: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 13368
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 11490
    })
})

## Run this cell if using cached dataset

In [5]:
cnn_dataset = load_from_disk(CACHE_LOCATION)

In [9]:
BATCH_SIZE = 16
BATCHES_PER_CACHE = 1

def inferDataset(
    curr_batch: Dataset,
    input_key: str = "article",
    prefix: str = "summarize: ",
    max_input_length: int = 1024,
    max_output_length: int = 1024
) -> Dataset:
    inputs = [prefix + doc for doc in curr_batch[input_key]]
    input_ids = tokenizer(
        inputs, 
        return_tensors = "pt",
        max_length = max_input_length,
        truncation = True,
        padding = True,
    ).input_ids.to("cuda:0")

    outputs = model.generate(input_ids, max_new_tokens = max_output_length)
    outputs.to("cpu")

    decoded_output = tokenizer.batch_decode(outputs, skip_special_tokens = True)
    curr_batch["t5_large_output"] = decoded_output

    return curr_batch

cnn_dataset.map(
    inferDataset,
    batched = True,
    batch_size = 16 
)
    

Map:   0%|          | 0/287113 [00:08<?, ? examples/s]


TypeError: string indices must be integers