In [1]:
!pip install rouge-score sacrebleu sentencepiece -q

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for rouge-score (setup.py) ... [?25l[?25hdone


In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import math
from tqdm import tqdm
import sentencepiece as spm
import warnings
warnings.filterwarnings('ignore')


# Positional encoding - classic Vaswani et al. approach
class PositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=5000):
        super().__init__()
        
        pe = torch.zeros(max_len, dim)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        
        # using the formula from "Attention is All You Need"
        div = torch.exp(torch.arange(0, dim, 2).float() * -(math.log(10000.0) / dim))
        
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # x shape: (batch, seq_len, dim)
        return x + self.pe[:, :x.size(1)]


# Standard scaled dot-product attention
def attention(q, k, v, mask=None, dropout=None):
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    attn_weights = torch.softmax(scores, dim=-1)
    
    if dropout is not None:
        attn_weights = dropout(attn_weights)
    
    output = torch.matmul(attn_weights, v)
    return output


class MultiHeadAttention(nn.Module):
    def __init__(self, dim, n_heads, dropout_p=0.1):
        super().__init__()
        assert dim % n_heads == 0
        
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        
        # linear projections
        self.q_linear = nn.Linear(dim, dim)
        self.k_linear = nn.Linear(dim, dim)
        self.v_linear = nn.Linear(dim, dim)
        self.out = nn.Linear(dim, dim)
        
        self.dropout = nn.Dropout(dropout_p)
    
    def forward(self, q, k, v, mask=None):
        bs = q.size(0)
        
        # project and split into heads
        q = self.q_linear(q).view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_linear(k).view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_linear(v).view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
        
        # apply attention
        scores = attention(q, k, v, mask, self.dropout)
        
        # concatenate heads
        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.dim)
        
        output = self.out(concat)
        return output


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout_p=0.1):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim)
        self.w2 = nn.Linear(hidden_dim, dim)
        self.dropout = nn.Dropout(dropout_p)
    
    def forward(self, x):
        return self.w2(self.dropout(torch.relu(self.w1(x))))


class EncoderBlock(nn.Module):
    def __init__(self, dim, n_heads, ff_dim, dropout_p=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(dim, n_heads, dropout_p)
        self.ff = FeedForward(dim, ff_dim, dropout_p)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.dropout1 = nn.Dropout(dropout_p)
        self.dropout2 = nn.Dropout(dropout_p)
    
    def forward(self, x, mask):
        # self-attention with residual
        attn_out = self.attn(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_out))
        
        # feed-forward with residual
        ff_out = self.ff(x)
        x = self.norm2(x + self.dropout2(ff_out))
        
        return x


class DecoderBlock(nn.Module):
    def __init__(self, dim, n_heads, ff_dim, dropout_p=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(dim, n_heads, dropout_p)
        self.cross_attn = MultiHeadAttention(dim, n_heads, dropout_p)
        self.ff = FeedForward(dim, ff_dim, dropout_p)
        
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        
        self.dropout1 = nn.Dropout(dropout_p)
        self.dropout2 = nn.Dropout(dropout_p)
        self.dropout3 = nn.Dropout(dropout_p)
    
    def forward(self, x, enc_out, src_mask, tgt_mask):
        # masked self-attention
        self_attn_out = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout1(self_attn_out))
        
        # cross-attention to encoder
        cross_attn_out = self.cross_attn(x, enc_out, enc_out, src_mask)
        x = self.norm2(x + self.dropout2(cross_attn_out))
        
        # feed-forward
        ff_out = self.ff(x)
        x = self.norm3(x + self.dropout3(ff_out))
        
        return x


