In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
import nltk
from nltk.tokenize import word_tokenize
from collections import Counter
from tqdm import tqdm
import numpy as np
import os
import json

In [None]:
BATCH_SIZE = 64
MAX_SEQ_LENGTH = 128

EMBED_DIM = 256
HIDDEN_SIZE = 256
NUM_LAYERS = 3
DROPOUT = 0.3
LEARNING_RATE = 1e-3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
nltk.download("punkt")
dataset = load_dataset("roneneldan/TinyStories")

train_lines = dataset['train']['text'][:200000]

def preprocess(text):
    return word_tokenize(text.lower())

# tokenize
train_lines = [preprocess(line) for line in train_lines]

# building vocab
word_counts = Counter([word for text in train_lines for word in text])
vocab = (
    ["<UNK>", "<PAD>", "<EOS>", "<BOS>"] +
    [word for word, count in word_counts.items() if count > 10]
)
word_to_idx = {word: i for i, word in enumerate(vocab)}

PAD_idx = word_to_idx['<PAD>']
UNK_idx = word_to_idx['<UNK>']
EOS_idx = word_to_idx['<EOS>']
BOS_idx = word_to_idx['<BOS>']

indexed_data = [
    [BOS_idx] + [word_to_idx.get(word, UNK_idx) for word in line[:MAX_SEQ_LENGTH-2]] + [EOS_idx]
    for line in train_lines
]

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [None]:
class TinyStoriesDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch):
    batch.sort(key=lambda x: len(x), reverse=True)
    sequences, lengths = zip(*[(torch.tensor(seq), len(seq)) for seq in batch])
    sequences_padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True, padding_value=PAD_idx)
    return sequences_padded, torch.tensor(lengths)

dataset = TinyStoriesDataset(indexed_data)
dataloader = DataLoader(dataset, BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

In [None]:
class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, num_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_size,
            num_layers,
            dropout=dropout,
            batch_first=True,
        )
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.layer_norm = nn.LayerNorm(hidden_size)
        print(f"{self._count_parameters()/1e6:.2f}M parameters")

    def forward(self, x):
        emb = self.dropout(self.embedding(x))
        output, hidden = self.lstm(emb)
        output = self.layer_norm(output)
        logits = self.fc(output)
        return logits

    def _count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

model = LSTMLanguageModel(
    len(vocab), EMBED_DIM, HIDDEN_SIZE, NUM_LAYERS, DROPOUT
).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_idx)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)

6.81M parameters


In [None]:
def save_vocabulary(vocab, word_to_idx, save_dir):
    vocab_path = os.path.join(save_dir, 'vocabulary.json')
    vocab_data = {
        'vocab': vocab,
        'word_to_idx': word_to_idx
    }
    with open(vocab_path, 'w') as f:
        json.dump(vocab_data, f)
    print(f"Vocabulary saved to {vocab_path}")

In [None]:
def train(model, dataloader, criterion, optimizer, scheduler, num_epochs, save_dir='model_checkpoints'):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # save vocabulary at the start of training
    save_vocabulary(vocab, word_to_idx, save_dir)

    best_loss = float('inf')
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch, (sequence, lengths) in enumerate(dataloader):
            sequence = sequence.to(device)

            targets = sequence[:, 1:].contiguous()

            outputs = model(sequence[:, :-1])

            outputs = outputs.view(-1, outputs.size(2))
            targets = targets.view(-1)

            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
            optimizer.step()

            total_loss += loss.item()

            if batch % 300 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch+1}/{len(dataloader)}], Loss: {loss.item():.4f}')

        avg_loss = total_loss / len(dataloader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')

        scheduler.step(avg_loss)  # update learning rate

        # save the model if it's the best so far
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss,
            }, os.path.join(save_dir, 'best_model.pth'))
            print(f'Model saved with loss: {best_loss:.4f}')

        # save a checkpoint every epoch
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth'))

In [None]:
train(model, dataloader, criterion, optimizer, scheduler, num_epochs=3)

Vocabulary saved to model_checkpoints/vocabulary.json
Epoch [1/3], Batch [1/3125], Loss: 2.3645
Epoch [1/3], Batch [301/3125], Loss: 2.2720
Epoch [1/3], Batch [601/3125], Loss: 2.2532
Epoch [1/3], Batch [901/3125], Loss: 2.3526
Epoch [1/3], Batch [1201/3125], Loss: 2.3472
Epoch [1/3], Batch [1501/3125], Loss: 2.2434
Epoch [1/3], Batch [1801/3125], Loss: 2.1862
Epoch [1/3], Batch [2101/3125], Loss: 2.3557
Epoch [1/3], Batch [2401/3125], Loss: 2.3001
Epoch [1/3], Batch [2701/3125], Loss: 2.2825
Epoch [1/3], Batch [3001/3125], Loss: 2.2108
Epoch [1/3], Average Loss: 2.2630
Model saved with loss: 2.2630
Epoch [2/3], Batch [1/3125], Loss: 2.2328
Epoch [2/3], Batch [301/3125], Loss: 2.0815
Epoch [2/3], Batch [601/3125], Loss: 2.1563
Epoch [2/3], Batch [901/3125], Loss: 2.1574
Epoch [2/3], Batch [1201/3125], Loss: 2.1599
Epoch [2/3], Batch [1501/3125], Loss: 2.1940
Epoch [2/3], Batch [1801/3125], Loss: 2.0760
Epoch [2/3], Batch [2101/3125], Loss: 2.1683
Epoch [2/3], Batch [2401/3125], Loss: 2

In [None]:
def generate_text(model, start_sequence, max_length=100, temperature=0.8):
    model.eval()
    current_sequence = start_sequence
    generated_sequence = start_sequence.copy()

    with torch.inference_mode():
        for _ in range(max_length):
            input_seq = torch.tensor(
                [word_to_idx.get(word, UNK_idx) for word in current_sequence]
            ).unsqueeze(0).to(device)

            output = model(input_seq)
            last_word_logits = output[0, -1, :]

            scaled_logits = last_word_logits / temperature

            probs = F.softmax(scaled_logits, dim=0).cpu().numpy()

            # sample next word
            next_word_idx = np.random.choice(len(probs), p=probs)
            next_word = vocab[next_word_idx]
            generated_sequence.append(next_word)

            # stop if we generate an <EOS> token
            if next_word == '<EOS>':
                break

            # update current sequence
            current_sequence = current_sequence[1:] + [next_word]

    if generated_sequence[-1] == '<EOS>':
        generated_sequence = generated_sequence[:-1]

    return ' '.join(generated_sequence)

prompt = preprocess("Once upon a time")
output = generate_text(model, prompt, max_length=50)
print(output)

once upon a time , there was a man who was feeling very tired . they looked around and saw a huge , hairy dog . he was so happy when he was done , but then , something strange happened . in the end , but he did not notice it . it


In [None]:
from google.colab import drive
import shutil


drive.mount('/content/drive')

destination_dir = '/content/drive/My Drive/model_checkpoints'
if not os.path.exists(destination_dir):
    os.makedirs(destination_dir)


source_dir = '/content/model_checkpoints'
destination_dir = '/content/drive/My Drive/model_checkpoints'

# Copy all files from source to destination
for filename in os.listdir(source_dir):
    source_file = os.path.join(source_dir, filename)
    destination_file = os.path.join(destination_dir, filename)
    shutil.copy2(source_file, destination_file)
    print(f"Copied {filename} to Google Drive")