In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

In [2]:
def l2_loss(pred,target):
    return torch.sum((pred - target) ** 2)

In [3]:
@dataclass
class TitanConfig():
    d_model: int = 128
    vocab_size: int = 10000   # will update later
    seq_len: int = 32
    n_heads: int = 4
    alpha: float = 0.1
    eta: float = 0.9
    theta: float = 0.01
    window_size: int = 128
    batch_size: int = 8
    n_layers: int = 2
    chunk_size: int = 64    # for MAC variant (not used in MAG)
    N_p: int = 10
    bos_token_id: int = 2
    eos_token_id: int = 3

In [4]:
class TitanMemory(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.register_buffer("M",torch.eye(config.d_model))
        self.register_buffer("S",torch.zeros(config.d_model,config.d_model))

        self.query = nn.Linear(config.d_model,config.d_model,bias=False)
        self.key = nn.Linear(config.d_model,config.d_model,bias=False)
        self.value = nn.Linear(config.d_model,config.d_model,bias=False)

        self.alpha = config.alpha
        self.eta = config.eta
        self.theta = config.theta

    def forward(self,x):
        q = self.query(x)
        y = torch.matmul(q,self.M)
        return y

    def update_memory(self,x):
        B = x.size(0)
        if B != 1:
            for i in range(B):
                self.update_memory(x[i:i+1])
            return

        k = self.key(x)
        v = self.value(x)

        v_pred = torch.matmul(k,self.M)
        loss = l2_loss(v_pred,v)
        error = v_pred - v

        g = 2 * torch.matmul(error.t(),k)

        self.S = self.eta * self.S - self.theta * g
        self.S = torch.clamp(self.S, -1e3, 1e3)
        self.M = (1-self.alpha) * self.M + self.S
        self.M = torch.clamp(self.M, -1e3, 1e3)
        return loss

In [5]:
class SlidingWindowAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.window_size = config.window_size
        self.attention = nn.MultiheadAttention(embed_dim = config.d_model, num_heads = config.n_heads, batch_first= True)

    def forward(self,x):
        batch_size,seq_len,_ = x.size()
        output = []

        for i in range(0,seq_len,self.window_size):
            x_chunk = x[:,i:i+self.window_size,:]
            attn_out,_ = self.attention(x_chunk,x_chunk,x_chunk)
            output.append(attn_out)
        return torch.cat(output,dim=1)
        

In [6]:
class PersistentMemory(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.persistent = nn.Parameter(torch.randn(config.N_p,config.d_model))

    def forward(self,batch_size):
        return self.persistent.unsqueeze(0).expand(batch_size,-1,-1)

In [7]:
'''class TitanMAC(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.chunk_size = config.chunk_size
        self.persistent = PersistentMemory(config)
        self.long_memory = TitanMemory(config)
        self.attn_layer = nn.ModuleList([
        nn.TransformerEncoderLayer(d_model=config.d_model,nhead=config.n_heads,batch_first=True) for _ in range(config.n_layers) ])
        
    def forward(self,x):
        batch_size,seq_len,_ = x.size()
        output = []

        for i in range(0,seq_len,self.chunk_size):
            x_chunk = x[:,i:i+self.chunk_size,:]
            x_flat = x_chunk.reshape(-1,self.d_model)
            out = self.long_memory(x_flat)
            out = out.reshape(batch_size,-1,self.d_model)
            persistent = self.persistent(batch_size)

            cat_input = torch.cat([persistent,out,x_chunk],dim=1)

            for layer in self.attn_layer:
                cat_input = layer(cat_input)
            out_chunk = cat_input[:,-x_chunk.size(1):,:]
            self.long_memory.update_memory(out_chunk.reshape(-1,self.d_model))
            retrived = self.long_memory(out_chunk.reshape(-1,self.d_model))
            retrived = retrieved.reshape(out_chunk.shape)
            out = out_chunk * retrieved
            output.append(out)
            
        output = torch.cat(output,dim=1)
        return output'''

'class TitanMAC(nn.Module):\n    def __init__(self,config):\n        super().__init__()\n        self.d_model = config.d_model\n        self.chunk_size = config.chunk_size\n        self.persistent = PersistentMemory(config)\n        self.long_memory = TitanMemory(config)\n        self.attn_layer = nn.ModuleList([\n        nn.TransformerEncoderLayer(d_model=config.d_model,nhead=config.n_heads,batch_first=True) for _ in range(config.n_layers) ])\n        \n    def forward(self,x):\n        batch_size,seq_len,_ = x.size()\n        output = []\n\n        for i in range(0,seq_len,self.chunk_size):\n            x_chunk = x[:,i:i+self.chunk_size,:]\n            x_flat = x_chunk.reshape(-1,self.d_model)\n            out = self.long_memory(x_flat)\n            out = out.reshape(batch_size,-1,self.d_model)\n            persistent = self.persistent(batch_size)\n\n            cat_input = torch.cat([persistent,out,x_chunk],dim=1)\n\n            for layer in self.attn_layer:\n                cat_inp

In [8]:
class TitanMAG(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.window_size = config.window_size
        self.long_memory = TitanMemory(config)
        self.attn_layers = nn.ModuleList([SlidingWindowAttention(config) for _ in range(config.n_layers)])
        self.persistent = PersistentMemory(config)

    def forward(self,x):
        batch_size,seq_len,d_model = x.size()

        x_flat = x.reshape(-1,d_model)
        with torch.no_grad():
            self.long_memory.update_memory(x_flat)
        
        persistent_tokens = self.persistent(batch_size)
        out = torch.cat([persistent_tokens,x],dim=1)

        for layer in self.attn_layers:
            out = layer(out)
        y = out
        out_flat = out.reshape(-1,self.d_model)
        long_term = self.long_memory(out_flat)
        long_term = long_term.reshape(batch_size,-1,d_model)

        output = y * long_term
        output = output[:,-seq_len:,:]
        return output

In [9]:
'''class TitanMAL(nn.Module):
    def __init__(self,config):
        self.d_model = config.d_model
        self.long_memory = TitanMemory(config)
        self.short_memory = SlidingWindowAttention(config)
        self.n_layer = config.n_layer
        self.percistent = PersistentMemory(config)
        self.attn_layers = nn.ModuleList([ nn.TransformerEncoderLayer(d_model=config.d_model,nhead=config.n_heads,batch_first=True) for _ in range(config.n_layer)])

    def forward(self,x):
        batch_size,seq_len,d_model = x.size()
        persistent_tokens = self.persistent(batch_size)
        out = torch.cat([persistent_tokens,x],dim=1)
        
        y = self.long_memory(out)
        o = self.short_memory(y)
        o = o[:,-seq_len:,:]

        x = x + out
        x = self.short_memory(x)

        for layer in self.attn_layers:
            x = layer(x)
        return x'''

'class TitanMAL(nn.Module):\n    def __init__(self,config):\n        self.d_model = config.d_model\n        self.long_memory = TitanMemory(config)\n        self.short_memory = SlidingWindowAttention(config)\n        self.n_layer = config.n_layer\n        self.percistent = PersistentMemory(config)\n        self.attn_layers = nn.ModuleList([ nn.TransformerEncoderLayer(d_model=config.d_model,nhead=config.n_heads,batch_first=True) for _ in range(config.n_layer)])\n\n    def forward(self,x):\n        batch_size,seq_len,d_model = x.size()\n        persistent_tokens = self.persistent(batch_size)\n        out = torch.cat([persistent_tokens,x],dim=1)\n        \n        y = self.long_memory(out)\n        o = self.short_memory(y)\n        o = o[:,-seq_len:,:]\n\n        x = x + out\n        x = self.short_memory(x)\n\n        for layer in self.attn_layers:\n            x = layer(x)\n        return x'

In [10]:
class TitanMAGLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_embedding = nn.Parameter(torch.randn(config.seq_len, config.d_model))
        self.titan = TitanMAG(config)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size)
    
    def forward(self, x):
        # x: [B, seq_len] token ids
        B, L = x.size()
        emb = self.embedding(x)  # [B, L, d_model]
        pos = self.pos_embedding[:L, :].unsqueeze(0)  # [1, L, d_model]
        emb = emb + pos
        out = self.titan(emb)  # [B, L, d_model]
        logits = self.lm_head(out)  # [B, L, vocab_size]
        return logits
    
    def generate(self, prompt, max_length=50):
        self.eval()
        generated = prompt.copy()
        with torch.no_grad():
            for _ in range(max_length):
                input_ids = torch.tensor([generated[-self.config.seq_len:]], dtype=torch.long).to(next(self.parameters()).device)
                logits = self.forward(input_ids)  # [1, seq_len, vocab_size]
                next_token_logits = logits[0, -1, :]
                next_token = torch.argmax(next_token_logits).item()
                generated.append(next_token)
                if next_token == self.config.eos_token_id:
                    break
        return generated

In [3]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from collections import Counter
from tqdm import tqdm

# -------------------------------
# 1. Load WikiText-2 using Hugging Face datasets
# and train a BPE tokenizer using the tokenizers library.
# -------------------------------
from tokenizers import Tokenizer, models, pre_tokenizers, trainers, processors

# Load WikiText-2 (use 50% of the data for faster experimentation)
wikitext = load_dataset("wikitext", "wikitext-2-raw-v1", split={
    "train": "train[:20%]",
    "validation": "validation[:10%]",
    "test": "test[:10%]"
})

# Prepare an iterator of training lines (skipping empty ones)
def line_iterator(split):
    for line in wikitext[split]["text"]:
        if line.strip():
            yield line

# Initialize a BPE tokenizer
bpe_tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
bpe_tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
trainer = trainers.BpeTrainer(vocab_size=100000, special_tokens=["<unk>", "<pad>", "<bos>", "<eos>"])
bpe_tokenizer.train_from_iterator(line_iterator("train"), trainer)

bpe_tokenizer.post_processor = processors.TemplateProcessing(
    single="<bos> $A <eos>",
    pair="<bos> $A <eos> $B:1 <eos>:1",
    special_tokens=[
        ("<bos>", bpe_tokenizer.token_to_id("<bos>")),
        ("<eos>", bpe_tokenizer.token_to_id("<eos>"))
    ],
)


bpe_tokenizer.save("bpe_tokenizer.json")

vocab = bpe_tokenizer.get_vocab() 
vocab_size = len(vocab)
print("Vocabulary size:", vocab_size)

# Build inverse mapping for generation.
itos = {id: token for token, id in vocab.items()}

# -------------------------------
# 2. Create WikiText Dataset for Language Modeling using the BPE tokenizer.
# -------------------------------
class WikiTextLMDataset(Dataset):
    def __init__(self, split, tokenizer, seq_len):
        self.seq_len = seq_len
        self.tokenizer = tokenizer
        data = []
        for line in wikitext[split]["text"]:
            if not line.strip():
                continue
            encoding = tokenizer.encode(line)
            # encoding.ids already includes the BOS and EOS from post-processing.
            data.extend(encoding.ids)
        self.data = torch.tensor(data, dtype=torch.long)
        
    def __len__(self):
        return (len(self.data) - 1) // self.seq_len
    
    def __getitem__(self, idx):
        i = idx * self.seq_len
        x = self.data[i : i+self.seq_len]
        y = self.data[i+1 : i+self.seq_len+1]
        return x, y

seq_len = 128
train_dataset = WikiTextLMDataset(split="train", tokenizer=bpe_tokenizer, seq_len=seq_len)
valid_dataset = WikiTextLMDataset(split="validation", tokenizer=bpe_tokenizer, seq_len=seq_len)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32)

@dataclass
class TitanConfig():
    d_model: int = 256
    vocab_size: int = None 
    seq_len: int = 128
    n_heads: int = 8
    alpha: float = 0.1
    eta: float = 0.9
    theta: float = 0.01
    window_size: int = 256
    batch_size: int = 32
    n_layers: int = 8
    chunk_size: int = 64    # for MAC variant (not used here)
    N_p: int = 128
    bos_token_id: int = 2   # will update later
    eos_token_id: int = 3   # will update later

config = TitanConfig(vocab_size=vocab_size, d_model=256, seq_len=seq_len, n_layers=8, N_p=128)

class PersistentMemory(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.persistent = nn.Parameter(torch.randn(config.N_p, config.d_model))
    
    def forward(self, batch_size):
        return self.persistent.unsqueeze(0).expand(batch_size, -1, -1)

class TitanMemory(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model
        self.register_buffer("M", torch.eye(config.d_model))
        self.register_buffer("S", torch.zeros(config.d_model, config.d_model))
        self.query = nn.Linear(config.d_model, config.d_model, bias=False)
        self.key = nn.Linear(config.d_model, config.d_model, bias=False)
        self.value = nn.Linear(config.d_model, config.d_model, bias=False)
        self.alpha = config.alpha
        self.eta = config.eta
        self.theta = config.theta

    def forward(self, x):
        q = self.query(x)
        y = torch.matmul(q, self.M)
        return y

    def update_memory(self, x):
        B = x.size(0)
        if B != 1:
            for i in range(B):
                self.update_memory(x[i:i+1])
            return
        k = self.key(x)
        v = self.value(x)
        v_pred = torch.matmul(k, self.M)
        loss = torch.sum((v_pred - v) ** 2)
        error = v_pred - v
        g = 2 * torch.matmul(error.t(), k)
        self.S = self.eta * self.S - self.theta * g
        self.S = torch.clamp(self.S, -1e3, 1e3)
        self.M = (1 - self.alpha) * self.M + self.S
        self.M = torch.clamp(self.M, -1e3, 1e3)
        return loss

# Sliding-Window Attention Module (using standard MultiheadAttention as a proxy)
class SlidingWindowAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.window_size = config.window_size
        self.attention = nn.MultiheadAttention(embed_dim=config.d_model, num_heads=config.n_heads, batch_first=True)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        output = []
        for i in range(0, seq_len, self.window_size):
            x_chunk = x[:, i:i+self.window_size, :]
            attn_out, _ = self.attention(x_chunk, x_chunk, x_chunk)
            output.append(attn_out)
        return torch.cat(output, dim=1)

# TitanMAG: Gated Memory (MAG) Architecture
class TitanMAG(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model
        self.window_size = config.window_size
        self.long_memory = TitanMemory(config)
        self.attn_layers = nn.ModuleList([SlidingWindowAttention(config) for _ in range(config.n_layers)])
        self.persistent = PersistentMemory(config)
        self.layernorm = nn.LayerNorm(config.d_model)
    
    def forward(self, x):
        batch_size, seq_len, d_model = x.size()
        x_flat = x.reshape(-1, d_model)
        with torch.no_grad():
            self.long_memory.update_memory(x_flat)
        
        persistent_tokens = self.persistent(batch_size)
        tilde_x = torch.cat([persistent_tokens, x], dim=1)
        
        out = tilde_x
        for layer in self.attn_layers:
            out = layer(out)
        y = out
        
        tilde_x_flat = tilde_x.reshape(-1, d_model)
        memory_retrieval = self.long_memory(tilde_x_flat)
        memory_retrieval = memory_retrieval.reshape(batch_size, -1, d_model)
        
        norm_y = self.layernorm(y)
        norm_memory = self.layernorm(memory_retrieval)
        combined = norm_y * norm_memory
        
        output = combined[:, -seq_len:, :]
        return output

# TitanMAGLM: TitanMAG with LM Head for Language Modeling.
class TitanMAGLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_embedding = nn.Parameter(torch.randn(config.seq_len, config.d_model))
        self.titan = TitanMAG(config)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size)
    
    def forward(self, x):
        B, L = x.size()
        emb = self.embedding(x)
        pos = self.pos_embedding[:L, :].unsqueeze(0)
        emb = emb + pos
        out = self.titan(emb)
        logits = self.lm_head(out)
        return logits
    
    def generate(self, prompt, max_length=50, k=10):
        self.eval()
        generated = prompt.copy()
        with torch.no_grad():
            for _ in range(max_length):
                input_ids = torch.tensor([generated[-self.config.seq_len:]], dtype=torch.long).to(next(self.parameters()).device)
                logits = self.forward(input_ids)
                next_token_logits = logits[0, -1, :]
                topk_logits, topk_indices = torch.topk(next_token_logits, k)
                probs = F.softmax(topk_logits, dim=-1)
                next_token = topk_indices[torch.multinomial(probs, num_samples=1)].item()
                generated.append(next_token)
                if next_token == self.config.eos_token_id:
                    break
        return generated

def generate_sample_from_predefined_prompt(model, predefined_prompt, vocab, itos, max_length=256, k=10):
    # Convert the predefined prompt (a string) into token IDs.
    tokens = simple_tokenizer(predefined_prompt)
    prompt_ids = [vocab["<bos>"]] + [vocab.get(token, vocab["<unk>"]) for token in tokens]
    
    # Generate continuation using the model's generate method.
    generated_ids = model.generate(prompt_ids, max_length=max_length, k=k)
    
    # Convert the prompt IDs back to text.
    prompt_text = " ".join([itos.get(i, "<unk>") for i in prompt_ids])
    # The generated text is the continuation after the prompt.
    continuation_ids = generated_ids[len(prompt_ids):]
    generated_text = " ".join([itos.get(i, "<unk>") for i in continuation_ids])
    
    return prompt_text, generated_text





# -------------------------------
# 4. Training and Validation Setup with Checkpointing and tqdm
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config.vocab_size = vocab_size
config.seq_len = seq_len
config.bos_token_id = vocab["<bos>"]
config.eos_token_id = vocab["<eos>"]

model = TitanMAGLM(config).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=vocab["<pad>"])
checkpoint_path = "titan_checkpoint-3.pth"
start_epoch = 0
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    print(f"Resuming training from epoch {start_epoch}")

def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0.0
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    for x, y in progress_bar:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits.view(-1, config.vocab_size), y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        progress_bar.set_postfix(loss=f"{loss.item():.4f}")
    return total_loss / len(dataloader)

def evaluate_epoch(model, dataloader, device):
    model.eval()
    total_loss = 0.0
    progress_bar = tqdm(dataloader, desc="Validation", leave=False)
    with torch.no_grad():
        for x, y in progress_bar:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits.view(-1, config.vocab_size), y.view(-1))
            total_loss += loss.item()
            progress_bar.set_postfix(loss=f"{loss.item():.4f}")
    return total_loss / len(dataloader)

num_epochs = 100
for epoch in range(start_epoch, num_epochs):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Total parameters:", total_params)
    print("Trainable parameters:", trainable_params)
    
    train_loss = train_epoch(model, train_loader, optimizer, device)
    val_loss = evaluate_epoch(model, valid_loader, device)
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    
    # Example usage:
    predefined_prompt = "The "
    prompt_text, generated_text = generate_sample_from_predefined_prompt(model, predefined_prompt, vocab, itos, max_length=256, k=10)
    print("Prompt:", prompt_text)
    print("Continuation:", generated_text)

    
    # Save checkpoint after each epoch.
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved at epoch {epoch+1}")





Vocabulary size: 44368
Total parameters: 25128784
Trainable parameters: 25128784


                                                                                

Epoch 1/100, Train Loss: 10.6930, Val Loss: 10.6873




NameError: name 'simple_tokenizer' is not defined

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from collections import Counter
from tqdm import tqdm
from tokenizers import Tokenizer, models, pre_tokenizers, trainers, processors

# -------------------------------
# 0. Define a Simple Tokenizer Function (for prompt processing)
# -------------------------------
def simple_tokenizer(text):
    return text.lower().split()

# -------------------------------
# 1. Load WikiText-2 and Initialize/Train a Custom BPE Tokenizer
# -------------------------------
print("Loading WikiText-2 dataset...")
wikitext = load_dataset("wikitext", "wikitext-2-raw-v1", split={
    "train": "train[:10%]",
    "validation": "validation[:10%]",
    "test": "test[:10%]"
})

# Check if a saved tokenizer exists; otherwise, train one.
tokenizer_path = "bpe_tokenizer.json"
if os.path.exists(tokenizer_path):
    print("Loading saved BPE tokenizer...")
    bpe_tokenizer = Tokenizer.from_file(tokenizer_path)
else:
    print("Training a new BPE tokenizer...")
    def line_iterator(split):
        for line in wikitext[split]["text"]:
            if line.strip():
                yield line
    bpe_tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
    bpe_tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
    trainer = trainers.BpeTrainer(vocab_size=100000, special_tokens=["<unk>", "<pad>", "<bos>", "<eos>"])
    bpe_tokenizer.train_from_iterator(line_iterator("train"), trainer)
    bpe_tokenizer.post_processor = processors.TemplateProcessing(
        single="<bos> $A <eos>",
        pair="<bos> $A <eos> $B:1 <eos>:1",
        special_tokens=[
            ("<bos>", bpe_tokenizer.token_to_id("<bos>")),
            ("<eos>", bpe_tokenizer.token_to_id("<eos>"))
        ],
    )
    bpe_tokenizer.save(tokenizer_path)
    print("Tokenizer saved to", tokenizer_path)

# Build vocabulary and inverse mapping from the BPE tokenizer.
vocab = bpe_tokenizer.get_vocab()  # dict: token -> id
vocab_size = len(vocab)
print("Vocabulary size:", vocab_size)
itos = {id: token for token, id in vocab.items()}

# -------------------------------
# 2. Create WikiText Dataset for LM using the Custom BPE Tokenizer
# -------------------------------
class WikiTextLMDataset(Dataset):
    def __init__(self, split, tokenizer, seq_len):
        self.seq_len = seq_len
        self.tokenizer = tokenizer
        data = []
        for line in wikitext[split]["text"]:
            if not line.strip():
                continue
            # The post-processor adds <bos> and <eos>
            encoding = tokenizer.encode(line)
            data.extend(encoding.ids)
        self.data = torch.tensor(data, dtype=torch.long)
        
    def __len__(self):
        return (len(self.data) - 1) // self.seq_len
    
    def __getitem__(self, idx):
        i = idx * self.seq_len
        x = self.data[i : i+self.seq_len]
        y = self.data[i+1 : i+self.seq_len+1]
        return x, y

seq_len = 64
train_dataset = WikiTextLMDataset(split="train", tokenizer=bpe_tokenizer, seq_len=seq_len)
valid_dataset = WikiTextLMDataset(split="validation", tokenizer=bpe_tokenizer, seq_len=seq_len)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32)

