In [3]:
import sys

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
from nlp481.distillation import (
    getEmptyFrameDict,
    cacheFrameDict,
    loadFrameDict,
    inferDataFrameDict,
    loadDatasetFromCachedDataframe
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

tokenizer = AutoTokenizer.from_pretrained("kssteven/T5-large-cnndm")
model = AutoModelForSeq2SeqLM.from_pretrained("kssteven/T5-large-cnndm")

In [3]:
DEVICE = "cuda:0"

model.eval()
model.to(DEVICE)

T5ForConditionalGeneration(
  (shared): Embedding(32100, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32100, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=1024, bias=False)
              (k): Linear(in_features=1024, out_features=1024, bias=False)
              (v): Linear(in_features=1024, out_features=1024, bias=False)
              (o): Linear(in_features=1024, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 16)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=1024, out_features=4096, bias=False)
              (wo): Linear(in_features=4096, out_features=1024, bias=False)
              (d

In [4]:
dataset = load_dataset("cnn_dailymail", "3.0.0")

INPUT_KEY = "article"
DATASET_NAME = "cnndm_t5_distill"

## Run Cells if Creating Fresh Frame Dict

In [6]:
ds_frame_dict = getEmptyFrameDict(INPUT_KEY, "t5_large_output", dataset)

In [7]:
cacheFrameDict("./cache", ds_frame_dict, DATASET_NAME)

## Run Cell if Using Cached Frame Dict

In [5]:
ds_frame_dict = loadFrameDict("./cache", dataset.keys(), DATASET_NAME)

In [6]:
ds_frame_dict["train"]

Unnamed: 0,article,t5_large_output
0,"LONDON, England (Reuters) -- Harry Potter star...",Harry Potter star Daniel Radcliffe gains acces...
1,Editor's note: In our Behind the Scenes series...,"Mentally ill inmates housed on Miami's ""forgot..."
2,"MINNEAPOLIS, Minnesota (CNN) -- Drivers who we...","Driver describes 30-, 35-foot free fall from M..."
3,WASHINGTON (CNN) -- Doctors removed five small...,NEW: Bush will resume his activities at Camp D...
4,(CNN) -- The National Football League has ind...,NEW: NFL Commissioner Roger Goodell says he'll...
...,...,...
287108,"The nine-year-old daughter of a black, unarmed...","Rumain Brisbon, 34, was shot dead by a Phoenix..."
287109,Legalising assisted suicide is a slippery slop...,Dutch death toll has doubled in six years.
287110,A group calling itself 'The Women of the 99 Pe...,The automated calls are illegal because they d...
287111,Most men enjoy a good pint of lager or real al...,"Peter Hill, 56, John Hill, 81, John Drew, 48, ..."


In [7]:
inferDataFrameDict(
    ds_frame_dict,
    model,
    tokenizer,
    16,
    cache_location = "./cache",
    dataset_name = DATASET_NAME,
    input_key = INPUT_KEY,
    batches_per_cache_write = 64
)

# Move model out of VRAM (so NLPG admins don't get mad at us)
model.to("cpu")

# For some reason moving the model to cpu doesn't actually free VRAM
# so just exit from the process
sys.exit(0)

  return bound(*args, **kwds)
100%|██████████| 17945/17945 [00:01<00:00, 12169.50it/s]
100%|██████████| 836/836 [31:43<00:00,  2.28s/it]
  return bound(*args, **kwds)
100%|██████████| 719/719 [45:41<00:00,  3.81s/it]


SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [8]:
ds_frame_dict["train"]

Unnamed: 0,article,t5_large_output
0,"LONDON, England (Reuters) -- Harry Potter star...",Harry Potter star Daniel Radcliffe gains acces...
1,Editor's note: In our Behind the Scenes series...,"Mentally ill inmates housed on Miami's ""forgot..."
2,"MINNEAPOLIS, Minnesota (CNN) -- Drivers who we...","Driver describes 30-, 35-foot free fall from M..."
3,WASHINGTON (CNN) -- Doctors removed five small...,NEW: Bush will resume his activities at Camp D...
4,(CNN) -- The National Football League has ind...,NEW: NFL Commissioner Roger Goodell says he'll...
...,...,...
287108,"The nine-year-old daughter of a black, unarmed...","Rumain Brisbon, 34, was shot dead by a Phoenix..."
287109,Legalising assisted suicide is a slippery slop...,Dutch death toll has doubled in six years.
287110,A group calling itself 'The Women of the 99 Pe...,The automated calls are illegal because they d...
287111,Most men enjoy a good pint of lager or real al...,"Peter Hill, 56, John Hill, 81, John Drew, 48, ..."


In [5]:
distill_dataset = loadDatasetFromCachedDataframe("./cache", dataset.keys(), DATASET_NAME)

Generating train split: 287113 examples [00:23, 12353.65 examples/s]
Generating validation split: 13368 examples [00:02, 5221.30 examples/s]
Generating test split: 11490 examples [00:00, 11910.93 examples/s]


In [6]:
distill_dataset["train"]

Dataset({
    features: ['article', 't5_large_output'],
    num_rows: 287113
})

In [7]:
for curr_name, curr_dataset in distill_dataset.items():
    curr_base_dataset = dataset[curr_name]
    curr_base_columns = set(curr_base_dataset.column_names)
    curr_base_columns.remove(INPUT_KEY)

    for curr_col_name in curr_base_columns:
        curr_dataset = curr_dataset.add_column(
            curr_col_name,
            dataset[curr_name][curr_col_name]
        )

    distill_dataset[curr_name] = curr_dataset

In [8]:
distill_dataset["train"]

Dataset({
    features: ['article', 't5_large_output', 'id', 'highlights'],
    num_rows: 287113
})

In [9]:
distill_dataset.push_to_hub(
    f"lilferrit/cnn_dailymail_t5_distillation",
    revision = "cnndm-checkpoints"
)

Creating parquet from Arrow format: 100%|██████████| 96/96 [00:05<00:00, 18.15ba/s]
Creating parquet from Arrow format: 100%|██████████| 96/96 [00:04<00:00, 21.16ba/s]
Creating parquet from Arrow format: 100%|██████████| 96/96 [00:04<00:00, 22.20ba/s]
Uploading the dataset shards: 100%|██████████| 3/3 [00:39<00:00, 13.17s/it]
Creating parquet from Arrow format: 100%|██████████| 14/14 [00:00<00:00, 23.77ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:02<00:00,  2.22s/it]
Creating parquet from Arrow format: 100%|██████████| 12/12 [00:00<00:00, 23.32ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.91s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/lilferrit/cnn_dailymail_t5_distillation/commit/7f69fe7b38ad72f82bfdc31f00d89866e751959a', commit_message='Upload dataset', commit_description='', oid='7f69fe7b38ad72f82bfdc31f00d89866e751959a', pr_url=None, pr_revision=None, pr_num=None)