In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
import numpy as np
from tqdm import tqdm
import random
import os
from datasets import load_dataset

In [2]:
# Custom Transformer implementation
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super().__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        # Register as buffer (not a parameter but part of the module)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class CustomTransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=8, num_encoder_layers=6, 
                 num_decoder_layers=6, dim_feedforward=1024, dropout=0.1):
        super().__init__()
        
        # Embedding layers
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        
        # Transformer architecture
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, 
                                                  dim_feedforward=dim_feedforward, 
                                                  dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead,
                                                  dim_feedforward=dim_feedforward,
                                                  dropout=dropout, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
        
        # Output layer
        self.output_layer = nn.Linear(d_model, vocab_size)
        
        self.d_model = d_model
        self.vocab_size = vocab_size
        
    def create_mask(self, src, tgt):
        src_seq_len = src.shape[1]
        tgt_seq_len = tgt.shape[1]
        
        # Create masks
        src_mask = torch.zeros((src_seq_len, src_seq_len), dtype=torch.bool, device=src.device)
        tgt_mask = self.generate_square_subsequent_mask(tgt_seq_len).to(tgt.device)
        
        # Create padding masks
        src_padding_mask = (src == 0)
        tgt_padding_mask = (tgt == 0)
        
        return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
    
    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def forward(self, src, tgt):
        # Embedding and positional encoding
        src_emb = self.positional_encoding(self.token_embedding(src) * math.sqrt(self.d_model))
        tgt_emb = self.positional_encoding(self.token_embedding(tgt) * math.sqrt(self.d_model))
        
        # Create masks
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = self.create_mask(src, tgt)
        
        # Transformer encoding and decoding
        memory = self.transformer_encoder(src_emb, src_key_padding_mask=src_padding_mask)
        output = self.transformer_decoder(tgt_emb, memory, 
                                         tgt_mask=tgt_mask,
                                         tgt_key_padding_mask=tgt_padding_mask,
                                         memory_key_padding_mask=src_padding_mask)
        
        # Project to vocabulary size
        return self.output_layer(output)
    
    def encode(self, src):
        src_emb = self.positional_encoding(self.token_embedding(src) * math.sqrt(self.d_model))
        src_padding_mask = (src == 0)
        return self.transformer_encoder(src_emb, src_key_padding_mask=src_padding_mask)
    
    def decode(self, tgt, memory):
        tgt_emb = self.positional_encoding(self.token_embedding(tgt) * math.sqrt(self.d_model))
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        tgt_padding_mask = (tgt == 0)
        
        output = self.transformer_decoder(tgt_emb, memory, tgt_mask=tgt_mask)
        return self.output_layer(output)

# Custom tokenizer
class SimpleTokenizer:
    def __init__(self, tokenization_type='word'):
        self.word_to_idx = {}
        self.idx_to_word = {}
        self.tokenization_type = tokenization_type  # 'word' or 'char'
        
        # Special tokens
        self.pad_token = '[PAD]'
        self.unk_token = '[UNK]'
        self.bos_token = '[BOS]'
        self.eos_token = '[EOS]'
        
        # Add special tokens
        self.add_special_tokens()
    
    def add_special_tokens(self):
        self.word_to_idx = {
            self.pad_token: 0,
            self.unk_token: 1,
            self.bos_token: 2,
            self.eos_token: 3
        }
        self.idx_to_word = {v: k for k, v in self.word_to_idx.items()}
    
    def tokenize(self, text):
        if self.tokenization_type == 'word':
            return text.split()
        else:  # char tokenization
            return list(text)
    
    def fit(self, texts):
        vocab = set()
        
        # Extract all unique tokens
        for text in texts:
            tokens = self.tokenize(text)
            vocab.update(tokens)
        
        # Add tokens to vocabulary
        for token in sorted(vocab):
            if token not in self.word_to_idx:
                self.word_to_idx[token] = len(self.word_to_idx)
                self.idx_to_word[len(self.idx_to_word)] = token
    
    def encode(self, text, add_special_tokens=True):
        tokens = self.tokenize(text)
        
        # Add special tokens if needed
        if add_special_tokens:
            tokens = [self.bos_token] + tokens + [self.eos_token]
        
        # Convert tokens to IDs
        ids = [self.word_to_idx.get(token, self.word_to_idx[self.unk_token]) for token in tokens]
        return ids
    
    def decode(self, ids, skip_special_tokens=True):
        tokens = [self.idx_to_word.get(id, self.unk_token) for id in ids]
        
        # Remove special tokens if needed
        if skip_special_tokens:
            tokens = [token for token in tokens if token not in [self.pad_token, self.unk_token, self.bos_token, self.eos_token]]
        
        # Join tokens
        if self.tokenization_type == 'word':
            return ' '.join(tokens)
        else:  # char tokenization
            return ''.join(tokens)
    
    def vocab_size(self):
        return len(self.word_to_idx)

