In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
import numpy as np
from tqdm import tqdm
import random
import os
from datasets import load_dataset

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super().__init__()
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# Transformer Model
class TransformerStoryModel(nn.Module):
    def __init__(self, vocab_size, d_model=384, nhead=8, num_encoder_layers=7, num_decoder_layers=7, dim_feedforward=1536, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size

        # Embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.embedding_dropout = nn.Dropout(dropout)

        # Encoder
        self.encoder_layers = nn.ModuleList([
            nn.ModuleDict({
                'norm1': nn.LayerNorm(d_model),
                'attn': nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True),
                'norm2': nn.LayerNorm(d_model),
                'ffn': nn.Sequential(
                    nn.Linear(d_model, dim_feedforward),
                    nn.GELU(),
                    nn.Linear(dim_feedforward, d_model),
                    nn.Dropout(dropout)
                )
            }) for _ in range(num_encoder_layers)
        ])

        # Decoder
        self.decoder_layers = nn.ModuleList([
            nn.ModuleDict({
                'norm1': nn.LayerNorm(d_model),
                'self_attn': nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True),
                'norm2': nn.LayerNorm(d_model),
                'cross_attn': nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True),
                'norm3': nn.LayerNorm(d_model),
                'ffn': nn.Sequential(
                    nn.Linear(d_model, dim_feedforward),
                    nn.GELU(),
                    nn.Linear(dim_feedforward, d_model),
                    nn.Dropout(dropout)
                )
            }) for _ in range(num_decoder_layers)
        ])

        # Output layer
        self.output_layer = nn.Linear(d_model, vocab_size)
        self.norm_out = nn.LayerNorm(d_model)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0, std=0.02)

    def create_mask(self, src, tgt):
        src_seq_len = src.shape[1]
        tgt_seq_len = tgt.shape[1]
        src_mask = torch.zeros((src_seq_len, src_seq_len), dtype=torch.bool, device=src.device)
        tgt_mask = self.generate_square_subsequent_mask(tgt_seq_len).to(tgt.device)
        src_padding_mask = (src == 0)
        tgt_padding_mask = (tgt == 0)
        return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, tgt):
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = self.create_mask(src, tgt)

        # Embeddings
        src_emb = self.token_embedding(src) * math.sqrt(self.d_model)
        tgt_emb = self.token_embedding(tgt) * math.sqrt(self.d_model)
        src_emb = self.positional_encoding(src_emb)
        tgt_emb = self.positional_encoding(tgt_emb)
        src_emb = self.embedding_dropout(src_emb)
        tgt_emb = self.embedding_dropout(tgt_emb)

        # Encoder
        memory = src_emb
        for layer in self.encoder_layers:
            memory = layer['norm1'](memory)
            memory = memory + layer['attn'](memory, memory, memory, attn_mask=src_mask, key_padding_mask=src_padding_mask)[0]
            memory = layer['norm2'](memory)
            memory = memory + layer['ffn'](memory)

        # Decoder
        output = tgt_emb
        for layer in self.decoder_layers:
            output = layer['norm1'](output)
            output = output + layer['self_attn'](output, output, output, attn_mask=tgt_mask, key_padding_mask=tgt_padding_mask)[0]
            output = layer['norm2'](output)
            output = output + layer['cross_attn'](output, memory, memory, key_padding_mask=src_padding_mask)[0]
            output = layer['norm3'](output)
            output = output + layer['ffn'](output)

        output = self.norm_out(output)
        return self.output_layer(output)

    def encode(self, src):
        src_emb = self.token_embedding(src) * math.sqrt(self.d_model)
        src_emb = self.positional_encoding(src_emb)
        src_emb = self.embedding_dropout(src_emb)
        src_padding_mask = (src == 0)
        memory = src_emb
        for layer in self.encoder_layers:
            memory = layer['norm1'](memory)
            memory = memory + layer['attn'](memory, memory, memory, key_padding_mask=src_padding_mask)[0]
            memory = layer['norm2'](memory)
            memory = memory + layer['ffn'](memory)
        return memory

    def decode(self, tgt, memory):
        tgt_emb = self.token_embedding(tgt) * math.sqrt(self.d_model)
        tgt_emb = self.positional_encoding(tgt_emb)
        tgt_emb = self.embedding_dropout(tgt_emb)
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        tgt_padding_mask = (tgt == 0)
        output = tgt_emb
        for layer in self.decoder_layers:
            output = layer['norm1'](output)
            output = output + layer['self_attn'](output, output, output, attn_mask=tgt_mask, key_padding_mask=tgt_padding_mask)[0]
            output = layer['norm2'](output)
            output = output + layer['cross_attn'](output, memory, memory)[0]
            output = layer['norm3'](output)
            output = output + layer['ffn'](output)
        output = self.norm_out(output)
        return self.output_layer(output)

