Simple NLP model for TinyStories dataset

In [54]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import re
from collections import Counter
import os
import sys

current_path = os.path.abspath('.')
project_name = 'TinyStoriesProject'
project_path = os.path.join(current_path.split(project_name)[0], project_name)
print(project_path)

/Users/shawn/Documents/sjsu/2025-1/DL_CMPE258/TinyStoriesProject


# Loading dataset

In [24]:
train_dataset = load_dataset("roneneldan/TinyStories", split="train")
valid_dataset = load_dataset("roneneldan/TinyStories", split="validation")

In [38]:
print(f'total train dataset length = {len(train_dataset)}')
print(f'total valid dataset length = {len(valid_dataset)}')

total train dataset length = 2119719
total valid dataset length = 21990


In [27]:
print(train_dataset[0]['text'])

One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.

Lily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."

Together, they shared the needle and sewed the button on Lily's shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.


In [34]:
max_length = 0
max_story = None
for story in train_dataset:
    length = len(story['text'].split(' '))
    if length > max_length:
        max_story = story['text']
        max_length = length

for story in valid_dataset:
    length = len(story['text'].split(' '))
    if length > max_length:
        max_story = story['text']
        max_length = length

print(max_length)
print(max_story)


959
Lily and Ben were friends. They liked to play with toys and run in the park. One day, they found a big cake on the table. It looked yummy and sweet. They wanted to eat some.

But Mom said, "No, no, no. That cake is for Grandma's birthday. You can't have any. It is not for you."

Lily and Ben were sad and angry. They did not like Mom's words. They did not want to wait for Grandma. They wanted cake now.

They had a bad idea. They decided to sneak some cake when Mom was not looking. They took a big knife and cut a slice. They put it on a plate and ran to the corner.

They took a bite of the cake. But it was not yummy and sweet. It was disgusting and bitter. It had salt and pepper and vinegar and mustard and garlic and onion and cheese and fish and pickle and soap and dirt and bugs and worms and hair and nails and glass and metal and rocks and sticks and bones and blood and poop and pee and spit and snot and pus and vomit and slime and goo and mold and rot and rust and dust and ash and

In [29]:
def simple_tokenize(text):
    # Lowercase text
    text = text.lower()
    # Replace any punctuation/special-chars with spaces
    text = re.sub(r"[^a-zA-Z0-9]+", " ", text)
    # Split on whitespace
    tokens = text.strip().split()
    return tokens

sample_token = simple_tokenize(train_dataset[0]['text'])
print(sample_token)