# Dataset for text generation using TinyStories
class TinyStoriesDataset(Dataset):
    def __init__(self, dataset, tokenizer, seq_length=64, split="train"):
        self.tokenizer = tokenizer
        self.seq_length = seq_length
        self.dataset = dataset[split]
        
        # Tokenize the stories
        print(f"Tokenizing {split} dataset...")
        self.tokenized_stories = []
        
        # Process a subset for faster training (adjust as needed)
        num_samples = min(10000, len(self.dataset))
        for i in tqdm(range(num_samples)):
            story = self.dataset[i]["text"]
            tokens = tokenizer.encode(story, add_special_tokens=True)
            self.tokenized_stories.append(tokens)
    
    def __len__(self):
        return len(self.tokenized_stories)
    
    def __getitem__(self, idx):
        tokens = self.tokenized_stories[idx]
        
        # Ensure tokens are the right length
        if len(tokens) <= self.seq_length + 1:
            # Pad to sequence length
            tokens = tokens + [0] * (self.seq_length + 1 - len(tokens))
        else:
            # Choose a random starting point to fit sequence length
            start = random.randint(0, len(tokens) - self.seq_length - 1)
            tokens = tokens[start:start + self.seq_length + 1]
        
        src = torch.tensor(tokens[:-1])
        tgt = torch.tensor(tokens[1:])
        
        return src, tgt

# Training function
def train_custom_model(model, train_dataloader, val_dataloader, optimizer, scheduler, device, epochs=10):
    model.train()
    best_val_loss = float('inf')
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        epoch_loss = 0
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        
        for batch_idx, (src, tgt) in enumerate(progress_bar):
            src, tgt = src.to(device), tgt.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            output = model(src, src)  # Using teacher forcing
            
            # Reshape output and target for loss calculation
            output_flat = output.view(-1, model.vocab_size)
            target_flat = tgt.contiguous().view(-1)
            
            # Calculate loss
            loss = F.cross_entropy(output_flat, target_flat, ignore_index=0)
            
            # Backward pass and optimize
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            # Update loss
            epoch_loss += loss.item()
            
            # Update progress bar
            progress_bar.set_postfix({"loss": loss.item()})
        
        avg_train_loss = epoch_loss / len(train_dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Training loss: {avg_train_loss:.4f}")
        
        # Validation phase
        model.eval()
        val_loss = 0
        progress_bar = tqdm(val_dataloader, desc=f"Epoch {epoch+1}/{epochs} [Val]")
        
        with torch.no_grad():
            for batch_idx, (src, tgt) in enumerate(progress_bar):
                src, tgt = src.to(device), tgt.to(device)
                
                # Forward pass
                output = model(src, src)
                
                # Calculate loss
                output_flat = output.view(-1, model.vocab_size)
                target_flat = tgt.contiguous().view(-1)
                loss = F.cross_entropy(output_flat, target_flat, ignore_index=0)
                
                # Update loss
                val_loss += loss.item()
                
                # Update progress bar
                progress_bar.set_postfix({"loss": loss.item()})
        
        avg_val_loss = val_loss / len(val_dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Validation loss: {avg_val_loss:.4f}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "val_loss": best_val_loss,
            }, "best_model.pt")
            print(f"Saved new best model with validation loss: {best_val_loss:.4f}")
    
    return model

