In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from collections import Counter
from datasets import load_dataset

# -------------------------------
# 1. Configuration Class (Same as Used in Training)
# -------------------------------
@dataclass
class TitanConfig:
    d_model: int = 256
    vocab_size: int = 10000  # Update after building vocab
    seq_len: int = 128
    n_heads: int = 8
    alpha: float = 0.1
    eta: float = 0.9
    theta: float = 0.01
    window_size: int = 256
    batch_size: int = 32
    n_layers: int = 8
    N_p: int = 128
    bos_token_id: int = 2
    eos_token_id: int = 3

# -------------------------------
# 2. Load WikiText-2 Dataset and Build Vocabulary
# -------------------------------
def simple_tokenizer(text):
    return text.lower().split()

# Load a subset of WikiText dataset
print("Loading dataset...")
wikitext = load_dataset("wikitext", "wikitext-2-raw-v1")

# Reduce dataset size for faster loading
N = 5000
for split in ["train", "validation", "test"]:
    wikitext[split] = wikitext[split].select(range(min(N, len(wikitext[split]))))

# Build vocabulary
print("Building vocabulary...")
counter = Counter()
for line in wikitext["train"]["text"]:
    if line.strip():
        counter.update(simple_tokenizer(line))

# Special tokens
special_tokens = ["<unk>", "<pad>", "<bos>", "<eos>"]
vocab = {token: idx for idx, token in enumerate(special_tokens)}

# Add words appearing at least twice
min_freq = 2
for token, freq in counter.items():
    if freq >= min_freq and token not in vocab:
        vocab[token] = len(vocab)

vocab_size = len(vocab)
itos = {idx: token for token, idx in vocab.items()}  # id -> token mapping

# -------------------------------
# 3. Define Model Components
# -------------------------------
class PersistentMemory(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.persistent = nn.Parameter(torch.randn(config.N_p, config.d_model))
    
    def forward(self, batch_size):
        return self.persistent.unsqueeze(0).expand(batch_size, -1, -1)

class TitanMemory(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model
        self.register_buffer("M", torch.eye(config.d_model))
        self.register_buffer("S", torch.zeros(config.d_model, config.d_model))
        self.query = nn.Linear(config.d_model, config.d_model, bias=False)
        self.key = nn.Linear(config.d_model, config.d_model, bias=False)
        self.value = nn.Linear(config.d_model, config.d_model, bias=False)
        self.alpha = config.alpha
        self.eta = config.eta
        self.theta = config.theta

    def forward(self, x):
        q = self.query(x)
        return torch.matmul(q, self.M)

class SlidingWindowAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim=config.d_model, num_heads=config.n_heads, batch_first=True)

    def forward(self, x):
        return self.attention(x, x, x)[0]

class TitanMAG(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.long_memory = TitanMemory(config)
        self.attn_layers = nn.ModuleList([SlidingWindowAttention(config) for _ in range(config.n_layers)])
        self.persistent = PersistentMemory(config)
        self.layernorm = nn.LayerNorm(config.d_model)

    def forward(self, x):
        batch_size, seq_len, d_model = x.size()
        persistent_tokens = self.persistent(batch_size)
        out = torch.cat([persistent_tokens, x], dim=1)
        for layer in self.attn_layers:
            out = layer(out)
        memory_retrieval = self.long_memory(out)
        combined = self.layernorm(out) * self.layernorm(memory_retrieval)
        return combined[:, -seq_len:, :]

class TitanMAGLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_embedding = nn.Parameter(torch.randn(config.seq_len, config.d_model))
        self.titan = TitanMAG(config)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size)

    def forward(self, x):
        emb = self.embedding(x) + self.pos_embedding[: x.size(1), :].unsqueeze(0)
        return self.lm_head(self.titan(emb))

    def generate(self, prompt, max_length=100, k=10):
        self.eval()
        generated = prompt.copy()
        with torch.no_grad():
            for _ in range(max_length):
                input_ids = torch.tensor([generated[-self.config.seq_len:]], dtype=torch.long).to(next(self.parameters()).device)
                logits = self.forward(input_ids)[0, -1, :]
                topk_logits, topk_indices = torch.topk(logits, k)
                probs = F.softmax(topk_logits, dim=-1)
                next_token = topk_indices[torch.multinomial(probs, num_samples=1)].item()
                generated.append(next_token)
                if next_token == self.config.eos_token_id:
                    break
        return generated

# -------------------------------
# 4. Load the Model and Checkpoint
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = TitanConfig(vocab_size=vocab_size)
config.bos_token_id = vocab["<bos>"]
config.eos_token_id = vocab["<eos>"]

model = TitanMAGLM(config).to(device)
checkpoint_path = "titan_checkpoint-3.pth"
print(f"Loading checkpoint from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

# -------------------------------
# 5. Chatbot Inference Function
# -------------------------------
def generate_text(prompt, model, vocab, itos, max_length=100, k=10):
    tokens = simple_tokenizer(prompt)
    prompt_ids = [vocab["<bos>"]] + [vocab.get(token, vocab["<unk>"]) for token in tokens]
    generated_ids = model.generate(prompt_ids, max_length=max_length, k=k)
    return " ".join([itos.get(i, "<unk>") for i in generated_ids])

# -------------------------------
# 6. Interactive Chatbot Loop
# -------------------------------
print("Chatbot is ready! Type 'exit' or 'quit' to stop.")
while True:
    user_input = input("User: ")
    if user_input.lower() in ["exit", "quit"]:
        break
    response = generate_text(user_input, model, vocab, itos, max_length=100, k=10)
    print("Bot:", response)


Loading dataset...
Building vocabulary...
Loading checkpoint from titan_checkpoint-3.pth...


  checkpoint = torch.load(checkpoint_path, map_location=device)


RuntimeError: Error(s) in loading state_dict for TitanMAGLM:
	size mismatch for embedding.weight: copying a param with shape torch.Size([26000, 256]) from checkpoint, the shape in current model is torch.Size([11674, 256]).
	size mismatch for lm_head.weight: copying a param with shape torch.Size([26000, 256]) from checkpoint, the shape in current model is torch.Size([11674, 256]).
	size mismatch for lm_head.bias: copying a param with shape torch.Size([26000]) from checkpoint, the shape in current model is torch.Size([11674]).