Simple NLP model for TinyStories dataset

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import re
from collections import Counter
import os
import sys
import time

current_path = os.path.abspath('.')
project_name = 'TinyStoriesProject'
project_path = os.path.join(current_path.split(project_name)[0], project_name)
sys.path.append(project_path)
print(project_path)

  from .autonotebook import tqdm as notebook_tqdm


/Users/shawn/Documents/sjsu/2025-1/DL_CMPE258/TinyStoriesProject


# Loading dataset

In [3]:
train_dataset = load_dataset("roneneldan/TinyStories", split="train")
valid_dataset = load_dataset("roneneldan/TinyStories", split="validation")

In [4]:
print(f'total train dataset length = {len(train_dataset)}')
print(f'total valid dataset length = {len(valid_dataset)}')

total train dataset length = 2119719
total valid dataset length = 21990


In [5]:
print(train_dataset[0]['text'])

One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.

Lily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."

Together, they shared the needle and sewed the button on Lily's shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.


In [6]:
max_length = 0
max_story = None
for story in train_dataset:
    length = len(story['text'].split(' '))
    if length > max_length:
        max_story = story['text']
        max_length = length

for story in valid_dataset:
    length = len(story['text'].split(' '))
    if length > max_length:
        max_story = story['text']
        max_length = length

print(max_length)
print(max_story)


959
Lily and Ben were friends. They liked to play with toys and run in the park. One day, they found a big cake on the table. It looked yummy and sweet. They wanted to eat some.

But Mom said, "No, no, no. That cake is for Grandma's birthday. You can't have any. It is not for you."

Lily and Ben were sad and angry. They did not like Mom's words. They did not want to wait for Grandma. They wanted cake now.

They had a bad idea. They decided to sneak some cake when Mom was not looking. They took a big knife and cut a slice. They put it on a plate and ran to the corner.

They took a bite of the cake. But it was not yummy and sweet. It was disgusting and bitter. It had salt and pepper and vinegar and mustard and garlic and onion and cheese and fish and pickle and soap and dirt and bugs and worms and hair and nails and glass and metal and rocks and sticks and bones and blood and poop and pee and spit and snot and pus and vomit and slime and goo and mold and rot and rust and dust and ash and

In [7]:
def simple_tokenize(text):
    # Lowercase text
    text = text.lower()
    # Replace any punctuation/special-chars with spaces
    text = re.sub(r"[^a-zA-Z0-9]+", " ", text)
    # Split on whitespace
    tokens = text.strip().split()
    return tokens

sample_token = simple_tokenize(train_dataset[0]['text'])
print(sample_token)