# Text generation function
def generate_text(model, tokenizer, prompt, max_length=100, temperature=1.0, device="cuda"):
    model.eval()
    
    # Encode prompt
    input_ids = tokenizer.encode(prompt)
    input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
    
    # Initialize output sequence
    output_ids = input_ids.copy()
    
    # Generate one token at a time
    for _ in range(max_length):
        # Prepare input (truncate if too long)
        curr_input = torch.tensor([output_ids[-min(len(output_ids), model.d_model):]], dtype=torch.long).to(device)
        
        # Generate memory from encoder
        memory = model.encode(curr_input)
        
        # Create target input (last token)
        tgt_input = torch.tensor([[output_ids[-1]]], dtype=torch.long).to(device)
        
        # Get prediction
        with torch.no_grad():
            output = model.decode(tgt_input, memory)
            
        # Apply temperature and get probabilities
        logits = output[0, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        
        # Sample from the distribution
        next_token_id = torch.multinomial(probs, 1).item()
        
        # Add to output sequence
        output_ids.append(next_token_id)
        
        # Stop if end of sequence
        if next_token_id == tokenizer.word_to_idx[tokenizer.eos_token]:
            break
    
    # Decode output sequence
    generated_text = tokenizer.decode(output_ids)
    return generated_text

# Main function
def main():
    # Check for GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load TinyStories dataset
    print("Loading TinyStories dataset...")
    ds = load_dataset("roneneldan/TinyStories")
    
    # Create tokenizer
    tokenizer = SimpleTokenizer(tokenization_type='word')
    
    # Fit tokenizer on a subset of the data
    print("Fitting tokenizer on dataset...")
    sample_size = 10000  # Adjust based on your needs
    sample_texts = [ds["train"][i]["text"] for i in range(sample_size)]
    tokenizer.fit(sample_texts)
    
    vocab_size = tokenizer.vocab_size()
    print(f"Vocabulary size: {vocab_size}")
    
    # Create datasets and dataloaders
    seq_length = 128
    train_dataset = TinyStoriesDataset(ds, tokenizer, seq_length=seq_length, split="train")
    val_dataset = TinyStoriesDataset(ds, tokenizer, seq_length=seq_length, split="validation")
    
    batch_size = 32  # Adjust based on your GPU memory
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
    
    # Create model
    model = CustomTransformerModel(
        vocab_size=vocab_size,
        d_model=256,      # Embedding dimension
        nhead=8,          # Number of attention heads
        num_encoder_layers=4,
        num_decoder_layers=4,
        dim_feedforward=1024
    ).to(device)
    
    # Setup optimizer and scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_dataloader) * 10)
    
    # Train model
    trained_model = train_custom_model(
        model, 
        train_dataloader,
        val_dataloader,
        optimizer,
        scheduler,
        device,
        epochs=50 # Adjust based on your needs
    )
    
    # Load best model
    checkpoint = torch.load("best_model.pt")
    model.load_state_dict(checkpoint["model_state_dict"])
    
    # Generate text
    prompt = "Once upon a time"
    generated_text = generate_text(
        model,
        tokenizer,
        prompt,
        max_length=200,
        temperature=0.8,
        device=device
    )
    
    print(f"Generated text:\n{generated_text}")
    
    # Save final model
    output_dir = "./tinystories_model"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    torch.save({
        "model_state_dict": model.state_dict(),
        "tokenizer": {
            "word_to_idx": tokenizer.word_to_idx,
            "idx_to_word": tokenizer.idx_to_word,
            "tokenization_type": tokenizer.tokenization_type
        },
        "config": {
            "d_model": model.d_model,
            "vocab_size": model.vocab_size,
        }
    }, f"{output_dir}/model.pt")
    print(f"Model saved to {output_dir}/model.pt")

if __name__ == "__main__":
    main()

Using device: cuda
Loading TinyStories dataset...


README.md: 0.00B [00:00, ?B/s]

(…)-00000-of-00004-2d5a1467fff1081b.parquet:   0%|          | 0.00/249M [00:00<?, ?B/s]

(…)-00001-of-00004-5852b56a2bd28fd9.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00002-of-00004-a26307300439e943.parquet:   0%|          | 0.00/246M [00:00<?, ?B/s]

(…)-00003-of-00004-d243063613e5a057.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00000-of-00001-869c898b519ad725.parquet:   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

Fitting tokenizer on dataset...
Vocabulary size: 30348
Tokenizing train dataset...


100%|██████████| 10000/10000 [00:01<00:00, 7924.24it/s]


Tokenizing validation dataset...


100%|██████████| 10000/10000 [00:01<00:00, 8484.73it/s]
Epoch 1/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.74it/s, loss=6.77]