# -------------------------------
# 3. Define Titan Configuration and Model Components
# -------------------------------
@dataclass
class TitanConfig:
    d_model: int = 256
    vocab_size: int = 100000  # will update below
    seq_len: int = 64
    n_heads: int = 8
    alpha: float = 0.1
    eta: float = 0.9
    theta: float = 0.01
    window_size: int = 128
    batch_size: int = 32
    n_layers: int = 6
    chunk_size: int = 64    # for MAC variant (not used here)
    N_p: int = 64
    bos_token_id: int = 2   # will update below
    eos_token_id: int = 3   # will update below

config = TitanConfig(vocab_size=vocab_size, d_model=256, seq_len=seq_len, n_layers=6, N_p=64)

# Persistent Memory Module
class PersistentMemory(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.persistent = nn.Parameter(torch.randn(config.N_p, config.d_model))
    
    def forward(self, batch_size):
        return self.persistent.unsqueeze(0).expand(batch_size, -1, -1)

# Titan Memory Module (Long-Term Memory)
class TitanMemory(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model
        self.register_buffer("M", torch.eye(config.d_model))
        self.register_buffer("S", torch.zeros(config.d_model, config.d_model))
        self.query = nn.Linear(config.d_model, config.d_model, bias=False)
        self.key = nn.Linear(config.d_model, config.d_model, bias=False)
        self.value = nn.Linear(config.d_model, config.d_model, bias=False)
        self.alpha = config.alpha
        self.eta = config.eta
        self.theta = config.theta

    def forward(self, x):
        q = self.query(x)
        return torch.matmul(q, self.M)

    def update_memory(self, x):
        B = x.size(0)
        if B != 1:
            for i in range(B):
                self.update_memory(x[i:i+1])
            return
        k = self.key(x)
        v = self.value(x)
        v_pred = torch.matmul(k, self.M)
        loss = torch.sum((v_pred - v) ** 2)
        error = v_pred - v
        g = 2 * torch.matmul(error.t(), k)
        self.S = self.eta * self.S - self.theta * g
        self.S = torch.clamp(self.S, -1e3, 1e3)
        self.M = (1 - self.alpha) * self.M + self.S
        self.M = torch.clamp(self.M, -1e3, 1e3)
        return loss

# Sliding-Window Attention Module (using standard MultiheadAttention as a proxy)
class SlidingWindowAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim=config.d_model, num_heads=config.n_heads, batch_first=True)

    def forward(self, x):
        return self.attention(x, x, x)[0]

# TitanMAG: Gated Memory (MAG) Architecture
class TitanMAG(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.long_memory = TitanMemory(config)
        self.attn_layers = nn.ModuleList([SlidingWindowAttention(config) for _ in range(config.n_layers)])
        self.persistent = PersistentMemory(config)
        self.layernorm = nn.LayerNorm(config.d_model)
    
    def forward(self, x):
        batch_size, seq_len, d_model = x.size()
        x_flat = x.reshape(-1, d_model)
        with torch.no_grad():
            self.long_memory.update_memory(x_flat)
        
        persistent_tokens = self.persistent(batch_size)
        tilde_x = torch.cat([persistent_tokens, x], dim=1)
        
        out = tilde_x
        for layer in self.attn_layers:
            out = layer(out)
        y = out
        
        tilde_x_flat = tilde_x.reshape(-1, d_model)
        memory_retrieval = self.long_memory(tilde_x_flat)
        memory_retrieval = memory_retrieval.reshape(batch_size, -1, d_model)
        
        norm_y = self.layernorm(y)
        norm_memory = self.layernorm(memory_retrieval)
        combined = norm_y * norm_memory
        
        output = combined[:, -seq_len:, :]
        return output

# TitanMAGLM: TitanMAG with LM Head for Language Modeling.
class TitanMAGLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_embedding = nn.Parameter(torch.randn(config.seq_len, config.d_model))
        self.titan = TitanMAG(config)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size)
    
    def forward(self, x):
        B, L = x.size()
        emb = self.embedding(x) + self.pos_embedding[:L, :].unsqueeze(0)
        out = self.titan(emb)
        logits = self.lm_head(out)
        return logits
    
    def generate(self, prompt, max_length=50, k=10, temperature=1.0):
        self.eval()
        generated = prompt.copy()
        with torch.no_grad():
            for _ in range(max_length):
                input_ids = torch.tensor([generated[-self.config.seq_len:]], dtype=torch.long).to(next(self.parameters()).device)
                logits = self.forward(input_ids)
                # Get logits for the last token and apply temperature scaling.
                next_token_logits = logits[0, -1, :] / temperature
                topk_logits, topk_indices = torch.topk(next_token_logits, k)
                probs = F.softmax(topk_logits, dim=-1)
                next_token = topk_indices[torch.multinomial(probs, num_samples=1)].item()
                generated.append(next_token)
                if next_token == self.config.eos_token_id:
                    break
        return generated

# -------------------------------
# 4. Training/Validation Setup with Checkpointing and tqdm
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config.vocab_size = vocab_size
config.seq_len = seq_len
config.bos_token_id = vocab["<bos>"]
config.eos_token_id = vocab["<eos>"]

model = TitanMAGLM(config).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=vocab["<pad>"])
checkpoint_path = "titan_checkpoint-1.pth"
start_epoch = 0
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    print(f"Resuming training from epoch {start_epoch}")

