In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
from nlp481.distillation import (
    getEmptyFrameDict,
    cacheFrameDict,
    loadFrameDict,
    inferDataFrameDict,
    loadDatasetFromCachedDataframe
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-large")
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-large")

In [3]:
DEVICE = "cuda:0"

model.eval()
model.to(DEVICE)

T5ForConditionalGeneration(
  (shared): Embedding(32128, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=1024, bias=False)
              (k): Linear(in_features=1024, out_features=1024, bias=False)
              (v): Linear(in_features=1024, out_features=1024, bias=False)
              (o): Linear(in_features=1024, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 16)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=1024, out_features=4096, bias=False)
              (wo): Linear(in_features=4096, out_features=1024, bias=False)
              (d

In [4]:
dataset = load_dataset("EdinburghNLP/xsum")

INPUT_KEY = "document"
DATASET_NAME = "xsum_t5_distill"

## Run Cells if Creating Fresh Frame Dict

In [None]:
ds_frame_dict = getEmptyFrameDict(INPUT_KEY, "t5_large_output", dataset)

In [None]:
cacheFrameDict("./cache", ds_frame_dict, DATASET_NAME)

## Run Cell if Using Cached Frame Dict

In [5]:
ds_frame_dict = loadFrameDict("./cache", dataset.keys(), DATASET_NAME)

In [6]:
ds_frame_dict["train"]

Unnamed: 0,document,t5_large_output
0,"The full cost of damage in Newton Stewart, one...",flood damage in Newton Stewart still being ass...
1,A fire alarm went off at the Holiday Inn in Ho...,fire alarm went off at the holiday inn in hope...
2,Ferrari appeared in a position to challenge un...,Mercedes will start the race on pole ahead of ...
3,"John Edward Bates, formerly of Spalding, Linco...",67-year-old accused of 22 charges including tw...
4,Patients and staff were evacuated from Cerahpa...,a man receiving treatment at a clinic in Istan...
...,...,...
204040,The initial figure released in July was booste...,
204041,"MEPs, including European Parliament chief Brex...",
204042,Lincoln Red Imps will bring a 1-0 lead to Glas...,
204043,Former Liverpool defender Mark Lawrenson expan...,


In [7]:
inferDataFrameDict(
    ds_frame_dict,
    model,
    tokenizer,
    16,
    cache_location = "./cache",
    dataset_name = DATASET_NAME,
    input_key = INPUT_KEY,
    batches_per_cache_write = 64
)

# Move model out of VRAM (so NLPG admins don't get mad at us)
model.to("cpu")

  return bound(*args, **kwds)
100%|██████████| 12753/12753 [2:00:09<00:00,  1.77it/s]  
  return bound(*args, **kwds)
100%|██████████| 709/709 [38:02<00:00,  3.22s/it]
100%|██████████| 709/709 [37:42<00:00,  3.19s/it]


T5ForConditionalGeneration(
  (shared): Embedding(32128, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=1024, bias=False)
              (k): Linear(in_features=1024, out_features=1024, bias=False)
              (v): Linear(in_features=1024, out_features=1024, bias=False)
              (o): Linear(in_features=1024, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 16)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=1024, out_features=4096, bias=False)
              (wo): Linear(in_features=4096, out_features=1024, bias=False)
              (d

In [8]:
ds_frame_dict["train"]

Unnamed: 0,document,t5_large_output
0,"The full cost of damage in Newton Stewart, one...",flood damage in Newton Stewart still being ass...
1,A fire alarm went off at the Holiday Inn in Ho...,fire alarm went off at the holiday inn in hope...
2,Ferrari appeared in a position to challenge un...,Mercedes will start the race on pole ahead of ...
3,"John Edward Bates, formerly of Spalding, Linco...",67-year-old accused of 22 charges including tw...
4,Patients and staff were evacuated from Cerahpa...,a man receiving treatment at a clinic in Istan...
...,...,...
204040,The initial figure released in July was booste...,net trade boosted GDP by one percentage point ...
204041,"MEPs, including European Parliament chief Brex...","eu leaders say the proposal is a ""damp squib"" ..."
204042,Lincoln Red Imps will bring a 1-0 lead to Glas...,defender sviatchenko says he is surprised by d...
204043,Former Liverpool defender Mark Lawrenson expan...,former defender marks lawrenson says the reds'...


In [9]:
distill_dataset = loadDatasetFromCachedDataframe("./cache", dataset.keys(), DATASET_NAME)

Generating train split: 204045 examples [00:09, 22644.90 examples/s]
Generating validation split: 11332 examples [00:00, 42530.97 examples/s]
Generating test split: 11334 examples [00:00, 41903.37 examples/s]


In [11]:
distill_dataset["train"]

 't5_large_output': 'flood damage in Newton Stewart still being assessed by authorities. repairs are ongoing in the town after the river cree overflowed into the town. a flood alert remains in place across the borders because of the constant rain.'}

In [12]:
for curr_name, curr_dataset in distill_dataset.items():
    curr_base_dataset = dataset[curr_name]
    curr_base_columns = set(curr_base_dataset.column_names)
    curr_base_columns.remove(INPUT_KEY)

    for curr_col_name in curr_base_columns:
        curr_dataset = curr_dataset.add_column(
            curr_col_name,
            dataset[curr_name][curr_col_name]
        )

    distill_dataset[curr_name] = curr_dataset

In [14]:
distill_dataset["train"]

Dataset({
    features: ['document', 't5_large_output', 'id', 'summary'],
    num_rows: 204045
})

In [15]:
distill_dataset.push_to_hub(
    f"lilferrit/xsum_t5_distillation",
)

Creating parquet from Arrow format: 100%|██████████| 103/103 [00:02<00:00, 50.14ba/s]
Creating parquet from Arrow format: 100%|██████████| 103/103 [00:02<00:00, 48.82ba/s]
Uploading the dataset shards: 100%|██████████| 2/2 [00:16<00:00,  8.29s/it]
Creating parquet from Arrow format: 100%|██████████| 12/12 [00:00<00:00, 41.26ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.44s/it]
Creating parquet from Arrow format: 100%|██████████| 12/12 [00:00<00:00, 50.46ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.19s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/lilferrit/xsum_t5_distillation/commit/d0a1c16a76631d3f9fc12d2a669b40466d5727fa', commit_message='Upload dataset', commit_description='', oid='d0a1c16a76631d3f9fc12d2a669b40466d5727fa', pr_url=None, pr_revision=None, pr_num=None)