Epoch 1/50, Training loss: 6.8258


  output = torch._nested_tensor_from_mask(
Epoch 1/50 [Val]: 100%|██████████| 313/313 [00:12<00:00, 25.90it/s, loss=6.47]


Epoch 1/50, Validation loss: 6.7199
Saved new best model with validation loss: 6.7199


Epoch 2/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.97it/s, loss=6.72]


Epoch 2/50, Training loss: 6.7134


Epoch 2/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.50it/s, loss=6.44]


Epoch 2/50, Validation loss: 6.6975
Saved new best model with validation loss: 6.6975


Epoch 3/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.97it/s, loss=6.79]


Epoch 3/50, Training loss: 6.7027


Epoch 3/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.73it/s, loss=6.43]


Epoch 3/50, Validation loss: 6.6924
Saved new best model with validation loss: 6.6924


Epoch 4/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.95it/s, loss=6.58]


Epoch 4/50, Training loss: 6.6951


Epoch 4/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.45it/s, loss=6.4]


Epoch 4/50, Validation loss: 6.6839
Saved new best model with validation loss: 6.6839


Epoch 5/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.95it/s, loss=6.74]


Epoch 5/50, Training loss: 6.6873


Epoch 5/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.20it/s, loss=6.41]


Epoch 5/50, Validation loss: 6.6864


Epoch 6/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.94it/s, loss=6.71]


Epoch 6/50, Training loss: 6.6814


Epoch 6/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.31it/s, loss=6.43]


Epoch 6/50, Validation loss: 6.6843


Epoch 7/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.94it/s, loss=6.51]


Epoch 7/50, Training loss: 6.6766


Epoch 7/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.39it/s, loss=6.42]


Epoch 7/50, Validation loss: 6.6758
Saved new best model with validation loss: 6.6758


Epoch 8/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.94it/s, loss=6.82]


Epoch 8/50, Training loss: 6.6704


Epoch 8/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.65it/s, loss=6.42]


Epoch 8/50, Validation loss: 6.6783


Epoch 9/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.95it/s, loss=6.82]


Epoch 9/50, Training loss: 6.6655


Epoch 9/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.39it/s, loss=6.42]


Epoch 9/50, Validation loss: 6.6768


Epoch 10/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.95it/s, loss=6.69]


Epoch 10/50, Training loss: 6.6620


Epoch 10/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.54it/s, loss=6.42]


Epoch 10/50, Validation loss: 6.6779


Epoch 11/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.94it/s, loss=6.72]


Epoch 11/50, Training loss: 6.6639


Epoch 11/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.24it/s, loss=6.4]


Epoch 11/50, Validation loss: 6.6787


Epoch 12/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.95it/s, loss=6.62]


Epoch 12/50, Training loss: 6.6636


Epoch 12/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.42it/s, loss=6.42]


Epoch 12/50, Validation loss: 6.6832


Epoch 13/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.94it/s, loss=6.69]


Epoch 13/50, Training loss: 6.6699


Epoch 13/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.49it/s, loss=6.45]


Epoch 13/50, Validation loss: 6.6820


Epoch 14/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.95it/s, loss=6.66]


Epoch 14/50, Training loss: 6.6762


Epoch 14/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.47it/s, loss=6.41]


Epoch 14/50, Validation loss: 6.6851


Epoch 15/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.95it/s, loss=6.74]


Epoch 15/50, Training loss: 6.6839


Epoch 15/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.44it/s, loss=6.43]


Epoch 15/50, Validation loss: 6.6933


Epoch 16/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.96it/s, loss=6.65]


Epoch 16/50, Training loss: 6.6858


Epoch 16/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.70it/s, loss=6.42]


Epoch 16/50, Validation loss: 6.6939


Epoch 17/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.97it/s, loss=6.7]


Epoch 17/50, Training loss: 6.6915


Epoch 17/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 27.08it/s, loss=6.42]


Epoch 17/50, Validation loss: 6.6982


Epoch 18/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.97it/s, loss=6.6]


Epoch 18/50, Training loss: 6.6911


Epoch 18/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.71it/s, loss=6.41]


Epoch 18/50, Validation loss: 6.7050


Epoch 19/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.95it/s, loss=6.77]


Epoch 19/50, Training loss: 6.6928