['one', 'day', 'a', 'little', 'girl', 'named', 'lily', 'found', 'a', 'needle', 'in', 'her', 'room', 'she', 'knew', 'it', 'was', 'difficult', 'to', 'play', 'with', 'it', 'because', 'it', 'was', 'sharp', 'lily', 'wanted', 'to', 'share', 'the', 'needle', 'with', 'her', 'mom', 'so', 'she', 'could', 'sew', 'a', 'button', 'on', 'her', 'shirt', 'lily', 'went', 'to', 'her', 'mom', 'and', 'said', 'mom', 'i', 'found', 'this', 'needle', 'can', 'you', 'share', 'it', 'with', 'me', 'and', 'sew', 'my', 'shirt', 'her', 'mom', 'smiled', 'and', 'said', 'yes', 'lily', 'we', 'can', 'share', 'the', 'needle', 'and', 'fix', 'your', 'shirt', 'together', 'they', 'shared', 'the', 'needle', 'and', 'sewed', 'the', 'button', 'on', 'lily', 's', 'shirt', 'it', 'was', 'not', 'difficult', 'for', 'them', 'because', 'they', 'were', 'sharing', 'and', 'helping', 'each', 'other', 'after', 'they', 'finished', 'lily', 'thanked', 'her', 'mom', 'for', 'sharing', 'the', 'needle', 'and', 'fixing', 'her', 'shirt', 'they', 'both',

In [8]:
def build_vocab(train_dataset, max_size=1000000):
    """
    Build a vocabulary (word->index) from the training dataset.
    We only keep the top 'max_size' tokens that appear at least 'min_freq' times.
    """
    counter = Counter()
    for example in train_dataset:
        text = example["text"]
        tokens = simple_tokenize(text)
        counter.update(tokens)

    most_common = counter.most_common(max_size)
    filtered = [(token, freq) for token, freq in most_common]
    
    # Reserve some special tokens
    # <pad> for padding, <unk> for unknown words, <bos> for beginning of sequence, <eos> for end of sequence.
    # <unk> for unkown words because there are words we don't know in validation data.
    special_tokens = ["<pad>", "<unk>", "<bos>", "<eos>"]
    
    vocab = special_tokens + [token for token, _ in filtered]
    word2idx = {w: i for i, w in enumerate(vocab)}
    
    return vocab, word2idx

In [9]:
class TinyStoriesDataset(Dataset):
    def __init__(self, dataset, word2idx, max_length=50, overlap=25):
        self.dataset = dataset
        self.word2idx = word2idx
        self.max_length = max_length
        self.overlap = overlap

        self.processed_stories = []
        
        for data in dataset:
            text = data['text']
            tokens = ['<bos>'] + simple_tokenize(text) + ['<eos>']
            token_ids = [word2idx[t] if t in word2idx else word2idx["<unk>"] for t in tokens]

            if len(token_ids) <= max_length:
                self.processed_stories.append(token_ids)
                continue
        
            # if story is too long, we split
            start_idx = 0
            while start_idx < len(token_ids):
                end_idx = min(start_idx + max_length, len(token_ids))
                chunk = token_ids[start_idx: end_idx]

                if len(chunk) < max_length // 2:
                    break
                    
                self.processed_stories.append(chunk)
                start_idx += (end_idx - overlap)
    
    def __len__(self):
        return len(self.processed_stories)
    
    def __getitem__(self, idx):
        """
        Return input_ids and target_ids for next token prediction.
        
        For a sequence [w1, w2, w3, w4],
        input_ids = [w1, w2, w3]
        target_ids = [w2, w3, w4] (the sequence shifted by 1).
        
        Padding will be handled in outer function
        """
        token_ids = self.processed_stories[idx]
        input_ids = token_ids[:-1]
        target_ids = token_ids[1:]
        
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(target_ids, dtype=torch.long)


In [10]:
class SimpleTinyStoriesDataset(Dataset):
    """
    Not using sliding windows
    """
    def __init__(self, dataset, word2idx, max_length=50):
        self.dataset = dataset
        self.word2idx = word2idx
        self.max_length = max_length

        self.processed_sequences = []
        for data in dataset:
            text = data['text']
            tokens = ['<bos>'] + simple_tokenize(text) + ['<eos>']
            token_ids = [word2idx[t] if t in word2idx else word2idx["<unk>"] for t in tokens]
            token_ids = token_ids[:self.max_length]
            
            self.processed_sequences.append(token_ids)

    def __len__(self):
        return len(self.processed_sequences)
    
    def __getitem__(self, idx):
        token_ids = self.processed_sequences[idx]
        input_ids = token_ids[:-1]
        target_ids = token_ids[1:]

        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(target_ids, dtype=torch.long)


In [11]:
def padding_fn(batch):
    input_ids_list, target_ids_list = zip(*batch)
    max_len = max(len(x) for x in input_ids_list)

    padded_inputs = []
    padded_targets = []

    for input, target in zip(input_ids_list, target_ids_list):
        pad_len = max_len - len(input)
        input_padded = torch.cat([input, torch.zeros(pad_len, dtype=torch.long)])  # possible because we used <pad> as 0
        target_padded = torch.cat([target, torch.zeros(pad_len, dtype=torch.long)])

        padded_inputs.append(input_padded)
        padded_targets.append(target_padded)

    padded_inputs = torch.stack(padded_inputs, dim=0)
    padded_targets = torch.stack(padded_targets, dim=0)

    return padded_inputs, padded_targets

            

In [12]:
t = time.time()
vocab, word2idx = build_vocab(train_dataset)

idx2word = {i: w for w, i in word2idx.items()}
vocab_size = len(vocab)
print(f"Vocabulary size: {vocab_size}")

train_data = TinyStoriesDataset(train_dataset, word2idx, max_length=50, overlap=15)
valid_data = TinyStoriesDataset(valid_dataset, word2idx, max_length=50, overlap=15)

print(f"Dataset time: {time.time() - t:.2f}")
t = time.time()

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=padding_fn, num_workers=0)
valid_loader = DataLoader(valid_data, batch_size=32, shuffle=False, collate_fn=padding_fn, num_workers=0)

