In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
import nltk
from nltk.tokenize import word_tokenize
from collections import Counter
from tqdm import tqdm
import numpy as np

In [3]:
BATCH_SIZE = 32

EMBED_DIM = 100
HIDDEN_SIZE = 128
NUM_LAYERS = 2
DROPOUT = 0.2
LEARNING_RATE = 3e-4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
nltk.download("punkt")
dataset = load_dataset("roneneldan/TinyStories")

train_lines = dataset['train']['text'][:100000]

def preprocess(text):
    return word_tokenize(text.lower())

# tokenize
train_lines = [preprocess(line) for line in train_lines]

# building vocab
word_counts = Counter([word for text in train_lines for word in text])
vocab = (
    ["<UNK>", "<PAD>", "<EOS>"] +
    [word for word, count in word_counts.items() if count > 5]
)
word_to_idx = {word: i for i, word in enumerate(vocab)}

PAD_idx = word_to_idx['<PAD>']
UNK_idx = word_to_idx['<UNK>']
EOS_idx = word_to_idx['<EOS>']

indexed_data = [
    [word_to_idx.get(word, UNK_idx) for word in line] + [EOS_idx] for line in train_lines
]

len(vocab), len(word_counts)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

(…)-00000-of-00004-2d5a1467fff1081b.parquet:   0%|          | 0.00/249M [00:00<?, ?B/s]

(…)-00001-of-00004-5852b56a2bd28fd9.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00002-of-00004-a26307300439e943.parquet:   0%|          | 0.00/246M [00:00<?, ?B/s]

(…)-00003-of-00004-d243063613e5a057.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00000-of-00001-869c898b519ad725.parquet:   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

(10259, 21570)

In [5]:
class TinyStoriesDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch):
    # sort the batch by sequence length (descending)
    batch.sort(key=lambda x: len(x), reverse=True)
    sequences, lengths = zip(*[(torch.tensor(seq), len(seq)) for seq in batch])

    # pad sequences
    sequences_padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True, padding_value=PAD_idx)

    return sequences_padded, torch.tensor(lengths)

dataset = TinyStoriesDataset(indexed_data)
dataloader = DataLoader(dataset, BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

In [6]:
class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, num_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_size,
            num_layers,
            dropout=dropout,
            batch_first=True,
        )
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, vocab_size)
        print(f"{self._count_parameters()/1e6:.2f}M parameters")

    def forward(self, x):
        emb = self.dropout(self.embedding(x))
        output, hidden = self.lstm(emb)
        logits = self.fc(output)
        return logits

    def _count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


model = LSTMLanguageModel(
    len(vocab), EMBED_DIM, HIDDEN_SIZE, NUM_LAYERS, DROPOUT
).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_idx)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

2.60M parameters


In [11]:
def train(model, dataloader, criterion, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch, (sequence, lengths) in enumerate(dataloader):
            sequence = sequence.to(device)

            targets = sequence[:, 1:].contiguous()

            outputs = model(sequence[:, :-1])

            outputs = outputs.view(-1, outputs.size(2))
            targets = targets.view(-1)

            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
            optimizer.step()

            total_loss += loss.item()

            if batch % 300 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch+1}/{len(dataloader)}], Loss: {loss.item():.4f}')

        avg_loss = total_loss / len(dataloader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')

In [18]:
train(model, dataloader, criterion, optimizer, num_epochs=1)

Epoch [1/1], Batch [1/3125], Loss: 3.3775
Epoch [1/1], Batch [301/3125], Loss: 3.2441
Epoch [1/1], Batch [601/3125], Loss: 3.3249
Epoch [1/1], Batch [901/3125], Loss: 3.3974
Epoch [1/1], Batch [1201/3125], Loss: 3.2553
Epoch [1/1], Batch [1501/3125], Loss: 3.1945
Epoch [1/1], Batch [1801/3125], Loss: 3.1968
Epoch [1/1], Batch [2101/3125], Loss: 3.2693
Epoch [1/1], Batch [2401/3125], Loss: 3.0774
Epoch [1/1], Batch [2701/3125], Loss: 3.1563
Epoch [1/1], Batch [3001/3125], Loss: 3.2721
Epoch [1/1], Average Loss: 3.2540


In [19]:
def generate_text(model, start_sequence, max_length=100, temperature=1.0):
    model.eval()
    current_sequence = start_sequence
    generated_sequence = start_sequence.copy()

    with torch.inference_mode():
        for _ in range(max_length):
            input_seq = torch.tensor(
                [word_to_idx.get(word, UNK_idx) for word in current_sequence]
            ).unsqueeze(0).to(device)

            output = model(input_seq)
            last_word_logits = output[0, -1, :]

            scaled_logits = last_word_logits / temperature

            probs = F.softmax(scaled_logits, dim=0).cpu().numpy()

            # sample next word
            next_word_idx = np.random.choice(len(probs), p=probs)
            next_word = vocab[next_word_idx]
            generated_sequence.append(next_word)

            # stop if we generate an <EOS> token
            if next_word == '<EOS>':
                break

            # update current sequence
            current_sequence = current_sequence[1:] + [next_word]

    if generated_sequence[-1] == '<EOS>':
        generated_sequence = generated_sequence[:-1]

    return ' '.join(generated_sequence)

In [20]:
prompt = preprocess("I live an a forest")

output = generate_text(model, prompt, max_length=10)
output

'i live an a forest and jack had a big surprise . the big bear'