In [13]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from nltk.tokenize import word_tokenize
import numpy as np

In [14]:
BATCH_SIZE = 64
MAX_SEQ_LENGTH = 128

EMBED_DIM = 256
HIDDEN_SIZE = 256
NUM_LAYERS = 3
DROPOUT = 0.3
LEARNING_RATE = 1e-3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [15]:
class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, num_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_size,
            num_layers,
            dropout=dropout,
            batch_first=True,
        )
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.layer_norm = nn.LayerNorm(hidden_size)
        print(f"{self._count_parameters() / 1e6:.2f}M parameters")

    def forward(self, x):
        emb = self.dropout(self.embedding(x))
        output, hidden = self.lstm(emb)
        output = self.layer_norm(output)
        logits = self.fc(output)
        return logits

    def _count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


In [16]:
def preprocess(text):
    return word_tokenize(text.lower())

def load_vocabulary(load_dir):
    vocab_path = os.path.join(load_dir, 'vocabulary.json')
    with open(vocab_path, 'r') as f:
        vocab_data = json.load(f)
    return vocab_data['vocab'], vocab_data['word_to_idx']

def load_model(model, optimizer, load_dir, device):
    model_path = os.path.join(load_dir, 'best_model.pth')
    
    # load model to current device
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    if device != 'cpu':
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    
    print(f"Model loaded from epoch {epoch} with loss: {loss:.4f}")
    return model, optimizer, epoch, loss

In [17]:
load_dir = "model_checkpoints"
load_dir = "model_checkpoints"
vocab, word_to_idx = load_vocabulary(load_dir)
UNK_idx = word_to_idx["<UNK>"]

model = LSTMLanguageModel(
    len(vocab), EMBED_DIM, HIDDEN_SIZE, NUM_LAYERS, DROPOUT
).to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
model, optimizer, epoch, loss = load_model(
    model, optimizer, load_dir, device
)

6.81M parameters
Model loaded from epoch 2 with loss: 2.1472


In [23]:
def generate_text(model, start_sequence, max_length=100, temperature=1):
    model.eval()
    current_sequence = start_sequence
    generated_sequence = start_sequence.copy()

    with torch.inference_mode():
        for _ in range(max_length):
            input_seq = torch.tensor(
                [word_to_idx.get(word, UNK_idx) for word in current_sequence]
            ).unsqueeze(0).to(device)

            output = model(input_seq)
            last_word_logits = output[0, -1, :]

            scaled_logits = last_word_logits / temperature

            probs = F.softmax(scaled_logits, dim=0).cpu().numpy()

            # sample next word
            next_word_idx = np.random.choice(len(probs), p=probs)
            next_word = vocab[next_word_idx]
            generated_sequence.append(next_word)

            # stop if we generate an <EOS> token
            if next_word == '<EOS>':
                break

            # update current sequence
            current_sequence = current_sequence[1:] + [next_word]

    if generated_sequence[-1] == '<EOS>':
        generated_sequence = generated_sequence[:-1]

    return ' '.join(generated_sequence)

prompt = preprocess("Once upon a time")
for i in range(5):
    output = generate_text(model, prompt, max_length=30)
    print(output)

once upon a time , a butterfly flew higher than any other child . they were sad to hear the problem . it was broken than yours . he wanted to make this own
once upon a time , there was a big cat . the cat said goodbye to the other side of the garden . the sun was shining and bright , it made everything wet
once upon a time , a fancy blanket with flowers . they were very happy . then , he had a big smile and a lot of friends . everyone started to complain that
once upon a time , there were big , hairy and black bear was not shaking . but then , something terrible happened . it was a soft , warm blanket . suddenly ,
once upon a time , a sandbox was very strong . the sun shone in the sky in the air sky . it was so scary with the loud bark . each day ,