Epoch 19/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.68it/s, loss=6.44]


Epoch 19/50, Validation loss: 6.7014


Epoch 20/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.97it/s, loss=6.64]


Epoch 20/50, Training loss: 6.6918


Epoch 20/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.67it/s, loss=6.45]


Epoch 20/50, Validation loss: 6.7030


Epoch 21/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.96it/s, loss=6.86]


Epoch 21/50, Training loss: 6.6912


Epoch 21/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.24it/s, loss=6.43]


Epoch 21/50, Validation loss: 6.6982


Epoch 22/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.92it/s, loss=6.76]


Epoch 22/50, Training loss: 6.6897


Epoch 22/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.30it/s, loss=6.45]


Epoch 22/50, Validation loss: 6.7078


Epoch 23/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.93it/s, loss=6.8]


Epoch 23/50, Training loss: 6.6842


Epoch 23/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.55it/s, loss=6.46]


Epoch 23/50, Validation loss: 6.7022


Epoch 24/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.93it/s, loss=6.73]


Epoch 24/50, Training loss: 6.6805


Epoch 24/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.25it/s, loss=6.45]


Epoch 24/50, Validation loss: 6.7020


Epoch 25/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.96it/s, loss=6.61]


Epoch 25/50, Training loss: 6.6760


Epoch 25/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.57it/s, loss=6.45]


Epoch 25/50, Validation loss: 6.6989


Epoch 26/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.95it/s, loss=6.72]


Epoch 26/50, Training loss: 6.6743


Epoch 26/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.48it/s, loss=6.42]


Epoch 26/50, Validation loss: 6.6967


Epoch 27/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.95it/s, loss=6.74]


Epoch 27/50, Training loss: 6.6694


Epoch 27/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.59it/s, loss=6.42]


Epoch 27/50, Validation loss: 6.6963


Epoch 28/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.95it/s, loss=6.73]


Epoch 28/50, Training loss: 6.6645


Epoch 28/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.41it/s, loss=6.42]


Epoch 28/50, Validation loss: 6.6953


Epoch 29/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.96it/s, loss=6.8]


Epoch 29/50, Training loss: 6.6618


Epoch 29/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.76it/s, loss=6.42]


Epoch 29/50, Validation loss: 6.6945


Epoch 30/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.95it/s, loss=6.84]


Epoch 30/50, Training loss: 6.6623


Epoch 30/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.45it/s, loss=6.43]


Epoch 30/50, Validation loss: 6.6937


Epoch 31/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.95it/s, loss=6.49]


Epoch 31/50, Training loss: 6.6598


Epoch 31/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.92it/s, loss=6.41]


Epoch 31/50, Validation loss: 6.6972


Epoch 32/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.96it/s, loss=6.63]


Epoch 32/50, Training loss: 6.6627


Epoch 32/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.34it/s, loss=6.45]


Epoch 32/50, Validation loss: 6.6963


Epoch 33/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.95it/s, loss=6.67]


Epoch 33/50, Training loss: 6.6663


Epoch 33/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.58it/s, loss=6.44]


Epoch 33/50, Validation loss: 6.6983


Epoch 34/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.97it/s, loss=6.61]


Epoch 34/50, Training loss: 6.6675


Epoch 34/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.37it/s, loss=6.44]


Epoch 34/50, Validation loss: 6.6999


Epoch 35/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.94it/s, loss=6.69]


Epoch 35/50, Training loss: 6.6745


Epoch 35/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.36it/s, loss=6.42]


Epoch 35/50, Validation loss: 6.7036


Epoch 36/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.91it/s, loss=6.6]


Epoch 36/50, Training loss: 6.6766


Epoch 36/50 [Val]: 100%|██████████| 313/313 [00:12<00:00, 25.65it/s, loss=6.42]


Epoch 36/50, Validation loss: 6.7052


Epoch 37/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.91it/s, loss=6.63]


Epoch 37/50, Training loss: 6.6816


Epoch 37/50 [Val]: 100%|██████████| 313/313 [00:12<00:00, 25.83it/s, loss=6.43]


Epoch 37/50, Validation loss: 6.7085


Epoch 38/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.91it/s, loss=6.71]


Epoch 38/50, Training loss: 6.6832


Epoch 38/50 [Val]: 100%|██████████| 313/313 [00:12<00:00, 25.80it/s, loss=6.44]


