In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
from nlp481.distillation import (
    getEmptyFrameDict,
    cacheFrameDict,
    loadFrameDict,
    inferDataFrameDict,
    loadDatasetFromCachedDataframe
)

In [None]:

tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-large")
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-large")

In [None]:
DEVICE = "cuda:0"

model.eval()
model.to(DEVICE)

In [None]:
dataset = load_dataset("EdinburghNLP/xsum")

INPUT_KEY = "document"
DATASET_NAME = "xsum_t5_distill"

## Run Cells if Creating Fresh Frame Dict

In [None]:
ds_frame_dict = getEmptyFrameDict(INPUT_KEY, "t5_large_output", dataset)

In [None]:
cacheFrameDict("./cache", ds_frame_dict, DATASET_NAME)

## Run Cell if Using Cached Frame Dict

In [None]:
ds_frame_dict = loadFrameDict("./cache", dataset.keys(), DATASET_NAME)

In [None]:
ds_frame_dict["train"]

In [None]:
inferDataFrameDict(
    ds_frame_dict,
    model,
    tokenizer,
    16,
    cache_location = "./cache",
    dataset_name = DATASET_NAME,
    input_key = INPUT_KEY,
    batches_per_cache_write = 64
)

# Move model out of VRAM (so NLPG admins don't get mad at us)
model.to("cpu")

In [None]:
ds_frame_dict["train"]

In [None]:
distill_dataset = loadDatasetFromCachedDataframe("./cache", dataset.keys(), DATASET_NAME)

In [None]:
distill_dataset["train"]

In [None]:
for curr_name, curr_dataset in cnn_distill_dataset.items():
    curr_dataset = curr_dataset.add_column(
        "highlights",
        cnn_dataset[curr_name]["highlights"]
    )

    curr_dataset = curr_dataset.add_column(
        "id",
        cnn_dataset[curr_name]["id"]
    )

    cnn_distill_dataset[curr_name] = curr_dataset

In [None]:
cnn_distill_dataset.push_to_hub(
    "lilferrit/cnn_dailymail_t5_distillation",
    config_name = "3.0.0"
)