class Transformer(nn.Module):
    def __init__(self, vocab_size, dim=128, n_heads=4, n_enc_layers=2, 
                 n_dec_layers=2, ff_dim=512, max_len=512, dropout_p=0.1):
        super().__init__()
        
        self.dim = dim
        self.vocab_size = vocab_size
        
        # embeddings
        self.src_embed = nn.Embedding(vocab_size, dim)
        self.tgt_embed = nn.Embedding(vocab_size, dim)
        self.pos_enc = PositionalEncoding(dim, max_len)
        
        # encoder stack
        self.enc_layers = nn.ModuleList([
            EncoderBlock(dim, n_heads, ff_dim, dropout_p) 
            for _ in range(n_enc_layers)
        ])
        
        # decoder stack
        self.dec_layers = nn.ModuleList([
            DecoderBlock(dim, n_heads, ff_dim, dropout_p) 
            for _ in range(n_dec_layers)
        ])
        
        self.dropout = nn.Dropout(dropout_p)
        self.final_layer = nn.Linear(dim, vocab_size)
        
        self._reset_parameters()
    
    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def make_src_mask(self, src):
        # src: (batch, src_len)
        mask = (src != 0).unsqueeze(1).unsqueeze(2)
        return mask
    
    def make_tgt_mask(self, tgt):
        # tgt: (batch, tgt_len)
        batch_size, tgt_len = tgt.size()
        
        # padding mask
        tgt_pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        
        # causal mask (no peeking)
        tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()
        tgt_mask = tgt_pad_mask & tgt_sub_mask
        
        return tgt_mask
    
    def encode(self, src, src_mask):
        x = self.src_embed(src) * math.sqrt(self.dim)
        x = self.pos_enc(x)
        x = self.dropout(x)
        
        for layer in self.enc_layers:
            x = layer(x, src_mask)
        
        return x
    
    def decode(self, tgt, enc_out, src_mask, tgt_mask):
        x = self.tgt_embed(tgt) * math.sqrt(self.dim)
        x = self.pos_enc(x)
        x = self.dropout(x)
        
        for layer in self.dec_layers:
            x = layer(x, enc_out, src_mask, tgt_mask)
        
        return x
    
    def forward(self, src, tgt):
        src_mask = self.make_src_mask(src)
        tgt_mask = self.make_tgt_mask(tgt)
        
        enc_out = self.encode(src, src_mask)
        dec_out = self.decode(tgt, enc_out, src_mask, tgt_mask)
        
        logits = self.final_layer(dec_out)
        return logits


# Dataset handling
class QADataset(Dataset):
    def __init__(self, questions, answers, tokenizer, max_len=128):
        self.questions = questions
        self.answers = answers
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        q_ids = self.tokenizer.encode_as_ids(self.questions[idx])
        a_ids = self.tokenizer.encode_as_ids(self.answers[idx])
        
        # trim if too long
        q_ids = q_ids[:self.max_len-2]
        a_ids = a_ids[:self.max_len-2]
        
        # add BOS and EOS tokens (1 and 2)
        q = torch.tensor([1] + q_ids + [2])
        a = torch.tensor([1] + a_ids + [2])
        
        return q, a


def collate_batch(batch):
    qs, ans = zip(*batch)
    qs_padded = pad_sequence(qs, batch_first=True, padding_value=0)
    ans_padded = pad_sequence(ans, batch_first=True, padding_value=0)
    return qs_padded, ans_padded


# Training utilities
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    n_batches = 0
    
    for src, tgt in tqdm(loader, desc="Training"):
        src, tgt = src.to(device), tgt.to(device)
        
        # teacher forcing
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        
        optimizer.zero_grad()
        
        logits = model(src, tgt_input)
        
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_output.reshape(-1))
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        n_batches += 1
    
    return total_loss / n_batches


def validate_model(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    n_batches = 0
    
    with torch.no_grad():
        for src, tgt in tqdm(loader, desc="Validation"):
            src, tgt = src.to(device), tgt.to(device)
            
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]
            
            logits = model(src, tgt_input)
            loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_output.reshape(-1))
            
            total_loss += loss.item()
            n_batches += 1
    
    return total_loss / n_batches


