In [26]:
import yaml
import torch
import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.utils.data import DataLoader
from dataset import SingleStage, DoubleStage, PromptToOutline  # Adjust if your dataset class is named differently
from utils import calculate_perplexity, calculate_distinct_n
from tqdm import tqdm
import os

Prompt to Outline

In [28]:
with open('config.yaml') as file:
    config = yaml.full_load(file)

training_args = config['training_args']

# Checkpoint directory
checkpoint_dir = os.path.join(config['model']['save_path'], 'pto_stage')
model_path = os.path.join(checkpoint_dir, 'model.safetensors')

In [29]:
decoder_name = config['model']['decoder']
tokenizer = GPT2Tokenizer.from_pretrained(decoder_name, padding_side="left")

if os.path.exists(model_path):
    print("Loading model from checkpoint...")
    model = GPT2LMHeadModel.from_pretrained(checkpoint_dir, use_safetensors=True)
else:
    print("Initializing new model...")
    raise FileNotFoundError(f"Checkpoint not found in {checkpoint_dir}. Ensure the checkpoint exists before resuming.")

Loading model from checkpoint...


In [30]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer), mean_resizing=False)
model.config.pad_token_id = tokenizer.pad_token_id

In [31]:
train_dataset = PromptToOutline(config['dataset']['test_path'], tokenizer)
train_data_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [52]:
sample_index = 12 # Change this index to sample a different example
sample_data = train_dataset[sample_index]

# Ensure 'input_ids' and 'attention_mask' are in the correct format
if 'input_ids' in sample_data and 'attention_mask' in sample_data:
    input_ids = sample_data['input_ids'].to(device)  # No need to unsqueeze, it's already 2D
    attention_mask = sample_data['attention_mask'].to(device)  # Move to device
else:
    raise KeyError("'input_ids' or 'attention_mask' not found in sample_data")

# Check the shape of input_ids
print("Input IDs Shape:", input_ids.shape)

Input IDs Shape: torch.Size([1, 128])


In [53]:
with torch.no_grad():
    # Use max_new_tokens to specify how many tokens to generate
    generated_ids = model.generate(input_ids, attention_mask=attention_mask,  repetition_penalty=2.0, max_new_tokens=50, num_return_sequences=1, pad_token_id=tokenizer.pad_token_id)

In [54]:
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

In [55]:
print("Input Text:\n", tokenizer.decode(input_ids[0], skip_special_tokens=True))
print("Generated Text:\n", generated_text)

Input Text:
 [ WP ] Everyone has a reaper . The further away it is , the longer you have left to live . Every day it inches a little bit closer , but it is always there . Except yours , which disappeared three weeks ago <SEP> shrieking twisting screaming tearing countless incredible technological achievements could never keep tied back white sweater expanding ever ethereal smoke trailing behind thousand thoughts crossed snake oil supplements many things ran lovely chest beneath great explosive conglomeration
Generated Text:
 [ WP ] Everyone has a reaper . The further away it is , the longer you have left to live . Every day it inches a little bit closer , but it is always there . Except yours , which disappeared three weeks ago <SEP> shrieking twisting screaming tearing countless incredible technological achievements could never keep tied back white sweater expanding ever ethereal smoke trailing behind thousand thoughts crossed snake oil supplements many things ran lovely chest beneath

Double Stage

In [56]:
with open('config.yaml') as file:
    config = yaml.full_load(file)

training_args = config['training_args']

# Checkpoint directory
checkpoint_dir = os.path.join(config['model']['save_path'], 'double_stage')
model_path = os.path.join(checkpoint_dir, 'model.safetensors')

In [57]:
decoder_name = config['model']['decoder']
tokenizer = GPT2Tokenizer.from_pretrained(decoder_name, padding_side="left")

if os.path.exists(model_path):
    print("Loading model from checkpoint...")
    model = GPT2LMHeadModel.from_pretrained(checkpoint_dir, use_safetensors=True)
else:
    print("Initializing new model...")
    raise FileNotFoundError(f"Checkpoint not found in {checkpoint_dir}. Ensure the checkpoint exists before resuming.")

Loading model from checkpoint...


In [58]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer), mean_resizing=False)
model.config.pad_token_id = tokenizer.pad_token_id

In [59]:
train_dataset = DoubleStage(config['dataset']['test_path'], tokenizer)
train_data_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [72]:
sample_index = 10
# Change this index to sample a different example
sample_data = train_dataset[sample_index]

# Ensure 'input_ids' and 'attention_mask' are in the correct format
if 'input_ids' in sample_data and 'attention_mask' in sample_data:
    input_ids = sample_data['input_ids'].to(device)  # No need to unsqueeze, it's already 2D
    attention_mask = sample_data['attention_mask'].to(device)  # Move to device