print(f"Data loader time: {time.time() - t:.2f}")

Vocabulary size: 42537
Dataset time: 242.64
Data loader time: 0.00


In [13]:
t = time.time()
vocab, word2idx = build_vocab(train_dataset)

idx2word = {i: w for w, i in word2idx.items()}
vocab_size = len(vocab)
print(f"Vocabulary size: {vocab_size}")

simple_train_data = SimpleTinyStoriesDataset(train_dataset, word2idx, max_length=50)
simple_valid_data = SimpleTinyStoriesDataset(valid_dataset, word2idx, max_length=50)

print(f"Dataset time: {time.time() - t:.2f}")
t = time.time()

simple_train_loader = DataLoader(simple_train_data, batch_size=32, shuffle=True, collate_fn=padding_fn, num_workers=0)
simple_valid_loader = DataLoader(simple_valid_data, batch_size=32, shuffle=False, collate_fn=padding_fn, num_workers=0)

print(f"Data loader time: {time.time() - t:.2f}")

Vocabulary size: 42537
Dataset time: 268.97
Data loader time: 0.00


# Model

In [14]:

class TransformerDecoderLM(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model=256,       # embedding dimension
        n_heads=4,         # number of attention heads
        num_layers=4,      # number of Transformer blocks
        dim_feedforward=512,
        max_seq_len=512,
        dropout=0.1
    ):
        super().__init__()
        
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)

        self.layers = nn.ModuleList([
            nn.TransformerDecoderLayer(d_model, n_heads, dim_feedforward, dropout)
            for _ in range(num_layers)
        ])
        
        self.fc_out = nn.Linear(d_model, vocab_size)

        self.d_model = d_model
        self.max_seq_len = max_seq_len
        
    def forward(self, input_ids):
        batch_size, seq_len = input_ids.shape
        
        token_embs = self.token_emb(input_ids)  # (batch_size, seq_len, d_model)

        positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0)
        pos_embs = self.pos_emb(positions)  # (1, seq_len, d_model)

        embeddings = token_embs + pos_embs  # (batch_size, seq_len, d_model)
        embeddings = embeddings.transpose(0, 1)

        mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1).bool()
        for layer in self.layers:
            embeddings = layer(embeddings, embeddings, tgt_mask=mask)

        logits = self.fc_out(embeddings).transpose(0, 1)  # (batch_size, seq_len, vocab_size)

        return logits

In [15]:
model = TransformerDecoderLM(vocab_size=vocab_size, d_model=128, n_heads=2, num_layers=2, dim_feedforward=128)
params_cnt = sum(param.numel() for param in model.parameters())
print(params_cnt)


11329321


Total paramters are 24M, which is a good amount for small language model

In [16]:
def train_one_epoch(model, dataloader, optimizer, device, vocab_size):
    t = time.time()
    model.train()
    total_loss = 0
    for i, batch in enumerate(dataloader):
        input_ids, target_ids = [b.to(device) for b in batch]
        logits = model(input_ids)
        logits = logits.reshape(-1, vocab_size)
        targets = target_ids.reshape(-1)

        loss = F.cross_entropy(logits, targets, ignore_index=0)
        total_loss += loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 10000 == 0:
            print(f"batch {i}: {time.time() - t:.2f}")
            t = time.time()
    
    return total_loss / len(dataloader)