def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0.0
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    for x, y in progress_bar:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits.view(-1, config.vocab_size), y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        progress_bar.set_postfix(loss=f"{loss.item():.4f}")
    return total_loss / len(dataloader)

def evaluate_epoch(model, dataloader, device):
    model.eval()
    total_loss = 0.0
    progress_bar = tqdm(dataloader, desc="Validation", leave=False)
    with torch.no_grad():
        for x, y in progress_bar:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits.view(-1, config.vocab_size), y.view(-1))
            total_loss += loss.item()
            progress_bar.set_postfix(loss=f"{loss.item():.4f}")
    return total_loss / len(dataloader)

def generate_sample_from_predefined_prompt(model, predefined_prompt, vocab, itos, max_length=256, k=10, temperature=1.0):
    tokens = simple_tokenizer(predefined_prompt)
    prompt_ids = [vocab["<bos>"]] + [vocab.get(token, vocab["<unk>"]) for token in tokens]
    generated_ids = model.generate(prompt_ids, max_length=max_length, k=k, temperature=temperature)
    prompt_text = " ".join([itos.get(i, "<unk>") for i in prompt_ids])
    continuation_ids = generated_ids[len(prompt_ids):]
    generated_text = " ".join([itos.get(i, "<unk>") for i in continuation_ids])
    return prompt_text, generated_text

# -------------------------------
# 5. Training Loop with Checkpoint Saving and Sampling
# -------------------------------
num_epochs = 100
for epoch in range(start_epoch, num_epochs):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Total parameters:", total_params)
    print("Trainable parameters:", trainable_params)
    
    train_loss = train_epoch(model, train_loader, optimizer, device)
    val_loss = evaluate_epoch(model, valid_loader, device)
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    
    predefined_prompt = "The "
    prompt_text, generated_text = generate_sample_from_predefined_prompt(model, predefined_prompt, vocab, itos, max_length=256, k=10, temperature=1.0)
    print("Prompt:", prompt_text)
    print("Continuation:", generated_text)
    
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved at epoch {epoch+1}")


print("Chatbot is ready! Type 'exit' or 'quit' to stop.")
while True:
    user_input = input("User: ")
    if user_input.lower() in ["exit", "quit"]:
        break
    prompt_text, response = generate_sample_from_predefined_prompt(model, user_input, vocab, itos, max_length=100, k=10, temperature=1.0)
    print("Bot:", response)


Loading WikiText-2 dataset...
Loading saved BPE tokenizer...
Vocabulary size: 44368


  checkpoint = torch.load(checkpoint_path, map_location=device)


Resuming training from epoch 40
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 41/100, Train Loss: 7.2734, Val Loss: 8.1746
Prompt: <bos> the
Continuation: . of a . the and , <bos> <eos>
Checkpoint saved at epoch 41
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 42/100, Train Loss: 7.2782, Val Loss: 8.2301
Prompt: <bos> the
Continuation: and . the the . in in The . , to , <eos>
Checkpoint saved at epoch 42
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 43/100, Train Loss: 7.2714, Val Loss: 8.2189
Prompt: <bos> the
Continuation: to and of and = " , . . . , . = the and , the , , the and the of in the in a of <eos>
Checkpoint saved at epoch 43
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 44/100, Train Loss: 7.2471, Val Loss: 8.2265
Prompt: <bos> the
Continuation: . the , a , , of @-@ a of in and a the <bos> s to <bos> to , the the the the of , the . in . , " in the = of . a . to the , in , of , . the the , <eos>
Checkpoint saved at epoch 44
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 45/100, Train Loss: 7.2361, Val Loss: 8.2517
Prompt: <bos> the
Continuation: the the . the and of , of = . of to and = in of . to the , and of of of = <bos> the of . the . , and a in . , <eos>
Checkpoint saved at epoch 45
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 46/100, Train Loss: 7.2607, Val Loss: 8.2340
Prompt: <bos> the
Continuation: was the a in in the , the . " a in the the , . , the a @-@ of , . of , the , , was the in the the the of and the the . and , , and . . , . in the = a in to in of . . , the . . , of a the . of . . in the <bos> = , and <eos>
Checkpoint saved at epoch 46
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 47/100, Train Loss: 7.2305, Val Loss: 8.2516
Prompt: <bos> the
Continuation: and and . and of . , <bos> , " <bos> to the of , , and in <eos>
Checkpoint saved at epoch 47
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 48/100, Train Loss: 7.2234, Val Loss: 8.2133
Prompt: <bos> the
Continuation: in , , in , the to . of and , the = . of : of a = a , in the a the = <bos> . of the , , . = . , <eos>
Checkpoint saved at epoch 48
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 49/100, Train Loss: 7.2069, Val Loss: 8.2397
Prompt: <bos> the
Continuation: of @-@ a and , , in " of the the and , the of the the a the , was the , " in . <eos>
Checkpoint saved at epoch 49
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 50/100, Train Loss: 7.2077, Val Loss: 8.1989
Prompt: <bos> the
Continuation: . , = " of and a of s <bos> of in , of , in the . , the = of to the = <bos> , , . in to the . , the = in the . . . the . and and = , " the a of in and to , the to , a the , of , a . . , , and of = of . The the the to of the , and . the . , . the = the the , the . of to the in , the the a , and of , to a of in of the , the = in the , , and = a , of , of of . and = of . = and of the . the of . in to and and to , a . . , in , , and . , , , the , . in of = the . <bos> in to the of . and . <eos>
Checkpoint saved at epoch 50
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 51/100, Train Loss: 7.2103, Val Loss: 8.2404
Prompt: <bos> the
Continuation: the . . , . to to @-@ , . . of . in , . of the to the . the = , , of = <bos> the and = to " , . the a to the . , <eos>
Checkpoint saved at epoch 51
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 52/100, Train Loss: 7.2073, Val Loss: 8.2432
Prompt: <bos> the
Continuation: the = = . the a in the in . and . , . = , . to the in . , . , the . , and the of of , of the and the <eos>
Checkpoint saved at epoch 52
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 53/100, Train Loss: 7.1946, Val Loss: 8.2368
Prompt: <bos> the
Continuation: and of to , in the the , . in <eos>
Checkpoint saved at epoch 53
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 54/100, Train Loss: 7.1922, Val Loss: 8.2733
Prompt: <bos> the
Continuation: in , a to , in , , to , of of to , the a = in , <bos> the , . of , . . . . , . . in " , . the , in the = in " and and , = in the a , = , of the and of , to , a and the , , a = to of . , the and " a The = a and in = to . , , of the to and and the , the , , and " was . the and the the in to of was <bos> . the in = the the = to in in <bos> the , of to of <eos>
Checkpoint saved at epoch 54
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 55/100, Train Loss: 7.1798, Val Loss: 8.2596
Prompt: <bos> the
Continuation: of of and " to , , the , in the the the of a , the the the a , of . , the of , of . in and and a in the the to , the to in the was = , to the = a the . <eos>
Checkpoint saved at epoch 55
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 56/100, Train Loss: 7.1692, Val Loss: 8.2747
Prompt: <bos> the
Continuation: . the the in in the the , in the . the <eos>
Checkpoint saved at epoch 56
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 57/100, Train Loss: 7.1758, Val Loss: 8.2707
Prompt: <bos> the
Continuation: . " to the . a in , the the of the the , the the a and and in the and the the <eos>
Checkpoint saved at epoch 57
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 58/100, Train Loss: 7.1682, Val Loss: 8.2764
Prompt: <bos> the
Continuation: of the in . the the . to , , and of , . " , . of and s . and of , , the , . a the , , . to and in . the the " and a = the to , in " " <eos>
Checkpoint saved at epoch 58
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 59/100, Train Loss: 7.1736, Val Loss: 8.2888
Prompt: <bos> the
Continuation: to in . in of the . of and the . in . to a and " . the . , , and " , . , , a . the <bos> , the a = the the of " , , of and " the the s a = the the the , of a , of the , and in . the and , . , and to of of , the of in . in . , the of the . , and of a , of , = of . , the to and " <eos>
Checkpoint saved at epoch 59
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 60/100, Train Loss: 7.1554, Val Loss: 8.2879
Prompt: <bos> the
Continuation: the , in s , . and and a . = the , of the the <eos>
Checkpoint saved at epoch 60
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 61/100, Train Loss: 7.1567, Val Loss: 8.2964
Prompt: <bos> the
Continuation: the to the in . of a of , , to a the the , . and and . . in to to <bos> in in , of = = the to of the the a , to , and the in in . . and = and , the = . . a , was a . a a . in to . . a <eos>
Checkpoint saved at epoch 61
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 62/100, Train Loss: 7.1687, Val Loss: 8.2792
Prompt: <bos> the
Continuation: = of = in <eos>
Checkpoint saved at epoch 62
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 63/100, Train Loss: 7.1625, Val Loss: 8.3223
Prompt: <bos> the
Continuation: the the of of . the " and , a the , , . and the in s , <bos> and the of of . the , of , to . to the of = the , the and . and = a . the in to the and of . was . , the . , = of of and the , the and , a <bos> to , . . and to , the the . . = = the the . of in . in , . , and a the the , and . <bos> to a the and = . and the the a = of the a in a the . of in the the " . . and . . . to , , the . the the of in the and . of = , the a a the and , in the to in , = the . . , and . the the of to a , , in in = . and , . = of . a , , in the , = and a , , the of , the " a . a in and in the , . , to , of a . , , . to the , the of and in , and of the and the , the = , the <eos>
Checkpoint saved at epoch 63
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 64/100, Train Loss: 7.1556, Val Loss: 8.3295
Prompt: <bos> the
Continuation: the and of and , the the and the and a and the the the <eos>
Checkpoint saved at epoch 64
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 65/100, Train Loss: 7.1478, Val Loss: 8.2932
Prompt: <bos> the
Continuation: <bos> and in of the in and in , to and the and the , of and <eos>
Checkpoint saved at epoch 65
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 66/100, Train Loss: 7.1532, Val Loss: 8.3375
Prompt: <bos> the
Continuation: , . , the the in . a in in and in , . and . the in = and = , " " , in . of and of , = . a of , , the , , = the in the . , , <bos> = to . , the , a of the the and , , . " in and in of the of . . the , , , of = and = " the , , the , , = . The " the , in , of , and , , of a the was , to , = in . a . and the , the of the of , a the . . . in and . , , . of = in to = the to , , to the , the . the the , a the of and , of a = and and . to , , to , = and . = of . the and in in and in , <bos> and a in and . to the the , in of , a = the " of the , = the in , the , the the to the . and = . of , <bos> the " and in , . , the in in = , the a , the , and a the in of a in a a . . of = the the the a , . the to . @-@ and , to
Checkpoint saved at epoch 66
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 67/100, Train Loss: 7.1557, Val Loss: 8.3153
Prompt: <bos> the
Continuation: , the in <bos> a in <bos> the in and a , of . = the a the of a , and " to = , a , a , the the the . , , a and a to a , in = a = , , . , " , the , the a , " and the , the the of = of the of . the " to to a . . and <eos>
Checkpoint saved at epoch 67
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 68/100, Train Loss: 7.1387, Val Loss: 8.3658
Prompt: <bos> the
Continuation: to a , <bos> in <bos> , , of the , . to in in a <eos>
Checkpoint saved at epoch 68
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 69/100, Train Loss: 7.1536, Val Loss: 8.3014
Prompt: <bos> the
Continuation: the <bos> the in the " of was , the , the of the , and . = to " to the " a to of a the the the a the and and of . the " , of and a . and to , was , " of " the , the of the . to the and the , and , of = <bos> a the a , , in of and . of . to . in . a of of in of . the . to of of , and to the the a , of the = . of the to to the , , . , the . a in . the . . = the of , , in of , " of the of " in the the . to , . , of of of , , a , <eos>
Checkpoint saved at epoch 69
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 70/100, Train Loss: 7.1435, Val Loss: 8.3134
Prompt: <bos> the
Continuation: and . and = <eos>
Checkpoint saved at epoch 70
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 71/100, Train Loss: 7.1294, Val Loss: 8.3714
Prompt: <bos> the
Continuation: a <eos>
Checkpoint saved at epoch 71
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 72/100, Train Loss: 7.1332, Val Loss: 8.3615
Prompt: <bos> the
Continuation: the and the s and of = and the , , a . the , the . , the . " and the . of , , . the , the = to the of = = . to a and the . of the the of in and , a and , of , in . the a in a <bos> , <bos> . . in the a = , and . , a , and the , to a of , the , <bos> and <eos>
Checkpoint saved at epoch 72
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 73/100, Train Loss: 7.1285, Val Loss: 8.3272
Prompt: <bos> the
Continuation: of , = to . in . and of in the a and in , , to = , = of and , , the . of of . . in <bos> and , and and , <eos>
Checkpoint saved at epoch 73
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 74/100, Train Loss: 7.1282, Val Loss: 8.3795
Prompt: <bos> the
Continuation: . , and . of = in , . of and and = of , of in in the the . in , , , in . , , " a a the to the of the and , a of " a , in in , a of and , , and . and . and " , a the the the = a , , the , to , <eos>
Checkpoint saved at epoch 74
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 75/100, Train Loss: 7.1185, Val Loss: 8.3022
Prompt: <bos> the
Continuation: in = . of . , . . and of the . . in = " to " of , the = to , , in in , to the and to and the of to in , <eos>
Checkpoint saved at epoch 75
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 76/100, Train Loss: 7.1350, Val Loss: 8.3701
Prompt: <bos> the
Continuation: a the , and <eos>
Checkpoint saved at epoch 76
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 77/100, Train Loss: 7.1334, Val Loss: 8.3891
Prompt: <bos> the
Continuation: to the the the , . and in and and , and . and . . and the " the , in , to and of the the <bos> and . , the of , , of of and in to of in the = in . <bos> , , of . . and = the of . . of the in of the , . = = the . the and the of of . a the in the the = of in . , of , the of and in , a and and . a the the and and " <bos> , of . , , a , and the , of of . , , the to , = , a in and the . , , , and . = , to of in the = the . in in , the = , , in " , and . the . . , and in , of = , <eos>
Checkpoint saved at epoch 77
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 78/100, Train Loss: 7.1192, Val Loss: 8.4090
Prompt: <bos> the
Continuation: of = a to <bos> a a , of the a , the <bos> , and and . to , . the , of of and in , the to . . a , the = , , and . a the a , , and the the the the and , , a , of of of in in . the the = a of " of , the a = in = of the and the , of of of the in of , . in and of to a and , the the , , a , and the a , a and of the of . in a , the of and to . and , . in . of the , the the of in = of and the the a the a , in of the <eos>
Checkpoint saved at epoch 78
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 79/100, Train Loss: 7.1278, Val Loss: 8.4292
Prompt: <bos> the
Continuation: . , , <eos>
Checkpoint saved at epoch 79
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 80/100, Train Loss: 7.1236, Val Loss: 8.3709
Prompt: <bos> the
Continuation: , the and the . in " , . a a of the a in = to in " " and . , " the the the , to to , , , and of the the the , the the , , = to . . " the <bos> the " the in the = of . . the , , , the in " of the and the the , in the in . the . to and the = , the of and and of in <bos> . to , the and the = the the of and a , the " to , of . and , . in , to , " <eos>
Checkpoint saved at epoch 80
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 81/100, Train Loss: 7.1205, Val Loss: 8.3714
Prompt: <bos> the
Continuation: , . " " in in . , and , of and the the and and the the , . , the , of = , , and in the and , the of and the the the , , the <bos> of the and the , in the , , . , , in the in a the the . . of and of , the of , , . and the . . , . the the in , the , the of of . in the the in = to the , . , the . , the . . and to <eos>
Checkpoint saved at epoch 81
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 82/100, Train Loss: 7.1151, Val Loss: 8.4010
Prompt: <bos> the
Continuation: the , the of to , , the = and the , a = of , and , the . in a the the of , the the the . , in . the <eos>
Checkpoint saved at epoch 82
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 83/100, Train Loss: 7.1172, Val Loss: 8.4743
Prompt: <bos> the
Continuation: . of and , the the = " . in , , in . the in = to the , . , " . , the a the in and = , , <bos> . <bos> of . of and and , the and , . " . of . the . the the and , to . the the to in . of . the the the a , . . . the , the the of of the the , and the the of of to , the = . to of to = to , a the and of the . the <bos> , the of and of = , the the in of <bos> and , , to the the and and , , <eos>
Checkpoint saved at epoch 83
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 84/100, Train Loss: 7.1121, Val Loss: 8.3712
Prompt: <bos> the
Continuation: the the a was and to and . a of " , . , . of the , to " . the a . the to of = @-@ , in the the , to a the , = = a " and a of , in a the , the the to , " and in the the to the to the to in " <eos>
Checkpoint saved at epoch 84
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 85/100, Train Loss: 7.1175, Val Loss: 8.4676
Prompt: <bos> the
Continuation: . a , and , <bos> a the , , to , in . to of , , a . the , in to . . . , , the and , to the and a the the the to and and in = of , to . and the in a , and in and a , a the . the the = = . , . of and = = . , a the the , . in , and <bos> to , in the and , the a , " of . the , , . and the the . the , the to " in the , <bos> <bos> , , in . a to to , a . the the the the . of <eos>
Checkpoint saved at epoch 85
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 86/100, Train Loss: 7.1201, Val Loss: 8.4070
Prompt: <bos> the
Continuation: , and a in a the a . , the , to the <eos>
Checkpoint saved at epoch 86
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 87/100, Train Loss: 7.1111, Val Loss: 8.4381
Prompt: <bos> the
Continuation: to in in , . , . and a the of . of and and , and of " the the the , to the to the the , in and , to = of the = <bos> of the the of the to , to and , , the of and " of , <eos>
Checkpoint saved at epoch 87
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 88/100, Train Loss: 7.1074, Val Loss: 8.4193
Prompt: <bos> the
Continuation: , , to , <bos> , the the of and , . <bos> in in , " , in the of , , of the of to and the . . the the and to , , and to the to and , = to and . = " the " , . , in the , to the the . = = = , <eos>
Checkpoint saved at epoch 88
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 89/100, Train Loss: 7.1157, Val Loss: 8.4208
Prompt: <bos> the
Continuation: the to the was and and a of a of , , , . , . the . of . a the . the the the in and and the of the . . " in to to , of the . the . , of , and . to <bos> , the . the , to , the a a . , = of a , and = . and in , . the " and . , in and = the the the to , . , , the to , the the a <eos>
Checkpoint saved at epoch 89
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 90/100, Train Loss: 7.1065, Val Loss: 8.4909
Prompt: <bos> the
Continuation: . . the of of the , . a . , , and . . , . of . , the to , the = to the . the . = of . . in , the the the . the . , , and . the the a . of a of . a to the , the , of the . the . the the in the a to , the the a the , in , . . , . . of to , and and . , the and = , . = of . , , . and the to , the the = in , the a . a = , to . , the . in " . to of of . of the of and . to to and a , , , in the . and the , , , " to the , the . , . , = of , " , in " of and the the , . and the . to and , the in " in , . , . of and of , . , of = . . the <bos> . to the to the , . , the of the , " the . the in and the , of , of of , the the the of . = . the the , of and , , , , , the and " , , the . <eos>
Checkpoint saved at epoch 90
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 91/100, Train Loss: 7.1054, Val Loss: 8.4500
Prompt: <bos> the
Continuation: the the in . . . and and the <bos> of , of to in and = a in the . = the a " a a and , in the the a , in to the of of a . to a the a the <eos>
Checkpoint saved at epoch 91
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 92/100, Train Loss: 7.1044, Val Loss: 8.3933
Prompt: <bos> the
Continuation: . , , the the . the <bos> the in to and a to of = the " " in . , . . the the <bos> of a the the the in , " of in to of " the the to = , . , " in the = the = " , , . in of and the in the a = . and a . . the of the in to the the . the the in in , , and the of . . the the . = of of in = to , of = , in , , . of to <eos>
Checkpoint saved at epoch 92
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 93/100, Train Loss: 7.1012, Val Loss: 8.4113
Prompt: <bos> the
Continuation: to to and was and in the the , " a the a a . of , , the the the and to to and , to . of of the and . in = and . the and to a = in the of in , the in , of . and <eos>
Checkpoint saved at epoch 93
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 94/100, Train Loss: 7.1210, Val Loss: 8.4169
Prompt: <bos> the
Continuation: a the and <bos> of the , of of . in in , " of in the a in and in , to the . , the a , , = . the = to the . of . . of of , . , in = , , . . the the = , the , to of and to , in of a and the " in and , of . and the the a a , . of to the = , of of the the a a of . " the a the , , , in . and of , to of , . the to in , the , the the a in , to , . , " of the , . the and the = <eos>
Checkpoint saved at epoch 94
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 95/100, Train Loss: 7.1002, Val Loss: 8.4427
Prompt: <bos> the
Continuation: the . and the , of , , of to in . " <eos>
Checkpoint saved at epoch 95
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 96/100, Train Loss: 7.1102, Val Loss: 8.4042
Prompt: <bos> the
Continuation: to , , , , and the . , to the the and and . a to of , the the a . " the , in the , of in <bos> , , the a , the , . , . , = , the the the the , the . in the to to <eos>
Checkpoint saved at epoch 96
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 97/100, Train Loss: 7.1062, Val Loss: 8.4572
Prompt: <bos> the
Continuation: the the the the , of , , " of and , . to of , the a the of a of . , and , , , in , , of of to , . . the , , the was , the and a . of and the the in in and , of <bos> the the " of of . in . the of of in , = , the the of the the " . in " to the to of , in " the . to the the to . , of a in and a . of the , and , , , the . the = and = and the " , the of , in to in and , was a the the to of to to the , of <bos> a to , the , to , " the of , and in the , the the of . , and , to a in . of " , = , of . . , = the . , of and , , and <bos> a in and a and of of . and = in . of , , was a in , the , the , in and a , of , the was to the to to , the . of . to a the . of the , , the . . of the and the the to " to , . . and to . to and in . . , ,
Checkpoint saved at epoch 97
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 98/100, Train Loss: 7.0996, Val Loss: 8.4595
Prompt: <bos> the
Continuation: to <bos> the a a , of . and , the . . " and . of the in , a to of of = of , the , the a to a , the , the of of , the . and the , , in " , to the = the = the the , and . of = to of to in , . of to the to was . to . . in the . to a the the in . in to the = the the , in , , . and of of in in the in , , and . in the <bos> of . , . " . to the , = a of in in . the the . the of " . . = = of to of . a in the the . to . the and , . and to , to , of , in to of the the and the of a in to a = of the in <eos>
Checkpoint saved at epoch 98
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 99/100, Train Loss: 7.1106, Val Loss: 8.4235
Prompt: <bos> the
Continuation: a , and of the and the to of , in to of to to and the the the and , of to a , a a , the a to , a = the the the , and and . the to a , of , the to the in , a to a of in , the in , of in the the " . , in to in the of the and . the in a , = = and = a , the the = , , of to in the of a of <bos> and , in . the . of , = the = . of the , in the = a to , the a , a a and to , the to of of to = = of to <eos>
Checkpoint saved at epoch 99
Total parameters: 24569680
Trainable parameters: 24569680


                                                                                

Epoch 100/100, Train Loss: 7.0985, Val Loss: 8.4808
Prompt: <bos> the
Continuation: , the in and . . of , and , a of . the the , the and a and to = in and , the <bos> <eos>
Checkpoint saved at epoch 100
Chatbot is ready! Type 'exit' or 'quit' to stop.
