In [1]:
import argparse
import time
from dataset import *
from models import * 
from generation import *
import numpy as np

from torch.utils.data import DataLoader
from torch import optim
import torch.nn.functional as F
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
DIR_PATH = "dataset/poems.csv"
SOV_TOKEN = "<SOV>" # start of verse
EOV_TOKEN = "<EOV>" # end of verse
encoder_ckpt = "saved_models/encoder_1.pt"
decoder_ckpt = "saved_models/decoder_1.pt"


parser = argparse.ArgumentParser()
parser.add_argument('--max_epochs', type=int, default=40)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--sequence_length', type=int, default=10)
args, unknown = parser.parse_known_args()

# Define the dataset and dataloader
dataset = SpanishPoemsDataset(DIR_PATH, SOV_TOKEN, EOV_TOKEN, args)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

print(dataset.__getitem__(3))
print('Vocab:', len(dataset.unique_words))

(tensor([2449, 1375,    1,    0,    6,   15, 2450, 2451,    5,   66]), tensor([1375,    1,    0,    6,   15, 2450, 2451,    5,   66, 2452]))
Vocab: 7506


In [3]:
# Set hyperparameters

input_size = len(dataset.unique_words) # size of the vocabulary for the input sequence
output_size = len(dataset.unique_words) # size of the vocabulary for the target sequence
hidden_size = 128
num_layers = 1
batch_size = args.batch_size
sequence_length = args.sequence_length
num_epochs = args.max_epochs
learning_rate = 0.001
save_epochs = 2

# Instantiate the encoder and decoder
encoder = GRUEncoder(input_size, hidden_size, num_layers)
decoder = GRUDecoder(hidden_size, output_size, num_layers)

# Define the loss function and optimizer 
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=learning_rate)

# TRAINING

In [4]:
# Training loop
start_time = time.time()
loss_history = []
encoder.train()
decoder.train()

for epoch in range(num_epochs):
    total_loss = 0
    for batch_inputs, batch_targets in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        batch_inputs = batch_inputs
        batch_targets = batch_targets

        # Initialize the hidden state of the encoder
        hidden = encoder.init_hidden(sequence_length)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass - Encoder
        encoder_outputs, hidden = encoder(batch_inputs, hidden)

        # Initialize the hidden state of the decoder with the final encoder hidden state
        decoder_hidden = hidden[:, 1:sequence_length, :].contiguous()

        # Prepare the input and target sequences for the decoder
        decoder_inputs = batch_targets[:, :-1] 
        decoder_targets = batch_targets[:, 1:]

        # Forward pass - Decoder
        decoder_outputs, _ = decoder(decoder_inputs, decoder_hidden)

        # Calculate the loss
        loss = criterion(decoder_outputs.reshape(-1, output_size), decoder_targets.reshape(-1))

        # Backpropagation
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        loss_history.append(loss.item())
    
    # Save the model after each epoch
    if epoch % save_epochs == 0:
        print("-> Saving checkpoint") 
        torch.save(encoder.state_dict(), encoder_ckpt)
        torch.save(decoder.state_dict(), decoder_ckpt)

    # Print average loss for the epoch
    avg_loss = total_loss / len(dataloader)
    loss_history.append(avg_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
    
end_time = time.time()
elapsed_time = end_time - start_time
print("Elapsed time: ", elapsed_time/3600) 

Epoch 1/5:   5%|▍         | 16/323 [00:05<01:47,  2.84it/s]


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['font.size'] = 25
plt.figure(figsize=(10,6))
plt.plot(range(len(loss_history)), loss_history, label="Model_1")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Encoder: GRU - Decoder: GRU")
plt.legend()
#plt.savefig("loss_512.jpg", bbox_inches='tight')
plt.show()

# GENERATE POEM 

In [5]:
poem, perplexity = generate_poem_GRU(
    encoder, decoder, dataset, 
    0, 1, max_length=50, temperature=0.8, top_k=10, 
    encoder_type='GRU', 
    encoder_ckpt=encoder_ckpt, 
    decoder_ckpt=decoder_ckpt
)

print('')
print(poem)
print('Perplexity: ', perplexity)


entre lentos pescados
sobrevuelan nuestros cuerpos
dice ven conmigo
tú eres la cabeza
cuando los ojos y mi vida es que este camino
no hay angustia comparable a no respirarte
no verte
pensar tú y todo
Perplexity:  13.702206871289045