# Inference - greedy decoding
def generate_answer(model, question, tokenizer, device, max_len=80):
    model.eval()
    
    # tokenize question
    q_ids = tokenizer.encode_as_ids(question)
    q_ids = [1] + q_ids[:126] + [2]  # BOS + tokens + EOS
    src = torch.tensor([q_ids], device=device)
    
    # encode
    src_mask = model.make_src_mask(src)
    enc_out = model.encode(src, src_mask)
    
    # start decoding with BOS token
    tgt_ids = [1]
    
    for _ in range(max_len):
        tgt = torch.tensor([tgt_ids], device=device)
        tgt_mask = model.make_tgt_mask(tgt)
        
        dec_out = model.decode(tgt, enc_out, src_mask, tgt_mask)
        logits = model.final_layer(dec_out)
        
        # greedy: pick most probable token
        next_token = logits[0, -1].argmax().item()
        tgt_ids.append(next_token)
        
        # stop at EOS or max length
        if next_token == 2:
            break
    
    # decode to text
    answer_ids = [i for i in tgt_ids if i not in [0, 1, 2]]
    answer = tokenizer.decode_ids(answer_ids)
    
    return answer


# Main training loop
def main():
    # paths
    qa_path = '/kaggle/input/nlp-a-2/qa_gk.csv'
    tokenizer_path = '/kaggle/input/nlp-a-2/urdu_tokenizer.model'
    
    print("Loading data...")
    df = pd.read_csv(qa_path, encoding='utf-8')
    
    questions = df['Question'].dropna().astype(str).tolist()
    answers = df['Answer'].dropna().astype(str).tolist()
    
    # match lengths
    min_len = min(len(questions), len(answers))
    questions = questions[:min_len]
    answers = answers[:min_len]
    
    print(f"Loaded {len(questions)} QA pairs")
    
    # load tokenizer
    print("Loading tokenizer...")
    tokenizer = spm.SentencePieceProcessor()
    tokenizer.load(tokenizer_path)
    vocab_size = len(tokenizer)
    print(f"Vocab size: {vocab_size}")
    
    # train/val split
    split_idx = int(len(questions) * 0.9)
    train_q, train_a = questions[:split_idx], answers[:split_idx]
    val_q, val_a = questions[split_idx:], answers[split_idx:]
    
    print(f"Train: {len(train_q)}, Val: {len(val_q)}")
    
    # datasets
    train_ds = QADataset(train_q, train_a, tokenizer)
    val_ds = QADataset(val_q, val_a, tokenizer)
    
    train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, collate_fn=collate_batch)
    val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, collate_fn=collate_batch)
    
    # model setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    model = Transformer(
        vocab_size=vocab_size,
        dim=128,
        n_heads=4,
        n_enc_layers=2,
        n_dec_layers=2,
        ff_dim=512,
        dropout_p=0.2
    ).to(device)
    
    n_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {n_params:,}")
    
    criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)
    optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)
    
    # training
    best_val_loss = float('inf')
    patience = 10
    patience_count = 0
    n_epochs = 100
    
    print("\nStarting training...")
    
    for epoch in range(n_epochs):
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss = validate_model(model, val_loader, criterion, device)
        scheduler.step()
        
        print(f"Epoch {epoch+1:3d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}", end="")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_count = 0
            torch.save(model.state_dict(), '/kaggle/working/urdu_chatbot.pt')
            print(" [saved]")
        else:
            patience_count += 1
            print()
            
            if patience_count >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    print(f"\nTraining finished. Best val loss: {best_val_loss:.4f}")
    
    # test some questions
    print("\nTesting inference:")
    test_questions = [
        "پاکستان کا دارالحکومت کیا ہے؟",
        "اسلام آباد کہاں واقع ہے؟",
        "پاکستان کی کرنسی کیا ہے؟"
    ]
    
    for q in test_questions:
        answer = generate_answer(model, q, tokenizer, device)
        print(f"\nQ: {q}")
        print(f"A: {answer}")
    
    return model, tokenizer, device


