In [15]:
# !pip install wget
!pip install torch -q
!pip install transformers -q
!pip install datasets -q

In [16]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [17]:
import torch

if torch.cuda.is_available():
  device = torch.device("cuda")
  device_count = torch.cuda.device_count()
  device_name = torch.cuda.get_device_name(0)

  print(f"There are {device_count} GPU(s) available.")
  print(f"We will use the GPU: {device_name}")


else:
  print("No GPU available, using the CPU instead.")
  device = torch.device("cpu")

No GPU available, using the CPU instead.


In [18]:
import torch
from transformers import DistilBertTokenizer
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split


In [19]:
class PoemDataset(Dataset):
    def __init__(self, sentences, poems, tokenizer, max_length):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.inputs = []

        for sentence, poem in zip(sentences, poems):
            self.inputs.append(f"{sentence} {poem} {tokenizer.eos_token}")

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_text = self.inputs[idx]
        encodings = self.tokenizer(input_text, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
        return {
            'input_ids': encodings['input_ids'].squeeze(),
            'attention_mask': encodings['attention_mask'].squeeze(),
        }


In [20]:

def prepare_poem_dataset(angry_sentences, funny_poems, model_name='gpt2', max_length=128, batch_size=4):
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    train_sentences, test_sentences, train_poems, test_poems = train_test_split(angry_sentences, funny_poems, test_size=0.2, random_state=42)

    train_dataset = PoemDataset(train_sentences, train_poems, tokenizer, max_length)
    test_dataset = PoemDataset(test_sentences, test_poems, tokenizer, max_length)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

    return train_dataloader, test_dataloader, tokenizer

In [21]:
angry_sentences = [
    "I can't believe they forgot my birthday!",
    "This traffic is driving me crazy!",
    "Why is the WiFi so slow today?",
    "I'm so tired of eating the same thing every day!",
    "My phone battery always dies when I need it most!",
    "Why do I always lose my keys right when I'm late?",
    "I hate it when people chew with their mouth open!",
    "How come the line is always longest when I'm in a hurry?",
    "Why does it always rain when I forget my umbrella?",
    "I can't stand it when people don't use their turn signals!"
]

funny_poems = [
    "Forgotten day, oh what a blight! / But who needs cake at midnight? / Perhaps they plan a grand surprise / Or simply can't read calendar's guise.",
    "Cars crawl like snails on hot concrete / A turtle race can't be beat / In this jam, I'll grow a beard / Road rage? Nah, I'm just weird.",
    "Internet crawls, my patience thins / Loading bar becomes my frenemy / I could've trained a pigeon / To deliver emails more speedy.",
    "Monotonous meals, day after day / My taste buds threaten to run away / Perhaps I'll start a food rebellion / And eat my socks for this meal's hellion.",
    "Battery drains, oh cruel device! / Always fails at moments precise / I'll invent a phone powered by sighs / Or just yell my messages to the skies.",
    "Keys play hide and seek, what a game! / As I'm rushing out, they're to blame / I'll tie them to a giant balloon / So finding them won't spell my doom.",
    "Open-mouthed chewers, please beware / Your dinner sounds pollute the air / I'll invent a mute button for mouths / Or dine exclusively down south.",
    "Lines stretch long when time is tight / A cosmic joke, an endless plight / I'll master teleportation soon / Or just camp out since last June.",
    "Raindrops fall as umbrellas hide / Weather forecasts have surely lied / I'll grow a waterproof hairdo / Or just pretend I'm at the zoo.",
    "Turn signals forgotten, cars swerve / Testing each driver's last nerve / I'll invent telepathic cars / Or stick big arrows to their fars."
]

In [23]:
print("Batch keys:", batch.keys())

Batch keys: dict_keys(['input_ids', 'attention_mask'])


In [24]:
for batch in train_dataloader:
    print("Batch keys:", batch.keys())
    print("Input shape:", batch['input_ids'].shape)
    print("Attention mask shape:", batch['attention_mask'].shape)
    # Comment out or remove the following line for now
    # print("Labels shape:", batch['labels'].shape)
    break

Batch keys: dict_keys(['input_ids', 'attention_mask'])
Input shape: torch.Size([4, 128])
Attention mask shape: torch.Size([4, 128])


In [25]:
def train_model(train_dataloader, model, optimizer, scheduler, device, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in tqdm(train_dataloader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            scheduler.step()

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_dataloader)}")


In [26]:
def evaluate_model(test_dataloader, model, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in tqdm(test_dataloader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            total_loss += outputs.loss.item()

    avg_loss = total_loss / len(test_dataloader)
    print(f"Average test loss: {avg_loss}")

In [27]:
def generate_poem(sentence, model, tokenizer, device, max_length=128):
    model.eval()
    input_ids = tokenizer.encode(sentence, return_tensors='pt').to(device)
    attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)

    with torch.no_grad():
        output = model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_length=max_length,
            num_return_sequences=1,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            temperature=0.7,
            no_repeat_ngram_size=2,
            pad_token_id=tokenizer.eos_token_id
        )

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    poem = generated_text[len(sentence):].strip()  # Remove the input sentence from the output
    lines = poem.split('.')[:4]  # Get first 4 sentences
    return '\n'.join(line.strip() for line in lines if line.strip())


In [28]:
train_dataloader, test_dataloader, tokenizer = prepare_poem_dataset(angry_sentences, funny_poems)

# Set up the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.to(device)

# Set up optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=5e-5)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * 10)

# Train the model
train_model(train_dataloader, model, optimizer, scheduler, device, num_epochs=10)

# Generate a poem
new_angry_sentence = "I can't believe I missed my bus!"
generated_poem = generate_poem(new_angry_sentence, model, tokenizer, device)
print(f"Input: {new_angry_sentence}")
print(f"Generated poem:\n{generated_poem}")

100%|██████████| 2/2 [00:20<00:00, 10.19s/it]


Epoch 1/10, Loss: 8.178539752960205


100%|██████████| 2/2 [00:17<00:00,  8.83s/it]


Epoch 2/10, Loss: 4.259954929351807


100%|██████████| 2/2 [00:16<00:00,  8.17s/it]


Epoch 3/10, Loss: 2.306539297103882


100%|██████████| 2/2 [00:16<00:00,  8.32s/it]


Epoch 4/10, Loss: 1.912907361984253


100%|██████████| 2/2 [00:18<00:00,  9.06s/it]


Epoch 5/10, Loss: 1.8905556201934814


100%|██████████| 2/2 [00:16<00:00,  8.23s/it]


Epoch 6/10, Loss: 1.8167877197265625


100%|██████████| 2/2 [00:16<00:00,  8.10s/it]


Epoch 7/10, Loss: 1.7516024708747864


100%|██████████| 2/2 [00:16<00:00,  8.40s/it]


Epoch 8/10, Loss: 1.6807235479354858


100%|██████████| 2/2 [00:17<00:00,  8.77s/it]


Epoch 9/10, Loss: 1.657581090927124


100%|██████████| 2/2 [00:17<00:00,  8.76s/it]

Epoch 10/10, Loss: 1.6297737956047058
Input: I can't believe I missed my bus!
Generated poem:




