In [None]:
import pandas as pd
import torch

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer

In [None]:
# MODIFY THE PARAMS HERE
model_type = "gpt2-large"
model_path = "./models/gpt2-large/arxiv_50000/blk0.7/checkpoint-47104/pytorch_model.bin"
data_file = "arxiv_5000" # Please put data_file (.csv) in ./data/

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_type)
config = AutoConfig.from_pretrained(model_type)
model = AutoModelForCausalLM.from_pretrained(model_type, config=config)
model.resize_token_embeddings(len(tokenizer))

# Prepare for batch generation
tokenizer.padding_side = "left"
tokenizer.truncation_side = "right"
tokenizer.pad_token = tokenizer.eos_token

In [None]:
model.load_state_dict(torch.load(model_path))
model = model.to(device)

In [None]:
raw_datasets = load_dataset('csv', data_files={'train': "./data/{}.csv".format(data_file)}, split=f"train[15%:]")
    
column_names = list(raw_datasets.features)
text_column_name = "text" if "text" in column_names else column_names[0]

In [None]:
def tokenize_prompt(examples, max_length=50):    
    return tokenizer(
        examples[text_column_name], return_tensors='pt',
        padding=True, truncation=True, max_length=max_length,
    )

def generate_text(data, batch_size=100, prompt_length=50, max_new_tokens=100, num_beams=1):
    # Create batches of tokenized prompts
    prompts = data.map(
        tokenize_prompt,
        batched=True,
        batch_size=batch_size,
        remove_columns=column_names,
        desc="Running tokenizer on dataset",
        fn_kwargs={"max_length": prompt_length},
    )
    prompts.set_format(type="torch")
    
    # Map the generation function over each batch
    def model_gen(examples):
        return {"model_gen": model.generate(
            input_ids=examples["input_ids"].to(device),
            attention_mask=examples["attention_mask"].to(device),
            max_new_tokens=max_new_tokens,
            num_beams=num_beams,
            pad_token_id=tokenizer.eos_token_id,
        )}
    
    model_outputs = prompts.map(
        model_gen,
        batched=True,
        batch_size=batch_size,
        desc="Generating",
    )
    
    # Decode the tokenized outputs
    return tokenizer.batch_decode(model_outputs["model_gen"], skip_special_tokens=True)

def run_and_save_gen(prompt_lengths, num_beams):
    output_df = pd.DataFrame({"text": raw_datasets["text"]})
    
    for prompt_length in prompt_lengths:
        for beams in num_beams:
            config = {
                "batch_size": 100, # Larger is faster, but costs more RAM
                "prompt_length": prompt_length, # Number of tokens, NOT words
                "max_new_tokens": 100,
                "num_beams": beams, # 1 for greedy search
            }
            print(config)

            outputs = generate_text(raw_datasets, **config)
            
            col_name = "promptLength{}_numBeams{}".format(prompt_length, beams)
            output_df[col_name] = outputs

    output_path = "./model_gen_{}_{}.csv".format(model_type, data_file)
    print("Saving generation output to", output_path)

    output_df.to_csv(output_path, index=False)

In [None]:
run_and_save_gen(
    prompt_lengths=[50, 100, 200, 500],
    num_beams=[1, 100],
)