if __name__ == "__main__":
    model, tokenizer, device = main()

Loading QA dataset...
Loaded 5816 QA pairs
Sample Q: پاکستان میں بہترین یونیورسٹیوں کا نام؟
Sample A: LUMS اور AKU پاکستان میں بہترین یونیورسٹی ہیں

Loading tokenizer...
Vocab size: 16000

Train: 5234, Val: 582

Device: cuda

Model parameters: 7,085,696

Training starts...



Training: 100%|██████████| 655/655 [00:14<00:00, 44.50it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 160.55it/s]


Epoch   1 | Train: 6.2724 | Val: 5.7475 [SAVED]


Training: 100%|██████████| 655/655 [00:12<00:00, 52.42it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 181.47it/s]


Epoch   2 | Train: 5.2041 | Val: 5.3878 [SAVED]


Training: 100%|██████████| 655/655 [00:12<00:00, 53.28it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 170.25it/s]


Epoch   3 | Train: 4.6723 | Val: 5.1254 [SAVED]


Training: 100%|██████████| 655/655 [00:12<00:00, 53.08it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 185.93it/s]


Epoch   4 | Train: 4.2604 | Val: 4.9628 [SAVED]


Training: 100%|██████████| 655/655 [00:12<00:00, 52.96it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 186.78it/s]


Epoch   5 | Train: 3.9537 | Val: 4.8656 [SAVED]


Training: 100%|██████████| 655/655 [00:12<00:00, 53.16it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 186.28it/s]


Epoch   6 | Train: 3.6956 | Val: 4.8071 [SAVED]


Training: 100%|██████████| 655/655 [00:12<00:00, 52.36it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 182.44it/s]


Epoch   7 | Train: 3.4900 | Val: 4.7904 [SAVED]


Training: 100%|██████████| 655/655 [00:12<00:00, 53.15it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 183.95it/s]


Epoch   8 | Train: 3.2946 | Val: 4.7898 [SAVED]


Training: 100%|██████████| 655/655 [00:12<00:00, 52.17it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 186.36it/s]


Epoch   9 | Train: 3.1393 | Val: 4.7709 [SAVED]


Training: 100%|██████████| 655/655 [00:12<00:00, 52.60it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 168.87it/s]


Epoch  10 | Train: 3.0106 | Val: 4.8385


Evaluating:  41%|████      | 30/73 [00:01<00:01, 25.15it/s]


  BLEU: 13.22



Training: 100%|██████████| 655/655 [00:12<00:00, 52.59it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 186.69it/s]


Epoch  11 | Train: 2.8855 | Val: 4.8085


Training: 100%|██████████| 655/655 [00:12<00:00, 53.41it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 181.31it/s]


Epoch  12 | Train: 2.7996 | Val: 4.8527


Training: 100%|██████████| 655/655 [00:12<00:00, 53.56it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 187.10it/s]


Epoch  13 | Train: 2.7178 | Val: 4.8542


Training: 100%|██████████| 655/655 [00:12<00:00, 53.79it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 187.21it/s]


Epoch  14 | Train: 2.6510 | Val: 4.8856


Training: 100%|██████████| 655/655 [00:12<00:00, 52.72it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 135.45it/s]


Epoch  15 | Train: 2.5932 | Val: 4.8787


Training: 100%|██████████| 655/655 [00:12<00:00, 52.35it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 184.09it/s]


Epoch  16 | Train: 2.5506 | Val: 4.9054


Training: 100%|██████████| 655/655 [00:12<00:00, 53.21it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 182.36it/s]


Epoch  17 | Train: 2.5012 | Val: 4.9124


Training: 100%|██████████| 655/655 [00:12<00:00, 54.03it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 184.31it/s]


Epoch  18 | Train: 2.4654 | Val: 4.9242


Training: 100%|██████████| 655/655 [00:12<00:00, 52.79it/s]
Validating: 100%|██████████| 73/73 [00:00<00:00, 188.56it/s]


Epoch  19 | Train: 2.4347 | Val: 4.9539
Early stopping at epoch 19

Training complete. Best loss: 4.7709

Testing on sample questions:
Q: پاکستان میں بہترین یونیورسٹیوں کا نام؟
A: مجھے افسوس ہے کہ میں نہیں جانتا، براہ مہربانی زیادہ مخصوص رہیں 

Q: اسلام آباد کیا ہے؟
A: اسلام آباد پاکستان کا دارالحکومت ہے 

Q: پاکستان کی کرنسی کیا ہے؟
A: روپیہ پاکستان کی کرنسی ہے 



In [None]:
import torch
import torch.nn as nn
import sentencepiece as spm
import math


# Same model architecture as training
class PositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, dim)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, dim, 2).float() * -(math.log(10000.0) / dim))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


def attention(q, k, v, mask=None, dropout=None):
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    attn_weights = torch.softmax(scores, dim=-1)
    if dropout is not None:
        attn_weights = dropout(attn_weights)
    return torch.matmul(attn_weights, v)


class MultiHeadAttention(nn.Module):
    def __init__(self, dim, n_heads, dropout_p=0.1):
        super().__init__()
        assert dim % n_heads == 0
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.q_linear = nn.Linear(dim, dim)
        self.k_linear = nn.Linear(dim, dim)
        self.v_linear = nn.Linear(dim, dim)
        self.out = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout_p)
    
    def forward(self, q, k, v, mask=None):
        bs = q.size(0)
        q = self.q_linear(q).view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_linear(k).view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_linear(v).view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
        scores = attention(q, k, v, mask, self.dropout)
        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.dim)
        return self.out(concat)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout_p=0.1):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim)
        self.w2 = nn.Linear(hidden_dim, dim)
        self.dropout = nn.Dropout(dropout_p)
    
    def forward(self, x):
        return self.w2(self.dropout(torch.relu(self.w1(x))))