['one', 'day', 'a', 'little', 'girl', 'named', 'lily', 'found', 'a', 'needle', 'in', 'her', 'room', 'she', 'knew', 'it', 'was', 'difficult', 'to', 'play', 'with', 'it', 'because', 'it', 'was', 'sharp', 'lily', 'wanted', 'to', 'share', 'the', 'needle', 'with', 'her', 'mom', 'so', 'she', 'could', 'sew', 'a', 'button', 'on', 'her', 'shirt', 'lily', 'went', 'to', 'her', 'mom', 'and', 'said', 'mom', 'i', 'found', 'this', 'needle', 'can', 'you', 'share', 'it', 'with', 'me', 'and', 'sew', 'my', 'shirt', 'her', 'mom', 'smiled', 'and', 'said', 'yes', 'lily', 'we', 'can', 'share', 'the', 'needle', 'and', 'fix', 'your', 'shirt', 'together', 'they', 'shared', 'the', 'needle', 'and', 'sewed', 'the', 'button', 'on', 'lily', 's', 'shirt', 'it', 'was', 'not', 'difficult', 'for', 'them', 'because', 'they', 'were', 'sharing', 'and', 'helping', 'each', 'other', 'after', 'they', 'finished', 'lily', 'thanked', 'her', 'mom', 'for', 'sharing', 'the', 'needle', 'and', 'fixing', 'her', 'shirt', 'they', 'both',

In [47]:
def build_vocab(train_dataset, max_size=1000000):
    """
    Build a vocabulary (word->index) from the training dataset.
    We only keep the top 'max_size' tokens that appear at least 'min_freq' times.
    """
    counter = Counter()
    for example in train_dataset:
        text = example["text"]
        tokens = simple_tokenize(text)
        counter.update(tokens)

    most_common = counter.most_common(max_size)
    filtered = [(token, freq) for token, freq in most_common]
    
    # Reserve some special tokens
    # <pad> for padding, <unk> for unknown words, <bos> for beginning of sequence, <eos> for end of sequence.
    # <unk> for unkown words because there are words we don't know in validation data.
    special_tokens = ["<pad>", "<unk>", "<bos>", "<eos>"]
    
    vocab = special_tokens + [token for token, _ in filtered]
    word2idx = {w: i for i, w in enumerate(vocab)}
    
    return vocab, word2idx

In [48]:
class TinyStoriesDataset(Dataset):
    def __init__(self, dataset, word2idx, max_length=50, overlap=25):
        self.dataset = dataset
        self.word2idx = word2idx
        self.max_length = max_length
        self.overlap = overlap

        self.processed_stories = []
        
        for data in dataset:
            text = data['text']
            tokens = ['<bos>'] + simple_tokenize(text) + ['<eos>']
            token_ids = [word2idx[t] if t in word2idx else word2idx["<unk>"] for t in tokens]

            if len(token_ids) <= max_length:
                self.processed_stories.append(token_ids)
                continue
        
            # if story is too long, we split
            start_idx = 0
            while start_idx < len(token_ids):
                end_idx = min(start_idx + max_length, len(token_ids))
                chunk = token_ids[start_idx: end_idx]

                if len(chunk) < max_length // 2:
                    break
                    
                self.processed_stories.append(chunk)
                start_idx += (end_idx - overlap)
    
    def __len__(self):
        return len(self.processed_stories)
    
    def __getitem__(self, idx):
        """
        Return input_ids and target_ids for next token prediction.
        
        For a sequence [w1, w2, w3, w4],
        input_ids = [w1, w2, w3]
        target_ids = [w2, w3, w4] (the sequence shifted by 1).
        
        Padding will be handled in outer function
        """
        token_ids = self.processed_stories[idx]
        input_ids = token_ids[:-1]
        target_ids = token_ids[1:]
        
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(target_ids, dtype=torch.long)


In [49]:
def padding_fn(batch):
    input_ids_list, target_ids_list = zip(*batch)
    max_len = max(len(x) for x in input_ids_list)

    padded_inputs = []
    padded_targets = []

    for input, target in zip(input_ids_list, target_ids_list):
        pad_len = max_len - len(input)
        input_padded = torch.cat([input, torch.zeros(pad_len, dtype=torch.long)])  # possible because we used <pad> as 0
        target_padded = torch.cat([target, torch.zeros(pad_len, dtype=torch.long)])

        padded_inputs.append(input_padded)
        padded_targets.append(target_padded)

    padded_inputs = torch.stack(padded_inputs, dim=0)
    padded_targets = torch.stack(padded_targets, dim=0)

    return padded_inputs, padded_targets

            

In [50]:
vocab, word2idx = build_vocab(train_dataset)

idx2word = {i: w for w, i in word2idx.items()}
vocab_size = len(vocab)
print(f"Vocabulary size: {vocab_size}")

train_data = TinyStoriesDataset(train_dataset, word2idx, max_length=50, overlap=15)
valid_data = TinyStoriesDataset(valid_dataset, word2idx, max_length=50, overlap=15)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=padding_fn)
valid_loader = DataLoader(valid_data, batch_size=32, shuffle=False, collate_fn=padding_fn)

Vocabulary size: 42537


In [55]:
torch.save(train_data, os.path.join(project_path, "data/processed/index_encoded/train_dataset_ml50_ol15.pt"))
torch.save(valid_data, os.path.join(project_path, "data/processed/index_encoded/valid_dataset_ml50_ol15.pt"))

# Model

In [None]:

class TransformerLM(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model=256,       # embedding dimension
        n_heads=4,         # number of attention heads
        num_layers=4,      # number of Transformer blocks
        dim_feedforward=512,
        max_seq_len=512,
        dropout=0.1
    ):
        super().__init__()
        
        # Token embedding
        self.token_emb = nn.Embedding(vocab_size, d_model)
        # Positional embedding
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        
        # Transformer blocks
        encoder_layer = nn.TransformerEncoderLayer(d_model, n_heads, dim_feedforward, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        
        # Final linear layer to predict next token
        self.fc_out = nn.Linear(d_model, vocab_size)
        
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        
    def forward(self, x):
        """
        x shape: (batch_size, seq_len)
        We return logits of shape (batch_size, seq_len, vocab_size).
        """
        batch_size, seq_len = x.shape
        
        # 1) Create token embeddings
        token_embs = self.token_emb(x)  # (batch_size, seq_len, d_model)
        
        # 2) Create positional embeddings
        #    Positions: [0, 1, 2, ..., seq_len-1]
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0)  # (1, seq_len)
        pos_embs = self.pos_emb(positions)  # (1, seq_len, d_model)
        
        # 3) Sum token + positional embeddings
        embeddings = token_embs + pos_embs  # (batch_size, seq_len, d_model)
        
        # 4) Transformer Encoder expects (seq_len, batch_size, d_model)
        embeddings = embeddings.transpose(0, 1)  # (seq_len, batch_size, d_model)
        
        # 5) Generate an attention mask to prevent attending to future positions
        #    For language modeling, we typically want a causal mask (no future positions).
        #    We'll create a standard upper-triangular mask for seq_len x seq_len.
        device = x.device
        mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
        
        # 6) Pass through the Transformer encoder
        encoded = self.transformer_encoder(embeddings, mask=mask)  # (seq_len, batch_size, d_model)
        
        # 7) Project back to vocab size
        logits = self.fc_out(encoded)  # (seq_len, batch_size, vocab_size)
        
        # 8) Reshape to (batch_size, seq_len, vocab_size)
        logits = logits.transpose(0, 1)
        
        return logits