# Tokenizer
class SimpleTokenizer:
    def __init__(self, tokenization_type='word'):
        self.word_to_idx = {}
        self.idx_to_word = {}
        self.tokenization_type = tokenization_type
        self.pad_token = '[PAD]'
        self.unk_token = '[UNK]'
        self.bos_token = '[BOS]'
        self.eos_token = '[EOS]'
        self.add_special_tokens()

    def add_special_tokens(self):
        self.word_to_idx = {self.pad_token: 0, self.unk_token: 1, self.bos_token: 2, self.eos_token: 3}
        self.idx_to_word = {v: k for k, v in self.word_to_idx.items()}

    def tokenize(self, text):
        return text.split() if self.tokenization_type == 'word' else list(text)

    def fit(self, texts):
        vocab = set()
        for text in texts:
            vocab.update(self.tokenize(text))
        for token in sorted(vocab):
            if token not in self.word_to_idx:
                self.word_to_idx[token] = len(self.word_to_idx)
                self.idx_to_word[len(self.idx_to_word)] = token

    def encode(self, text, add_special_tokens=True):
        tokens = self.tokenize(text)
        if add_special_tokens:
            tokens = [self.bos_token] + tokens + [self.eos_token]
        return [self.word_to_idx.get(token, self.word_to_idx[self.unk_token]) for token in tokens]

    def decode(self, ids, skip_special_tokens=True):
        tokens = [self.idx_to_word.get(id, self.unk_token) for id in ids]
        if skip_special_tokens:
            tokens = [token for token in tokens if token not in [self.pad_token, self.unk_token, self.bos_token, self.eos_token]]
        return ' '.join(tokens) if self.tokenization_type == 'word' else ''.join(tokens)

    def vocab_size(self):
        return len(self.word_to_idx)

# Dataset
class StoryDataset(Dataset):
    def __init__(self, stories, tokenizer, seq_length=64):
        self.tokenizer = tokenizer
        self.seq_length = seq_length
        self.tokenized_stories = []
        for story in stories:
            tokens = tokenizer.encode(story, add_special_tokens=True)
            self.tokenized_stories.append(tokens)

    def __len__(self):
        return len(self.tokenized_stories)

    def __getitem__(self, idx):
        tokens = self.tokenized_stories[idx]
        if len(tokens) <= self.seq_length + 1:
            tokens = tokens + [0] * (self.seq_length + 1 - len(tokens))
        else:
            start = random.randint(0, len(tokens) - self.seq_length - 1)
            tokens = tokens[start:start + self.seq_length + 1]
        src = torch.tensor(tokens[:-1])
        tgt = torch.tensor(tokens[1:])
        return src, tgt

# Training function
def train_model(model, train_loader, val_loader, optimizer, scheduler, device, epochs=10):
    model.train()
    best_val_loss = float('inf')
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        for src, tgt in progress_bar:
            src, tgt = src.to(device), tgt.to(device)
            optimizer.zero_grad()
            output = model(src, tgt)
            output_flat = output.view(-1, model.vocab_size)
            target_flat = tgt.contiguous().view(-1)
            loss = F.cross_entropy(output_flat, target_flat, ignore_index=0)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            epoch_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item()})
        avg_train_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs}, Training loss: {avg_train_loss:.4f}")

        model.eval()
        val_loss = 0
        progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]")
        with torch.no_grad():
            for src, tgt in progress_bar:
                src, tgt = src.to(device), tgt.to(device)
                output = model(src, tgt)
                output_flat = output.view(-1, model.vocab_size)
                target_flat = tgt.contiguous().view(-1)
                loss = F.cross_entropy(output_flat, target_flat, ignore_index=0)
                val_loss += loss.item()
                progress_bar.set_postfix({"loss": loss.item()})
        avg_val_loss = val_loss / len(val_loader)
        print(f"Epoch {epoch+1}/{epochs}, Validation loss: {avg_val_loss:.4f}")
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), "best_model.pt")
            print(f"Saved new best model with validation loss: {best_val_loss:.4f}")
    return model