def evaluate(model, dataloader, device, vocab_size):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids, target_ids = [b.to(device) for b in batch]
            
            logits = model(input_ids)  # (batch_size, seq_len, vocab_size)
            logits = logits.reshape(-1, vocab_size)
            targets_reshaped = target_ids.reshape(-1)
            
            loss = F.cross_entropy(logits, targets_reshaped, ignore_index=0)
            total_loss += loss.item()
    
    return total_loss / len(dataloader)


In [17]:

def generate_text(model, prompt, word2idx, idx2word, max_new_tokens=30, device="cpu"):
    """
    model:        Our language model
    prompt:       A string prompt
    word2idx:     The vocabulary mapping from tokens to indices
    idx2word:     The inverse vocabulary mapping
    max_new_tokens: The maximum number of tokens to generate
    device:       "cpu" or "cuda"
    """
    model.eval()
    tokens = simple_tokenize(prompt.lower())
    
    # Convert to token IDs
    input_ids = [word2idx["<bos>"]]  # start with <bos>
    for t in tokens:
        input_ids.append(word2idx[t] if t in word2idx else word2idx["<unk>"])
    
    # Move to tensor
    input_ids = torch.tensor([input_ids], dtype=torch.long, device=device)
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # Get logits for the current sequence
            logits = model(input_ids)  # (batch_size=1, seq_len, vocab_size)
            
            # Take the last timestep
            logits_last = logits[:, -1, :]  # (1, vocab_size)
            
            # Greedy pick
            next_token_id = torch.argmax(logits_last, dim=-1)  # (1,)
            
            # Append the predicted token
            input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)  # shape: (1, seq_len+1)
            
            # If we hit <eos>, we can stop early
            if next_token_id.item() == word2idx["<eos>"]:
                break
    
    # Convert token IDs back to text (skipping initial <bos>)
    generated_ids = input_ids[0].tolist()
    # The first token is <bos>, skip that for printing
    generated_ids = generated_ids[1:]
    
    # Stop if we see <eos>
    if word2idx["<eos>"] in generated_ids:
        eos_pos = generated_ids.index(word2idx["<eos>"])
        generated_ids = generated_ids[:eos_pos]
    
    generated_words = [idx2word[idx] for idx in generated_ids]
    
    return " ".join(generated_words)

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
model = TransformerDecoderLM(vocab_size=vocab_size, d_model=256, n_heads=4, num_layers=4, dim_feedforward=512)
model = model.to(device)

# Define an optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# Training loop
num_epochs = 10  # For demonstration; adjust as needed
for epoch in range(num_epochs):
    print(f"===============EPOCH {epoch:03} TRAINIG=================")
    train_loss = train_one_epoch(model, train_loader, optimizer, device, vocab_size)
    val_loss = evaluate(model, valid_loader, device, vocab_size)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {val_loss:.4f}")


batch 0: 1.15
batch 10000: 4304.00
batch 20000: 2603.23
batch 30000: 1226.73
batch 40000: 1057.74
batch 50000: 1061.85
batch 60000: 2758.75
batch 70000: 1084.28
batch 80000: 1086.91
batch 90000: 1079.12
batch 100000: 1058.44
batch 110000: 1054.31
batch 120000: 1078.14
batch 130000: 1074.79
batch 140000: 1087.31
batch 150000: 7048.28
batch 160000: 7391.83
batch 170000: 6186.90
batch 180000: 6337.92
batch 190000: 5771.79
Epoch 1/10, Train Loss: 0.1072, Valid Loss: 0.0677
batch 0: 0.27
batch 10000: 1187.35
batch 20000: 4093.66
batch 30000: 1068.52
batch 40000: 1053.23
batch 50000: 1058.50
batch 60000: 1849.67
batch 70000: 1235.81
batch 80000: 1063.00
batch 90000: 1061.04


KeyboardInterrupt: 

In [19]:
prompt = "Once upon a time"
generated = generate_text(model, prompt, word2idx, idx2word, max_new_tokens=30, device=device)
print("Prompt:", prompt)
print("Generated:", generated)

Prompt: Once upon a time
Generated: once upon a time once upon time once upon time once upon time once upon time time time once upon time time once upon time time time time time time time time time time
