In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset, load_from_disk
from datasets import Dataset, DatasetDict
from typing import Dict, List
from functools import partial
from pathlib import Path
from tqdm import tqdm

import pandas as pd
import os
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-large")
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-large")

In [3]:
model.eval()
model.to("cuda:0")

T5ForConditionalGeneration(
  (shared): Embedding(32128, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=1024, bias=False)
              (k): Linear(in_features=1024, out_features=1024, bias=False)
              (v): Linear(in_features=1024, out_features=1024, bias=False)
              (o): Linear(in_features=1024, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 16)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=1024, out_features=4096, bias=False)
              (wo): Linear(in_features=4096, out_features=1024, bias=False)
              (d

In [4]:
cnn_dataset = load_dataset("cnn_dailymail", "3.0.0")

## Utility Functions

In [5]:
def getEmptyFrameDict(
    input_key: str,
    output_key: str,
    dataset_dict: DatasetDict
) -> Dict[str, pd.DataFrame]:
    dataframe_dict = dict()

    for curr_name, curr_dataset in dataset_dict.items():
        next_dataframe = pd.DataFrame.from_dict({
            input_key: curr_dataset[input_key],
            output_key: [""] * len(curr_dataset[input_key])
        })

        dataframe_dict[curr_name] = next_dataframe

    return dataframe_dict

def cacheFrame(
    data_frame: pd.DataFrame,
    cache_dir: Path,
    cache_entry_name: str
) -> None:
    file_name = f"{cache_entry_name}.cache.parquet"
    data_frame.to_parquet(
        os.path.join(
            cache_dir,
            file_name
        ),
        engine = "pyarrow",
        compression = None
    )

def cacheFrameDict(
    cache_dir: Path,
    dataframe_dict: Dict[str, pd.DataFrame],
    prefix_name: str = None,
) -> None:
    for curr_name, curr_dataframe in dataframe_dict.items():
        file_name = curr_name

        if prefix_name is not None:
            file_name = f"{prefix_name}_{file_name}"

        cacheFrame(
            curr_dataframe,
            cache_dir,
            file_name
        )

def loadFrameDict(
    cache_dir: Path,
    dataframe_names: List[str],
    prefix_name: str = None,
) -> Dict[str, pd.DataFrame]:
    dataframe_dict = dict()

    for curr_name in dataframe_names:
        file_name = f"{curr_name}.cache.parquet"

        if prefix_name is not None:
            file_name = f"{prefix_name}_{file_name}"

        dataframe_dict[curr_name] = pd.read_parquet(
            os.path.join(
                cache_dir,
                file_name
            ),
            engine = "pyarrow",
        )

    return dataframe_dict

## Run Cells if Creating Fresh Frame Dict

In [6]:
cnn_frame_dict = getEmptyFrameDict("article", "t5_large_output", cnn_dataset)

In [7]:
cacheFrameDict("./cache", cnn_frame_dict, "cnn_dm_distill")

## RUn Cell if Using Cached Frame Dict

In [8]:
cached_cnn_frame_dict = loadFrameDict("./cache", cnn_dataset.keys(), "cnn_dm_distill")

In [9]:
def inferDataFrameDict(
    dataframe_dict: Dict[str, pd.DataFrame],
    batch_size: int,
    input_key: str = "article",
    output_key: str = "t5_large_output",
    prefix: str = "summarize: ",
    max_input_length: int = 1024,
    max_output_length: int = 1024,
    cache_location: str = None,
    dataset_name: str = None,
    batches_per_cache_write: int = None
) -> None:
    using_cache = all(x is not None for x in [cache_location, dataset_name, batches_per_cache_write])

    for curr_name, curr_dataframe in dataframe_dict.items():
        chunks_iter = np.array_split(curr_dataframe, (len(curr_dataframe) // batch_size) + 1)
        row_counter = 0
        
        for chunk_idx, curr_chunk in enumerate(tqdm(chunks_iter)):
            curr_chunk_inputs = list(curr_chunk[input_key])
            curr_chunk_outputs = list(curr_chunk[output_key])
            is_cached = all(x != "" for x in curr_chunk_outputs)

            if is_cached:
                row_counter += len(curr_chunk_outputs)
                continue

            inputs = [prefix + doc for doc in curr_chunk_inputs]
            input_ids = tokenizer(
                inputs, 
                return_tensors = "pt",
                max_length = max_input_length,
                truncation = True,
                padding = True,
            ).input_ids.to("cuda:0")

            outputs = model.generate(input_ids, max_new_tokens = max_output_length)
            outputs.to("cpu")
            decoded_output = tokenizer.batch_decode(outputs, skip_special_tokens = True)

            out_column_index = curr_dataframe.columns.get_loc(output_key)
            end_row_index = row_counter + len(decoded_output)
            curr_dataframe.iloc[row_counter : end_row_index, out_column_index] = decoded_output
            row_counter += len(curr_chunk_outputs)

            if (((chunk_idx + 1) % batches_per_cache_write) == 0) and using_cache:
                cacheFrame(
                    curr_dataframe,
                    cache_location,
                    f"{dataset_name}_{curr_name}"
                )
            
        cacheFrame(
            curr_dataframe,
            cache_location,
            f"{dataset_name}_{curr_name}"
        )

inferDataFrameDict(
    cnn_frame_dict,
    16,
    cache_location = "./cache",
    dataset_name = "cnn_dm_distill",
    batches_per_cache_write = 128
) 

  return bound(*args, **kwds)
  0%|          | 10/14356 [01:47<42:45:54, 10.73s/it]


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacty of 7.79 GiB of which 3.62 MiB is free. Including non-PyTorch memory, this process has 7.78 GiB memory in use. Of the allocated memory 6.63 GiB is allocated by PyTorch, and 1.03 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
cnn_frame_dict["test"]

In [None]:
# Hugging Face doesn't provide an easy way to load a DatasetDict from Pandas,
# so loadDatasetFromCachedDataframe pulls the latest cache entry instead
def loadDatasetFromCachedDataframe(
    cache_dir: Path,
    dataframe_names: List[str],
    prefix_name: str = None,
) -> DatasetDict:
    file_dict = dict()

    for curr_name in dataframe_names:
        file_name = f"{curr_name}.cache.parquet"

        if prefix_name is not None:
            file_name = f"{prefix_name}_{file_name}"

        file_dict[curr_name] = os.path.join(cache_dir, file_name)

    return DatasetDict.from_parquet(file_dict)

cnn_distill_dataset = loadDatasetFromCachedDataframe("./cache", cnn_dataset.keys(), "cnn_dm_distill")