# Text generation
def generate_story(model, tokenizer, prompt, max_length=100, temperature=1.0, device="cuda"):
    model.eval()
    input_ids = tokenizer.encode(prompt)
    input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
    output_ids = input_ids.copy()
    for _ in range(max_length):
        curr_input = torch.tensor([output_ids[-min(len(output_ids), model.d_model):]], dtype=torch.long).to(device)
        memory = model.encode(curr_input)
        tgt_input = torch.tensor([[output_ids[-1]]], dtype=torch.long).to(device)
        with torch.no_grad():
            output = model.decode(tgt_input, memory)
        logits = output[0, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        next_token_id = torch.multinomial(probs, 1).item()
        output_ids.append(next_token_id)
        if next_token_id == tokenizer.word_to_idx[tokenizer.eos_token]:
            break
    return tokenizer.decode(output_ids)

# Main execution
if __name__ == "__main__":
    # Check for GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load TinyStories dataset
    print("Loading TinyStories dataset...")
    ds = load_dataset("roneneldan/TinyStories")

    # Create tokenizer
    tokenizer = SimpleTokenizer(tokenization_type='word')

    # Fit tokenizer on a larger subset of the data
    print("Fitting tokenizer on dataset...")
    sample_size = 100000  # Use a larger sample for robust vocabulary
    sample_texts = [ds["train"][i]["text"] for i in range(min(sample_size, len(ds["train"]))) if i < len(ds["train"])]
    tokenizer.fit(sample_texts)

    vocab_size = tokenizer.vocab_size()
    print(f"Vocabulary size: {vocab_size}")

    # Create datasets and dataloaders
    seq_length = 128
    # Create a single dataset from a larger subset of training data
    max_samples = 100000  # Use 100,000 stories to ensure enough sequences
    stories = [ds["train"][i]["text"] for i in range(min(max_samples, len(ds["train"])))]
    dataset = StoryDataset(stories, tokenizer, seq_length=seq_length)

    # Split the dataset into train and validation sets
    train_size = int(0.8 * len(dataset))  # 80% for training
    val_size = len(dataset) - train_size  # 20% for validation
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    # Calculate required batch size for ~1000 steps
    target_steps = 1000
    estimated_dataset_size = len(train_dataset)
    batch_size = max(1, estimated_dataset_size // target_steps)  # Ensure at least 1
    if batch_size > 16:  # Cap batch size to avoid memory issues with larger model
        batch_size = 16
    print(f"Adjusted batch_size: {batch_size}, Estimated steps: {len(train_dataset) // batch_size}")

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, drop_last=False)

    # Create model
    model = TransformerStoryModel(
        vocab_size=vocab_size,
        d_model=384,      # Increased from 256
        nhead=8,          # Unchanged
        num_encoder_layers=7,  # Increased from 6
        num_decoder_layers=7,  # Increased from 6
        dim_feedforward=1536   # Increased from 1024
    ).to(device)

    # Setup optimizer and scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)  # Reduced from 0.01 for stability
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_dataloader) * 10)

    # Train model
    trained_model = train_model(
        model,
        train_dataloader,
        val_dataloader,
        optimizer,
        scheduler,
        device,
        epochs=5
    )

    # Save the final model
    output_dir = "c:/Users/saiet/JupyterNotebooks/tinystories_model/cusLLM_model"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    final_model_path = os.path.join(output_dir, "final_model.pt")
    torch.save({
        "model_state_dict": model.state_dict(),
        "tokenizer": {
            "word_to_idx": tokenizer.word_to_idx,
            "idx_to_word": tokenizer.idx_to_word,
            "tokenization_type": tokenizer.tokenization_type
        },
        "config": {
            "d_model": model.d_model,
            "vocab_size": model.vocab_size,
        }
    }, final_model_path)
    print(f"Final model saved to {final_model_path}")

    # Load and test the model
    checkpoint = torch.load(final_model_path, map_location=device)
    print(list(checkpoint.keys()))
    print(checkpoint["tokenizer"]["tokenization_type"])
    tokenizer = SimpleTokenizer(tokenization_type=checkpoint["tokenizer"]["tokenization_type"])
    tokenizer.word_to_idx = checkpoint["tokenizer"]["word_to_idx"]
    tokenizer.idx_to_word = checkpoint["tokenizer"]["idx_to_word"]

    model = TransformerStoryModel(
        vocab_size=checkpoint["config"]["vocab_size"],
        d_model=checkpoint["config"]["d_model"],
        nhead=8,
        num_encoder_layers=7,
        num_decoder_layers=7,
        dim_feedforward=1536
    ).to(device)

    model.load_state_dict(checkpoint["model_state_dict"])
    print(f"Loaded model from {final_model_path}")

    while True:
        prompt = input("Enter a prompt for text generation: ")
        generated_text = generate_story(
            model,
            tokenizer,
            prompt,
            max_length=200,
            temperature=0.8,
            device=device
        )
        print(f"Generated text:\n{generated_text}")
        choice = input("next? (y/n):")
        if choice == "n":
            break

SyntaxError: closing parenthesis ')' does not match opening parenthesis '[' (1686918023.py, line 310)