class EncoderBlock(nn.Module):
    def __init__(self, dim, n_heads, ff_dim, dropout_p=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(dim, n_heads, dropout_p)
        self.ff = FeedForward(dim, ff_dim, dropout_p)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.dropout1 = nn.Dropout(dropout_p)
        self.dropout2 = nn.Dropout(dropout_p)
    
    def forward(self, x, mask):
        attn_out = self.attn(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.dropout2(ff_out))
        return x


class DecoderBlock(nn.Module):
    def __init__(self, dim, n_heads, ff_dim, dropout_p=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(dim, n_heads, dropout_p)
        self.cross_attn = MultiHeadAttention(dim, n_heads, dropout_p)
        self.ff = FeedForward(dim, ff_dim, dropout_p)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.dropout1 = nn.Dropout(dropout_p)
        self.dropout2 = nn.Dropout(dropout_p)
        self.dropout3 = nn.Dropout(dropout_p)
    
    def forward(self, x, enc_out, src_mask, tgt_mask):
        self_attn_out = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout1(self_attn_out))
        cross_attn_out = self.cross_attn(x, enc_out, enc_out, src_mask)
        x = self.norm2(x + self.dropout2(cross_attn_out))
        ff_out = self.ff(x)
        x = self.norm3(x + self.dropout3(ff_out))
        return x


class Transformer(nn.Module):
    def __init__(self, vocab_size, dim=128, n_heads=4, n_enc_layers=2, 
                 n_dec_layers=2, ff_dim=512, max_len=512, dropout_p=0.1):
        super().__init__()
        self.dim = dim
        self.vocab_size = vocab_size
        self.src_embed = nn.Embedding(vocab_size, dim)
        self.tgt_embed = nn.Embedding(vocab_size, dim)
        self.pos_enc = PositionalEncoding(dim, max_len)
        self.enc_layers = nn.ModuleList([
            EncoderBlock(dim, n_heads, ff_dim, dropout_p) 
            for _ in range(n_enc_layers)
        ])
        self.dec_layers = nn.ModuleList([
            DecoderBlock(dim, n_heads, ff_dim, dropout_p) 
            for _ in range(n_dec_layers)
        ])
        self.dropout = nn.Dropout(dropout_p)
        self.final_layer = nn.Linear(dim, vocab_size)
    
    def make_src_mask(self, src):
        mask = (src != 0).unsqueeze(1).unsqueeze(2)
        return mask
    
    def make_tgt_mask(self, tgt):
        batch_size, tgt_len = tgt.size()
        tgt_pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()
        tgt_mask = tgt_pad_mask & tgt_sub_mask
        return tgt_mask
    
    def encode(self, src, src_mask):
        x = self.src_embed(src) * math.sqrt(self.dim)
        x = self.pos_enc(x)
        x = self.dropout(x)
        for layer in self.enc_layers:
            x = layer(x, src_mask)
        return x
    
    def decode(self, tgt, enc_out, src_mask, tgt_mask):
        x = self.tgt_embed(tgt) * math.sqrt(self.dim)
        x = self.pos_enc(x)
        x = self.dropout(x)
        for layer in self.dec_layers:
            x = layer(x, enc_out, src_mask, tgt_mask)
        return x
    
    def forward(self, src, tgt):
        src_mask = self.make_src_mask(src)
        tgt_mask = self.make_tgt_mask(tgt)
        enc_out = self.encode(src, src_mask)
        dec_out = self.decode(tgt, enc_out, src_mask, tgt_mask)
        logits = self.final_layer(dec_out)
        return logits


class UrduQABot:
    """Easy-to-use wrapper for inference"""
    
    def __init__(self, model_path, tokenizer_path, device='cuda'):
        print("Loading tokenizer...")
        self.tokenizer = spm.SentencePieceProcessor()
        self.tokenizer.load(tokenizer_path)
        self.vocab_size = len(self.tokenizer)
        
        print("Loading model...")
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        
        self.model = Transformer(
            vocab_size=self.vocab_size,
            dim=128,
            n_heads=4,
            n_enc_layers=2,
            n_dec_layers=2,
            ff_dim=512,
            dropout_p=0.2
        ).to(self.device)
        
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.eval()
        
        print(f"Model loaded on {self.device}")
    
    def ask(self, question, max_len=80):
        """Ask a question and get an answer"""
        self.model.eval()
        
        # tokenize
        q_ids = self.tokenizer.encode_as_ids(question)
        q_ids = [1] + q_ids[:126] + [2]
        src = torch.tensor([q_ids], device=self.device)
        
        # encode
        src_mask = self.model.make_src_mask(src)
        enc_out = self.model.encode(src, src_mask)
        
        # decode
        tgt_ids = [1]  # start with BOS
        
        with torch.no_grad():
            for _ in range(max_len):
                tgt = torch.tensor([tgt_ids], device=self.device)
                tgt_mask = self.model.make_tgt_mask(tgt)
                
                dec_out = self.model.decode(tgt, enc_out, src_mask, tgt_mask)
                logits = self.model.final_layer(dec_out)
                
                next_token = logits[0, -1].argmax().item()
                tgt_ids.append(next_token)
                
                if next_token == 2:  # EOS
                    break
        
        # decode to text
        answer_ids = [i for i in tgt_ids if i not in [0, 1, 2]]
        answer = self.tokenizer.decode_ids(answer_ids)
        
        return answer
    
    def chat(self):
        """Interactive chat mode"""
        print("\n=== Urdu QA Chatbot ===")
        print("Type 'quit' or 'exit' to stop\n")
        
        while True:
            question = input("You: ").strip()
            
            if question.lower() in ['quit', 'exit', 'q']:
                print("Goodbye!")
                break
            
            if not question:
                continue
            
            answer = self.ask(question)
            print(f"Bot: {answer}\n")


# Beam search for better quality (optional)
def beam_search(model, src, tokenizer, device, beam_size=3, max_len=80):
    """Beam search decoding for potentially better answers"""
    model.eval()
    
    src_mask = model.make_src_mask(src)
    enc_out = model.encode(src, src_mask)
    
    # initialize beam with BOS token
    beams = [([1], 0.0)]  # (sequence, score)
    
    with torch.no_grad():
        for step in range(max_len):
            candidates = []
            
            for seq, score in beams:
                if seq[-1] == 2:  # already finished
                    candidates.append((seq, score))
                    continue
                
                tgt = torch.tensor([seq], device=device)
                tgt_mask = model.make_tgt_mask(tgt)
                
                dec_out = model.decode(tgt, enc_out, src_mask, tgt_mask)
                logits = model.final_layer(dec_out)
                log_probs = torch.log_softmax(logits[0, -1], dim=-1)
                
                # get top k tokens
                top_probs, top_ids = torch.topk(log_probs, beam_size)
                
                for prob, idx in zip(top_probs, top_ids):
                    new_seq = seq + [idx.item()]
                    new_score = score + prob.item()
                    candidates.append((new_seq, new_score))
            
            # keep top beam_size candidates
            beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_size]
            
            # stop if all beams finished
            if all(seq[-1] == 2 for seq, _ in beams):
                break
    
    # return best sequence
    best_seq = beams[0][0]
    answer_ids = [i for i in best_seq if i not in [0, 1, 2]]
    answer = tokenizer.decode_ids(answer_ids)
    
    return answer