else:
    raise KeyError("'input_ids' or 'attention_mask' not found in sample_data")

# Check the shape of input_ids
print("Input IDs Shape:", input_ids.shape)

Input IDs Shape: torch.Size([1, 128])


In [79]:
with torch.no_grad():
    # Use max_new_tokens to specify how many tokens to generate
    generated_ids = model.generate(input_ids, attention_mask=attention_mask, repetition_penalty=2.0,  max_new_tokens=46, num_return_sequences=1, pad_token_id=tokenizer.pad_token_id)

In [80]:
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

In [81]:
print("Input Text:\n", tokenizer.decode(input_ids[0], skip_special_tokens=True))
print("Generated Text:\n", generated_text)

Input Text:
 [ WP ] A magical mirror shows your reflection and your future soulmate . You only see your reflection . <S> stand reading another old junky sci teen girl interrupted knick knacks decorated carnival music seemed carnival director walks “ ten dollars another idiot ten dollars another day start packing <SEP> It was just another day at the carnival . ” <newline> <newline> “ And it works ? I put the tent and everything in the back of my trailer with a skip in my step .
Generated Text:
 [ WP ] A magical mirror shows your reflection and your future soulmate . You only see your reflection . <S> stand reading another old junky sci teen girl interrupted knick knacks decorated carnival music seemed carnival director walks “ ten dollars another idiot ten dollars another day start packing <SEP> It was just another day at the carnival . ” <newline> <newline> “ And it works ? I put the tent and everything in the back of my trailer with a skip in my step . The next morning , as she walked

Single Stage

In [82]:
with open('config.yaml') as file:
    config = yaml.full_load(file)

training_args = config['training_args']

# Checkpoint directory
checkpoint_dir = os.path.join(config['model']['save_path'], 'single_stage')
model_path = os.path.join(checkpoint_dir, 'model.safetensors')

In [83]:
decoder_name = config['model']['decoder']
tokenizer = GPT2Tokenizer.from_pretrained(decoder_name, padding_side="left")

if os.path.exists(model_path):
    print("Loading model from checkpoint...")
    model = GPT2LMHeadModel.from_pretrained(checkpoint_dir, use_safetensors=True)
else:
    print("Initializing new model...")
    raise FileNotFoundError(f"Checkpoint not found in {checkpoint_dir}. Ensure the checkpoint exists before resuming.")

Loading model from checkpoint...


In [84]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer), mean_resizing=False)
model.config.pad_token_id = tokenizer.pad_token_id

In [85]:
train_dataset = DoubleStage(config['dataset']['test_path'], tokenizer)
train_data_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [90]:
sample_index = 5 # Change this index to sample a different example
sample_data = train_dataset[sample_index]

# Ensure 'input_ids' and 'attention_mask' are in the correct format
if 'input_ids' in sample_data and 'attention_mask' in sample_data:
    input_ids = sample_data['input_ids'].to(device)  # No need to unsqueeze, it's already 2D
    attention_mask = sample_data['attention_mask'].to(device)  # Move to device
else:
    raise KeyError("'input_ids' or 'attention_mask' not found in sample_data")

# Check the shape of input_ids
print("Input IDs Shape:", input_ids.shape)

Input IDs Shape: torch.Size([1, 128])


In [91]:
with torch.no_grad():
    # Use max_new_tokens to specify how many tokens to generate
    generated_ids = model.generate(input_ids, attention_mask=attention_mask,  repetition_penalty=2.0, max_new_tokens=50, num_return_sequences=1, pad_token_id=tokenizer.pad_token_id)

In [92]:
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

In [93]:
print("Input Text:\n", tokenizer.decode(input_ids[0], skip_special_tokens=True))
print("Generated Text:\n", generated_text)


Input Text:
 [ TT ] `` Shut the dog up . '' <S> sides heaved one last breath skittish animal moved skinny animal circling rough sandy ground keep watch like blood trailed behind running circles around right away sir “ yes sir head officer yelled <SEP> The dog stopped its barking , and shifted its black eyes to me . <newline> <newline> “ GRUNT , ” the sound of my officer ’ s voice rang out , scaring the dog away . <newline> <newline> “ Grunt , get back in the jeep , were running off schedule , ”
Generated Text:
 [ TT ] `` Shut the dog up . '' <S> sides heaved one last breath skittish animal moved skinny animal circling rough sandy ground keep watch like blood trailed behind running circles around right away sir “ yes sir head officer yelled <SEP> The dog stopped its barking , and shifted its black eyes to me . <newline> <newline> “ GRUNT , ” the sound of my officer ’ s voice rang out , scaring the dog away . <newline> <newline> “ Grunt , get back in the jeep , were running off schedule 