Epoch 38/50, Validation loss: 6.7104


Epoch 39/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.91it/s, loss=6.69]


Epoch 39/50, Training loss: 6.6865


Epoch 39/50 [Val]: 100%|██████████| 313/313 [00:12<00:00, 25.72it/s, loss=6.45]


Epoch 39/50, Validation loss: 6.7115


Epoch 40/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.91it/s, loss=6.83]


Epoch 40/50, Training loss: 6.6861


Epoch 40/50 [Val]: 100%|██████████| 313/313 [00:12<00:00, 25.80it/s, loss=6.44]


Epoch 40/50, Validation loss: 6.7122


Epoch 41/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.92it/s, loss=6.68]


Epoch 41/50, Training loss: 6.6850


Epoch 41/50 [Val]: 100%|██████████| 313/313 [00:12<00:00, 26.08it/s, loss=6.45]


Epoch 41/50, Validation loss: 6.7118


Epoch 42/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.92it/s, loss=6.66]


Epoch 42/50, Training loss: 6.6824


Epoch 42/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.25it/s, loss=6.44]


Epoch 42/50, Validation loss: 6.7147


Epoch 43/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.92it/s, loss=6.63]


Epoch 43/50, Training loss: 6.6805


Epoch 43/50 [Val]: 100%|██████████| 313/313 [00:12<00:00, 25.99it/s, loss=6.45]


Epoch 43/50, Validation loss: 6.7113


Epoch 44/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.92it/s, loss=6.61]


Epoch 44/50, Training loss: 6.6777


Epoch 44/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.32it/s, loss=6.43]


Epoch 44/50, Validation loss: 6.7086


Epoch 45/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.93it/s, loss=6.61]


Epoch 45/50, Training loss: 6.6746


Epoch 45/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.15it/s, loss=6.44]


Epoch 45/50, Validation loss: 6.7068


Epoch 46/50 [Train]: 100%|██████████| 313/313 [00:34<00:00,  8.94it/s, loss=6.86]


Epoch 46/50, Training loss: 6.6713


Epoch 46/50 [Val]: 100%|██████████| 313/313 [00:11<00:00, 26.27it/s, loss=6.4]


Epoch 46/50, Validation loss: 6.7069


Epoch 47/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.91it/s, loss=6.74]


Epoch 47/50, Training loss: 6.6686


Epoch 47/50 [Val]: 100%|██████████| 313/313 [00:12<00:00, 25.76it/s, loss=6.4]


Epoch 47/50, Validation loss: 6.7052


Epoch 48/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.89it/s, loss=6.8]


Epoch 48/50, Training loss: 6.6644


Epoch 48/50 [Val]: 100%|██████████| 313/313 [00:12<00:00, 25.65it/s, loss=6.43]


Epoch 48/50, Validation loss: 6.7047


Epoch 49/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.92it/s, loss=6.72]


Epoch 49/50, Training loss: 6.6622


Epoch 49/50 [Val]: 100%|██████████| 313/313 [00:12<00:00, 26.01it/s, loss=6.41]


Epoch 49/50, Validation loss: 6.7052


Epoch 50/50 [Train]: 100%|██████████| 313/313 [00:35<00:00,  8.91it/s, loss=6.68]


Epoch 50/50, Training loss: 6.6608


Epoch 50/50 [Val]: 100%|██████████| 313/313 [00:12<00:00, 25.86it/s, loss=6.42]
  checkpoint = torch.load("best_model.pt")


Epoch 50/50, Validation loss: 6.7036
Generated text:
Once upon a time that teddy and and and Tommy They a she to ran and to day, the started the the a a and her the and it she had knew her was they to to loud a and and beautiful made fast scared to the She the go one calling. and happy! mud on was and The She if The You was stone. that happy. and and and for to One out car He had mom favorite Tim the a her "Would the little a and keep It dog He happy she She recording, so Bob to She and proud thought and He a and much and hurt. build the to way was a He the fall was a The was it was They the and her the angry box. day. was He day, had He Give her he But of a you fun. they that time, the strange to the so said, to for the he a everyone first, make not so day the and I and three and you Lily to was I her and She the up a to to could the the She that to that was and I and loved home. the They mom his scaring and
Model saved to ./tinystories_model/model.pt
