In [1]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
import math
import ipywidgets as widgets
from IPython.display import display


In [17]:
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' 
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'


In [2]:
def load_data(prompt_file, story_file, max_length=512, dataset_size=None, chunk_size=1024):
    def clean_punctuation(text):
        for p in '!,.:;?':
            text = text.replace(' ' + p, p)
        text = text.replace(' ' + 'n\'t', 'n\'t')
        text = text.replace(' ' + '\'s', '\'s')
        return text

    def read_file_chunks(file_path, chunk_size):
        with open(file_path, encoding='utf-8-sig') as file:  # Change encoding to 'utf-8-sig'
            while True:
                chunk = file.read(chunk_size)
                if not chunk:
                    break
                yield chunk

    prompts = open(prompt_file, encoding='utf-8-sig').readlines()  # Change encoding to 'utf-8-sig'
    stories = []

    for chunk in read_file_chunks(story_file, chunk_size):
        stories.extend(chunk.splitlines())

    if dataset_size:
        prompts = prompts[:dataset_size]
        stories = stories[:dataset_size]

    data = []
    for prompt, story in zip(prompts, stories):
        combined_text = prompt.strip() + ' <sep> ' + " ".join(story.split()[:300])
        cleaned_text = clean_punctuation(combined_text)
        data.append(cleaned_text[:max_length])

    return data


In [3]:
prompt_file_train = 'train.wp_target'
story_file_train = 'train.wp_source'
prompt_file_valid = 'valid.wp_target'
story_file_valid = 'valid.wp_target'
prompt_file_test = 'test.wp_target'
story_file_test = 'test.wp_source'

# Load a reduced dataset size (e.g., 100 for training, 20 for validation)
train_text = load_data(prompt_file_train, story_file_train, dataset_size=100)
valid_text = load_data(prompt_file_valid, story_file_valid, dataset_size=20)
test_text = load_data(prompt_file_test, story_file_test, dataset_size=20)

In [4]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

inputs_train = tokenizer(train_text, padding=True, truncation=True, max_length=512)
inputs_valid = tokenizer(valid_text, padding=True, truncation=True, max_length=512)
inputs_test = tokenizer(test_text, padding=True, truncation=True, max_length=512)

labels_train = [ids[:sum(mask)] + [-100] * (len(mask) - sum(mask)) for ids, mask in zip(inputs_train['input_ids'], inputs_train['attention_mask'])]
labels_valid = [ids[:sum(mask)] + [-100] * (len(mask) - sum(mask)) for ids, mask in zip(inputs_valid['input_ids'], inputs_valid['attention_mask'])]
labels_test = [ids[:sum(mask)] + [-100] * (len(mask) - sum(mask)) for ids, mask in zip(inputs_test['input_ids'], inputs_test['attention_mask'])]


In [5]:
class StoryDataset(torch.utils.data.Dataset):
    def __init__(self, inputs, labels):
        self.ids = inputs['input_ids']
        self.attention_mask = inputs['attention_mask']
        self.labels = labels

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, item):
        return torch.tensor(self.ids[item], dtype=torch.long), torch.tensor(self.attention_mask[item], dtype=torch.long), torch.tensor(self.labels[item], dtype=torch.long)


In [6]:
train_dataset = StoryDataset(inputs_train, labels_train)
valid_dataset = StoryDataset(inputs_valid, labels_valid)
test_dataset = StoryDataset(inputs_test, labels_test)

train_dataloader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=2)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, shuffle=False, batch_size=2)
test_dataloader = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=2)



In [7]:
model = GPT2LMHeadModel.from_pretrained('gpt2')

total_num_training_steps = len(train_dataloader) * 1  # Assuming 1 epoch
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)


In [10]:
for epoch in range(1):
    model.train()
    train_loss = 0.0

    for inputs in train_dataloader:
        input_ids, attention_mask, labels = [x for x in inputs]

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        scheduler.step()

        train_loss += loss.item()

    average_train_loss = train_loss / len(train_dataloader)
    print(f'Average training loss for Epoch {epoch + 1}: {average_train_loss}')


Average training loss for Epoch 1: 3.622979373931885


In [11]:
model.eval()
eval_loss = []

for inputs in valid_dataloader:
    input_ids, attention_mask, labels = [x for x in inputs]

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        batch_loss = outputs.loss

    eval_loss.append(batch_loss.cpu().item())

average_eval_loss = sum(eval_loss) / len(valid_dataloader)
perplexity = math.exp(average_eval_loss)
print(f'Average validation loss: {average_eval_loss}')
print(f'Perplexity for the validation dataset: {perplexity}')


Average validation loss: 3.5467012882232667
Perplexity for the validation dataset: 34.69866758966632


In [12]:
user_prompt_widget = widgets.Text(
    placeholder='Enter your prompt',
    description='Prompt:',
    disabled=False
)

generate_button = widgets.Button(description='Generate Stories')
output_area = widgets.Output()


In [14]:
def generate_button_click(b):
    user_prompt = user_prompt_widget.value.strip()

    if not user_prompt:
        with output_area:
            print("Prompt cannot be empty. Please enter a prompt.")
        return

    encoded_user_prompt = tokenizer.encode(user_prompt, add_special_tokens=True, return_tensors="pt")

    output_sequences = model.generate(
        input_ids=encoded_user_prompt,
        max_length=300,
        temperature=0.8,
        top_k=30,
        top_p=0.9,
        repetition_penalty=1.0,
        do_sample=True,
        num_return_sequences=1
    )

    if len(output_sequences.shape) > 2:
        output_sequences.squeeze_()

    generated_story = ""
    for generated_sequence in output_sequences:
        generated_sequence = generated_sequence.tolist()
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
        text = text[: text.find(tokenizer.eos_token)]
        generated_story += text

    output_area.clear_output(wait=True)
    with output_area:
        print("\nGenerated Story:")
        print(generated_story)

generate_button.on_click(generate_button_click)


In [15]:
display(user_prompt_widget)
display(generate_button)
display(output_area)


Text(value='', description='Prompt:', placeholder='Enter your prompt')

Button(description='Generate Stories', style=ButtonStyle())

Output()