def main():
    # paths - update these
    model_path = '/kaggle/working/urdu_chatbot.pt'
    tokenizer_path = '/kaggle/input/nlp-a-2/urdu_tokenizer.model'
    
    # create bot
    bot = UrduQABot(model_path, tokenizer_path)
    
    # test some questions
    test_questions = [
        "پاکستان کا دارالحکومت کیا ہے؟",
        "اسلام آباد کہاں ہے؟",
        "پاکستان کی کرنسی کیا ہے؟",
        "لاہور کے بارے میں بتائیں؟"
    ]
    
    print("\n=== Testing Questions ===\n")
    for q in test_questions:
        ans = bot.ask(q)
        print(f"Q: {q}")
        print(f"A: {ans}\n")
    
    # uncomment to start interactive chat
    # bot.chat()


# Alternative: Load and use model directly
def quick_inference():
    """Quick inference without wrapper class"""
    
    # setup
    tokenizer = spm.SentencePieceProcessor()
    tokenizer.load('/kaggle/input/nlp-a-2/urdu_tokenizer.model')
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = Transformer(
        vocab_size=len(tokenizer),
        dim=128,
        n_heads=4,
        n_enc_layers=2,
        n_dec_layers=2,
        ff_dim=512,
        dropout_p=0.2
    ).to(device)
    
    model.load_state_dict(torch.load('/kaggle/working/urdu_chatbot.pt', map_location=device))
    model.eval()
    
    # ask question
    question = "پاکستان کے بارے میں بتائیں؟"
    
    q_ids = [1] + tokenizer.encode_as_ids(question)[:126] + [2]
    src = torch.tensor([q_ids], device=device)
    
    src_mask = model.make_src_mask(src)
    enc_out = model.encode(src, src_mask)
    
    tgt_ids = [1]
    
    with torch.no_grad():
        for _ in range(80):
            tgt = torch.tensor([tgt_ids], device=device)
            tgt_mask = model.make_tgt_mask(tgt)
            dec_out = model.decode(tgt, enc_out, src_mask, tgt_mask)
            logits = model.final_layer(dec_out)
            next_token = logits[0, -1].argmax().item()
            tgt_ids.append(next_token)
            if next_token == 2:
                break
    
    answer_ids = [i for i in tgt_ids if i not in [0, 1, 2]]
    answer = tokenizer.decode_ids(answer_ids)
    
    print(f"Q: {question}")
    print(f"A: {answer}")


if __name__ == "__main__":
    main()
    