# TinyStories Middle Story Generation

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer
import torch.optim as optim
import re
from collections import Counter
import os
import sys
import time
from tqdm import tqdm

current_path = os.path.abspath('.')
project_name = 'TinyStoriesProject'
project_path = os.path.join(current_path.split(project_name)[0], project_name)
sys.path.append(project_path)
print(project_path)

  from .autonotebook import tqdm as notebook_tqdm


/Users/shawn/Documents/sjsu/2025-1/DL_CMPE258/TinyStoriesProject


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

## 1. Data Loading and Preprocessing

Load the TinyStories dataset. The dataset consists of short stories with a limited vocabulary.

In [3]:
train_dataset = load_dataset("roneneldan/TinyStories", split="train")
valid_dataset = load_dataset("roneneldan/TinyStories", split="validation")

In [4]:
print(f'total train dataset length = {len(train_dataset)}')
print(f'total valid dataset length = {len(valid_dataset)}')

total train dataset length = 2119719
total valid dataset length = 21990


## 2. Tokenization

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0')
tokenizer.pad_token = tokenizer.eos_token

In [6]:
def tokenize_function(examples):
    return tokenizer(['text'], padding='max_length', truncation=True, max_length=512)

train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
valid_dataset = valid_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

Map: 100%|██████████| 21990/21990 [00:00<00:00, 1190560.80 examples/s]


In [7]:
import torch
from torch.utils.data import Dataset

class TinyStoriesDataset(Dataset):
    def __init__(self, dataset, tokenizer):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.pad_token_id = tokenizer.pad_token_id

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]

        input_ids = torch.tensor(item['input_ids'])
        # labels = input_ids shifted left by one (next token prediction)
        labels = input_ids.clone()

        return {
            'input_ids': input_ids,
            'labels': labels
        }

In [15]:
len(train_dataset[0]['input_ids'])

512

In [10]:
from torch.utils.data import DataLoader

batch_size = 32

train_loader = DataLoader(
    TinyStoriesDataset(train_dataset, tokenizer),
    batch_size=batch_size,
    shuffle=True
)

valid_loader = DataLoader(
    TinyStoriesDataset(valid_dataset, tokenizer),
    batch_size=batch_size,
    shuffle=False
)


## 3. Model Architecture

In [11]:
import torch
import torch.nn as nn


# TransformerBlock
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim),
            nn.Dropout(dropout),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, causal_mask=None, key_padding_mask=None):
        attn_output, _ = self.attention(
            x, x, x,
            attn_mask=causal_mask,
            key_padding_mask=key_padding_mask,
            is_causal=True
        )
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

# Generate infiling attention mask
def create_infilling_attention_mask(x, blank_token_id):
    batch_size, seq_len = x.size()
    device = x.device
    attn_mask = torch.zeros(batch_size, seq_len, seq_len, device=device)

    for batch_index, input_seq in enumerate(x):
        try:
            blank_idx = (input_seq == blank_token_id).nonzero(as_tuple=True)[0].item()
        except IndexError:
            blank_idx = seq_len

        # causal mask (make attention mask -inf for words after blank)
        for i in range(blank_idx + 1, seq_len):
            for j in range(i+1, seq_len):
                attn_mask[batch_index, i, j] = float('-inf')

    return attn_mask


# DecoderOnlyTransformer
class DecoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers, num_heads, ff_dim, max_seq_length, dropout=0.1, pad_token_id=None, blank_token_id=None):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id)
        self.position_embedding = nn.Embedding(max_seq_length, embed_dim)
        self.layers = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(embed_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)

        self.pad_token_id = pad_token_id
        self.blank_token_id = blank_token_id
        self.max_seq_length = max_seq_length

    def generate_causal_mask(self, seq_len, device):
        # GPT-style causal mask (standard lower triangular mask)
        mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask
    
    def forward(self, x):
        batch_size, seq_len = x.size()
        device = x.device

        # --- Generate attention mask ---
        attn_mask = self.generate_causal_mask(seq_len, device)  # <=== causal mask를 만든다

        key_padding_mask = (x == self.pad_token_id) if self.pad_token_id is not None else None

        # --- Embedding ---
        positions = torch.arange(seq_len, device=device).unsqueeze(0)
        token_embed = self.token_embedding(x)
        pos_embed = self.position_embedding(positions)
        x = self.dropout(token_embed + pos_embed)

        # --- Transformer Blocks 통과 ---
        for layer in self.layers:
            x = layer(x, causal_mask=attn_mask, key_padding_mask=key_padding_mask)

        # --- Output ---
        logits = self.fc_out(x)  # (batch_size, seq_len, vocab_size)
        return logits

In [12]:
# from src.models import DecoderOnlyTransformer

vocab_size = tokenizer.vocab_size
embed_dim = 512
num_heads = 8
num_layers = 6
ff_dim = 2048
max_seq_length = 512

model = DecoderOnlyTransformer(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_layers=num_layers,
    num_heads=num_heads,
    ff_dim=ff_dim,
    max_seq_length=max_seq_length,
    dropout=0.1,
    pad_token_id=tokenizer.pad_token_id,
    blank_token_id=None
)

device = torch.device('cuda' if torch.cuda.is_available() else 'mps')
model = model.to(device)
model

DecoderOnlyTransformer(
  (token_embedding): Embedding(32000, 512, padding_idx=2)
  (position_embedding): Embedding(512, 512)
  (layers): ModuleList(
    (0-5): 6 x TransformerBlock(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
      )
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (feed_forward): Sequential(
        (0): Linear(in_features=512, out_features=2048, bias=True)
        (1): ReLU()
        (2): Linear(in_features=2048, out_features=512, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (fc_out): Linear(in_features=512, out_features=32000, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

## 4. Training

In [13]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)


In [14]:
from tqdm import tqdm

def train_one_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0

    loop = tqdm(dataloader, leave=True)
    for batch in loop:
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        
        # forward
        outputs = model(input_ids)  # (batch_size, seq_len, vocab_size)

        # reshape for loss
        outputs = outputs.view(-1, outputs.size(-1))  # (batch_size*seq_len, vocab_size)
        labels = labels.view(-1)                     # (batch_size*seq_len)

        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        loop.set_description(f"Train Loss {total_loss / (loop.n+1):.4f}")

    return total_loss / len(dataloader)


In [15]:
num_epochs = 3

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
    print(f"Train loss: {train_loss:.4f}")


Epoch 1/3


Train Loss nan: 100%|██████████| 67/67 [01:50<00:00,  1.65s/it]


Train loss: nan
Epoch 2/3


Train Loss nan: 100%|██████████| 67/67 [02:09<00:00,  1.93s/it]


Train loss: nan
Epoch 3/3


Train Loss nan: 100%|██████████| 67/67 [02:05<00:00,  1.87s/it]

Train loss: nan



