In [None]:
import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerFast, PretrainedConfig, PreTrainedModel
from datasets import load_dataset

# ==========================================
# 1. ARCHITECTURE DEFINITION
# ==========================================

class KonkanSmallConfig(PretrainedConfig):
    model_type = "konkangpt"
    def __init__(self, vocab_size=32000, d_model=768, n_layers=12, n_heads=12, 
                 d_ff=3072, max_len=1024, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.d_ff = d_ff
        self.max_len = max_len
        self.dropout = dropout

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=2048):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
    def forward(self, x, seq_len):
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb.cos(), emb.sin()

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(x, cos, sin):
    cos = cos[:x.shape[-2], :].unsqueeze(0).unsqueeze(0)
    sin = sin[:x.shape[-2], :].unsqueeze(0).unsqueeze(0)
    return (x * cos) + (rotate_half(x) * sin)

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

class KonkanBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.head_dim = config.d_model // config.n_heads
        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.gate_up_proj = nn.Linear(config.d_model, 2 * config.d_ff, bias=False)
        self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.input_layernorm = RMSNorm(config.d_model)
        self.post_attention_layernorm = RMSNorm(config.d_model)
        self.act = SwiGLU()

    def forward(self, x, cos, sin, mask):
        residual = x
        x = self.input_layernorm(x)
        b, t, c = x.shape
        q = self.q_proj(x).reshape(b, t, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).reshape(b, t, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).reshape(b, t, self.n_heads, self.head_dim).transpose(1, 2)
        q = apply_rotary_pos_emb(q, cos, sin)
        k = apply_rotary_pos_emb(k, cos, sin)
        y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
        y = y.transpose(1, 2).contiguous().reshape(b, t, c)
        x = residual + self.o_proj(y)
        x = x + self.down_proj(self.act(self.gate_up_proj(self.post_attention_layernorm(x))))
        return x

class KonkanGPT(PreTrainedModel):
    config_class = KonkanSmallConfig
    def __init__(self, config):
        super().__init__(config)
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.rope = RotaryEmbedding(config.d_model // config.n_heads, config.max_len)
        self.layers = nn.ModuleList([KonkanBlock(config) for _ in range(config.n_layers)])
        self.norm = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight
        self.post_init()

    def forward(self, input_ids, labels=None, **kwargs):
        b, t = input_ids.shape
        cos, sin = self.rope(input_ids, t)
        mask = torch.tril(torch.ones(t, t, device=input_ids.device)).view(1, 1, t, t).bool()
        x = self.token_emb(input_ids)
        for layer in self.layers:
            x = layer(x, cos, sin, mask)
        logits = self.head(self.norm(x))
        loss = None
        if labels is not None:
            # FIX: Using .reshape() to avoid stride compatibility errors
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))
        return {"loss": loss, "logits": logits}

# ==========================================
# 2. CHECKPOINT & DATA MANAGER
# ==========================================

class PitstopManager:
    def __init__(self, save_dir="pitstops", max_to_keep=2):
        self.save_dir = save_dir
        self.max_to_keep = max_to_keep
        os.makedirs(save_dir, exist_ok=True)

    def save(self, model, optimizer, scaler, epoch, step):
        raw_model = getattr(model, "_orig_mod", model)
        checkpoint = {
            'epoch': epoch, 'step': step,
            'model_state_dict': raw_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
        }
        final_path = os.path.join(self.save_dir, f"pitstop_step_{step}.pt")
        torch.save(checkpoint, final_path)
        ckpts = sorted(glob.glob(os.path.join(self.save_dir, "pitstop_step_*.pt")), key=os.path.getmtime)
        while len(ckpts) > self.max_to_keep:
            os.remove(ckpts.pop(0))

    def load_latest(self, model, optimizer, scaler):
        ckpts = sorted(glob.glob(os.path.join(self.save_dir, "pitstop_step_*.pt")), key=os.path.getmtime)
        if ckpts:
            latest_path = ckpts[-1]
            print(f"üîÑ Resuming from: {latest_path}")
            ckpt = torch.load(latest_path, map_location="cpu")
            model.load_state_dict(ckpt['model_state_dict'])
            optimizer.load_state_dict(ckpt['optimizer_state_dict'])
            scaler.load_state_dict(ckpt['scaler_state_dict'])
            return ckpt['epoch'], ckpt['step']
        return 0, 0

def pack_dataset(tokenizer, data_path, max_seq_len=1024):
    print("üì¶ Packing Dataset (Constant Length Training)...")
    ds = load_dataset("text", data_files={"train": data_path}, split="train")
    tokenized = ds.map(lambda x: tokenizer(x["text"]), batched=True, remove_columns=["text"], num_proc=4)
    all_ids = []
    for ids in tokenized["input_ids"]:
        all_ids.extend(ids + [tokenizer.eos_token_id])
    chunk_size = max_seq_len + 1
    total_chunks = len(all_ids) // chunk_size
    packed_data = torch.tensor(all_ids[:total_chunks * chunk_size]).reshape(total_chunks, chunk_size)
    print(f"‚úÖ Created {len(packed_data)} dense blocks.")
    return packed_data

class PackedDataset(torch.utils.data.Dataset):
    def __init__(self, tensor_data): self.data = tensor_data
    def __len__(self): return len(self.data)
    def __getitem__(self, i): return {"input_ids": self.data[i]}

# ==========================================
# 3. TRAINING ENGINE
# ==========================================

def train_pure_power():
    device = "cuda"
    TOKEN_DIR = "konkani-tokenizer-v3-32k"
    DATA_PATH = "/kaggle/input/konkani-book-corpus/konkani_book_corpus.txt"
    
    tokenizer = PreTrainedTokenizerFast.from_pretrained(TOKEN_DIR)
    model = KonkanGPT(KonkanSmallConfig(vocab_size=len(tokenizer))).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
    scaler = torch.amp.GradScaler('cuda')
    pitstop_manager = PitstopManager("pitstops")

    # Load checkpoint if exists
    start_epoch, start_step = pitstop_manager.load_latest(model, optimizer, scaler)

    # DRY RUN
    print("üß™ Dry Run...")
    test_ids = torch.randint(0, 100, (2, 129)).to(device)
    with torch.no_grad():
        _ = model(test_ids[:, :-1].contiguous(), labels=test_ids[:, 1:].contiguous())
    print("‚úÖ Verified.")

    # Data
    packed_tensor = pack_dataset(tokenizer, DATA_PATH)
    train_loader = DataLoader(PackedDataset(packed_tensor), batch_size=4, shuffle=True)

    model = torch.compile(model) 
    accum_steps = 16 
    model.train()

    for epoch in range(start_epoch, 2):
        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}")
        for step, batch in progress_bar:
            if epoch == start_epoch and step <= start_step: continue

            ids = batch["input_ids"].to(device, non_blocking=True)
            inputs, labels = ids[:, :-1].contiguous(), ids[:, 1:].contiguous()

            with torch.amp.autocast('cuda'):
                outputs = model(inputs, labels=labels)
                loss = outputs["loss"] / accum_steps
            
            scaler.scale(loss).backward()
            
            if (step + 1) % accum_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)

            if step % 10 == 0:
                progress_bar.set_postfix({"loss": f"{loss.item()*accum_steps:.4f}"})
            
            # Save checkpoint every 500 steps
            if step > 0 and step % 500 == 0:
                pitstop_manager.save(model, optimizer, scaler, epoch, step)

    torch.save(getattr(model, "_orig_mod", model).state_dict(), "konkan_160m_final_pure.pt")

if __name__ == "__main__":
    train_pure_power()

In [None]:
import torch
import torch.nn.functional as F
from transformers import PreTrainedTokenizerFast

# 1. SETUP - Match these to your training config
MODEL_PATH = "konkan_160m_final_pure.pt" # Or your latest checkpoint
TOKEN_DIR = "konkani-tokenizer-v3-32k"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# 2. LOAD TOKENIZER AND MODEL
tokenizer = PreTrainedTokenizerFast.from_pretrained(TOKEN_DIR)

# Reuse your architecture definition from the training script
# Ensure KonkanGPT and KonkanSmallConfig classes are defined above this
config = KonkanSmallConfig(vocab_size=len(tokenizer))
model = KonkanGPT(config).to(DEVICE)

# Load weights
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()

print("‚úÖ Model Loaded. Ready to generate.")

def generate(prompt, max_new_tokens=50, temperature=0.8, top_k=50):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
    
    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs["logits"][:, -1, :] / temperature
            
            # Optional: Top-K sampling to prevent "junk" tokens
            if top_k > 0:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            input_ids = torch.cat([input_ids, next_token], dim=1)
            
            if next_token.item() == tokenizer.eos_token_id:
                break
                
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

# 3. TEST IT
prompt = "‡§ó‡•ã‡§Ç‡§Ø ‡§è‡§ï" # "Goa is a..."
print(f"\nPrompt: {prompt}")
print(f"Generated: {generate(prompt)}")

In [None]:
# Try these parameters for more "story-telling" flow
print(generate("‡§ó‡•ã‡§Ç‡§Ø ‡§è‡§ï", max_new_tokens=100, temperature=0.85, top_k=40))

In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import PreTrainedTokenizerFast, PretrainedConfig, PreTrainedModel
from tqdm.auto import tqdm

# ==========================================
# 1. YOUR MODEL ARCHITECTURE (Exact Copy)
# ==========================================
# We include this so the script can load your pre-trained weights correctly.

class KonkanSmallConfig(PretrainedConfig):
    model_type = "konkangpt"
    def __init__(self, vocab_size=32000, d_model=768, n_layers=12, n_heads=12, 
                 d_ff=3072, max_len=1024, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.d_ff = d_ff
        self.max_len = max_len
        self.dropout = dropout

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=2048):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
    def forward(self, x, seq_len):
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb.cos(), emb.sin()

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(x, cos, sin):
    cos = cos[:x.shape[-2], :].unsqueeze(0).unsqueeze(0)
    sin = sin[:x.shape[-2], :].unsqueeze(0).unsqueeze(0)
    return (x * cos) + (rotate_half(x) * sin)

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

class KonkanBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.head_dim = config.d_model // config.n_heads
        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.gate_up_proj = nn.Linear(config.d_model, 2 * config.d_ff, bias=False)
        self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.input_layernorm = RMSNorm(config.d_model)
        self.post_attention_layernorm = RMSNorm(config.d_model)
        self.act = SwiGLU()

    def forward(self, x, cos, sin, mask):
        residual = x
        x = self.input_layernorm(x)
        b, t, c = x.shape
        q = self.q_proj(x).reshape(b, t, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).reshape(b, t, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).reshape(b, t, self.n_heads, self.head_dim).transpose(1, 2)
        q = apply_rotary_pos_emb(q, cos, sin)
        k = apply_rotary_pos_emb(k, cos, sin)
        y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
        y = y.transpose(1, 2).contiguous().reshape(b, t, c)
        x = residual + self.o_proj(y)
        x = x + self.down_proj(self.act(self.gate_up_proj(self.post_attention_layernorm(x))))
        return x

class KonkanGPT(PreTrainedModel):
    config_class = KonkanSmallConfig
    def __init__(self, config):
        super().__init__(config)
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.rope = RotaryEmbedding(config.d_model // config.n_heads, config.max_len)
        self.layers = nn.ModuleList([KonkanBlock(config) for _ in range(config.n_layers)])
        self.norm = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight
        self.post_init()

    def forward(self, input_ids, labels=None, **kwargs):
        b, t = input_ids.shape
        cos, sin = self.rope(input_ids, t)
        mask = torch.tril(torch.ones(t, t, device=input_ids.device)).view(1, 1, t, t).bool()
        x = self.token_emb(input_ids)
        for layer in self.layers:
            x = layer(x, cos, sin, mask)
        logits = self.head(self.norm(x))
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        return {"loss": loss, "logits": logits}

# ==========================================
# 2. SFT DATASET & MASKING LOGIC
# ==========================================

class KonkanSFTDataset(Dataset):
    def __init__(self, json_path, tokenizer, max_len=512):
        self.data = []
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        # Load JSON
        with open(json_path, 'r', encoding='utf-8') as f:
            raw_data = json.load(f) # Assuming list of dicts
            
        print(f"üîÑ Processing {len(raw_data)} SFT samples...")
        
        # Pre-format logic
        for item in raw_data:
            instr = item['instruction'].strip()
            resp = item['response'].strip()
            
            # 1. Format the string with clear delimiters
            # The \n are crucial for the model to learn structure
            full_text = f"<|user|>\n{instr}\n<|assistant|>\n{resp}<|endoftext|>"
            
            # 2. We need the length of the PROMPT ONLY to mask it later
            prompt_text = f"<|user|>\n{instr}\n<|assistant|>\n"
            
            self.data.append({
                "full_text": full_text,
                "prompt_text": prompt_text
            })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def sft_collate_fn(batch, tokenizer, device):
    """
    This is where the magic happens.
    We create 'labels' that match 'input_ids', but we set the 
    Instruction part to -100 so the model ignores it for loss calculation.
    """
    input_ids_list = []
    labels_list = []
    max_batch_len = 0
    
    for item in batch:
        # Tokenize prompt and full text
        prompt_ids = tokenizer.encode(item['prompt_text'], add_special_tokens=False)
        full_ids = tokenizer.encode(item['full_text'], add_special_tokens=False)
        
        # Check length
        if len(full_ids) > 1024:
            full_ids = full_ids[:1024]
            
        # Create Label Mask
        # Copy full_ids to labels
        labels = list(full_ids)
        
        # Set the prompt part to -100 (Ignore Index)
        # We assume prompt_ids is a prefix of full_ids
        prompt_len = len(prompt_ids)
        if prompt_len < len(labels):
            for i in range(prompt_len):
                labels[i] = -100
        
        input_ids_list.append(torch.tensor(full_ids))
        labels_list.append(torch.tensor(labels))
        max_batch_len = max(max_batch_len, len(full_ids))

    # Pad everything to the right
    padded_inputs = torch.full((len(batch), max_batch_len), tokenizer.pad_token_id, dtype=torch.long)
    padded_labels = torch.full((len(batch), max_batch_len), -100, dtype=torch.long) # Pad labels with -100
    
    for i, (ids, labs) in enumerate(zip(input_ids_list, labels_list)):
        l = len(ids)
        padded_inputs[i, :l] = ids
        padded_labels[i, :l] = labs
        
    return padded_inputs.to(device), padded_labels.to(device)


# ==========================================
# 3. TRAINING LOOP
# ==========================================

def train_sft():
    # SETUP
    device = "cuda" if torch.cuda.is_available() else "cpu"
    TOKEN_DIR = "konkani-tokenizer-v3-32k" # Path to your tokenizer folder
    PRETRAINED_MODEL_PATH = "konkan_160m_final_pure.pt" # Your Pre-trained weights
    JSON_PATH = "/kaggle/input/sft-160m/sft_dataset_ds.json"
    
    # 1. Load Tokenizer
    tokenizer = PreTrainedTokenizerFast.from_pretrained(TOKEN_DIR)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        
    # 2. Load Model Structure
    config = KonkanSmallConfig(vocab_size=len(tokenizer))
    model = KonkanGPT(config).to(device)
    
    # 3. Load Pre-trained Weights (Crucial for not forgetting)
    if os.path.exists(PRETRAINED_MODEL_PATH):
        print(f"üì• Loading Pre-trained Weights from {PRETRAINED_MODEL_PATH}...")
        state_dict = torch.load(PRETRAINED_MODEL_PATH, map_location=device)
        model.load_state_dict(state_dict, strict=False) # strict=False allows small mismatches safely
    else:
        print("‚ö†Ô∏è WARNING: Pre-trained weights not found! Training from scratch (Not Recommended for SFT).")

    # 4. Optimizer - LOW LR to prevent forgetting
    # We use 1e-5 (very small) so we don't break the grammar knowledge
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
    
    # 5. Dataset
    dataset = KonkanSFTDataset(JSON_PATH, tokenizer)
    loader = DataLoader(dataset, batch_size=4, shuffle=True, 
                        collate_fn=lambda b: sft_collate_fn(b, tokenizer, device))
    
    # 6. Training Loop
    model.train()
    EPOCHS = 3 # Small dataset, don't overdo it
    ACCUM_STEPS = 4
    
    print("üöÄ Starting SFT Training (Masked Instruction Loss)...")
    
    for epoch in range(EPOCHS):
        total_loss = 0
        progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        
        for step, (inputs, labels) in enumerate(progress_bar):
            
            # Forward Pass
            # Inputs = Full Text
            # Labels = Full Text (but Instruction is -100)
            outputs = model(inputs, labels=labels)
            loss = outputs["loss"] / ACCUM_STEPS
            
            loss.backward()
            
            if (step + 1) % ACCUM_STEPS == 0:
                # Gradient Clipping prevents "exploding" updates that ruin memory
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()
            
            total_loss += loss.item() * ACCUM_STEPS
            progress_bar.set_postfix({"loss": f"{loss.item() * ACCUM_STEPS:.4f}"})
            
    # 7. Save SFT Model
    save_path = "konkan_sft_gonyai.pt"
    torch.save(model.state_dict(), save_path)
    print(f"‚úÖ SFT Complete. Model saved to {save_path}")

# ==========================================
# 4. INFERENCE FUNCTION (TESTING)
# ==========================================
def chat_with_gonyai(instruction):
    device = "cuda"
    tokenizer = PreTrainedTokenizerFast.from_pretrained("konkani-tokenizer-v3-32k")
    config = KonkanSmallConfig(vocab_size=len(tokenizer))
    model = KonkanGPT(config).to(device)
    model.load_state_dict(torch.load("konkan_sft_gonyai.pt"))
    model.eval()
    
    # Format exactly like training data, but stop at assistant tag
    prompt = f"<|user|>\n{instruction}\n<|assistant|>\n"
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        output = model.generate(
            input_ids, 
            max_new_tokens=150, 
            temperature=0.7, 
            top_p=0.9, 
            eos_token_id=tokenizer.eos_token_id
        )
        
    decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    # Extract only the response part
    response = decoded.split("<|assistant|>")[-1].strip()
    return response

if __name__ == "__main__":
    train_sft()
    
    # Test it immediately
    print("\nüß™ Testing Gonyai:")
    print(chat_with_gonyai("‡§§‡•Ç‡§Ç ‡§ï‡•ã‡§£?"))

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedTokenizerFast

# 1. ROBUST IMPORTS FOR BASE CLASSES
try:
    from transformers import PreTrainedModel, PretrainedConfig
except ImportError:
    # Fallback for older/specific versions
    from transformers import PreTrainedModel, PreTrainedConfig as PretrainedConfig

# 2. ARCHITECTURE DEFINITION
class KonkanSmallConfig(PretrainedConfig):
    model_type = "konkangpt"
    def __init__(self, vocab_size=32000, d_model=768, n_layers=12, n_heads=12, 
                 d_ff=3072, max_len=1024, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.d_ff = d_ff
        self.max_len = max_len
        self.dropout = dropout

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=2048):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
    def forward(self, x, seq_len):
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb.cos(), emb.sin()

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(x, cos, sin):
    cos = cos[:x.shape[-2], :].unsqueeze(0).unsqueeze(0)
    sin = sin[:x.shape[-2], :].unsqueeze(0).unsqueeze(0)
    return (x * cos) + (rotate_half(x) * sin)

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

class KonkanBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.head_dim = config.d_model // config.n_heads
        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.gate_up_proj = nn.Linear(config.d_model, 2 * config.d_ff, bias=False)
        self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.input_layernorm = RMSNorm(config.d_model)
        self.post_attention_layernorm = RMSNorm(config.d_model)
        self.act = SwiGLU()

    def forward(self, x, cos, sin, mask):
        residual = x
        x = self.input_layernorm(x)
        b, t, c = x.shape
        q = self.q_proj(x).reshape(b, t, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).reshape(b, t, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).reshape(b, t, self.n_heads, self.head_dim).transpose(1, 2)
        q = apply_rotary_pos_emb(q, cos, sin)
        k = apply_rotary_pos_emb(k, cos, sin)
        y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
        y = y.transpose(1, 2).contiguous().reshape(b, t, c)
        x = residual + self.o_proj(y)
        x = x + self.down_proj(self.act(self.gate_up_proj(self.post_attention_layernorm(x))))
        return x

class KonkanGPT(PreTrainedModel):
    config_class = KonkanSmallConfig
    def __init__(self, config):
        super().__init__(config)
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.rope = RotaryEmbedding(config.d_model // config.n_heads, config.max_len)
        self.layers = nn.ModuleList([KonkanBlock(config) for _ in range(config.n_layers)])
        self.norm = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight
        self.post_init()

    def forward(self, input_ids, labels=None, **kwargs):
        b, t = input_ids.shape
        cos, sin = self.rope(input_ids, t)
        mask = torch.tril(torch.ones(t, t, device=input_ids.device)).view(1, 1, t, t).bool()
        x = self.token_emb(input_ids)
        for layer in self.layers:
            x = layer(x, cos, sin, mask)
        logits = self.head(self.norm(x))
        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        return {"loss": loss, "logits": logits}

# 3. HELPER FUNCTIONS
def load_gonyai(model_path, tokenizer_dir):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    config = KonkanSmallConfig(vocab_size=len(tokenizer))
    model = KonkanGPT(config).to(device)
    
    print(f"üì• Loading Gonyai Weights from {model_path}...")
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()
    return model, tokenizer

def generate_manual(model, tokenizer, prompt, max_new_tokens=150, temperature=0.7):
    device = next(model.parameters()).device
    full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
    input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
    
    generated = input_ids
    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(generated)
            next_token_logits = outputs["logits"][:, -1, :]
            next_token_logits = next_token_logits / temperature
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    input_length = input_ids.shape[1]
    response_tokens = generated[0][input_length:]
    return tokenizer.decode(response_tokens, skip_special_tokens=True).strip()

# 4. EXECUTION
TOKEN_DIR = "konkani-tokenizer-v3-32k"
SFT_MODEL_PATH = "konkan_sft_gonyai.pt"

# Load the model
model, tokenizer = load_gonyai(SFT_MODEL_PATH, TOKEN_DIR)

# Run the tests
test_questions = [
    "‡§§‡•Å‡§ú‡•Ä ‡§µ‡§≥‡§ñ ‡§∏‡§æ‡§Ç‡§ó ‡§Ü‡§®‡•Ä ‡§§‡•Å‡§ú‡•á‡§Ç ‡§ß‡•ç‡§Ø‡•á‡§Ø ‡§ï‡§ø‡§§‡•á‡§Ç?", 
    "‡§ó‡•ã‡§Ç‡§Ø‡§ö‡•ç‡§Ø‡§æ ‡§∂‡§ø‡§ó‡§Æ‡•ã ‡§â‡§§‡•ç‡§∏‡§µ‡§æ‡§µ‡§ø‡§∂‡•Ä‡§Ç ‡§Æ‡§æ‡§π‡§ø‡§§‡•Ä ‡§¶‡•Ä.",
    "‡§ö‡§µ‡§• ‡§∏‡§£‡§æ‡§ö‡•á‡§Ç ‡§Æ‡•ç‡§π‡§§‡•ç‡§µ ‡§∏‡§æ‡§Ç‡§ó.",
    "‡§™‡§æ‡§µ‡§∏‡§æ‡§ö‡•á‡§∞ ‡§è‡§ï ‡§∏‡•ã‡§¨‡•Ä‡§§ ‡§ï‡§µ‡§ø‡§§‡§æ ‡§¨‡§∞‡•ã‡§µ.",
    "‡§è‡§ï ‡§≤‡•ç‡§π‡§æ‡§® ‡§ï‡§æ‡§£‡•Ä ‡§∏‡§æ‡§Ç‡§ó ‡§ú‡§æ‡§§‡•Ç‡§Ç‡§§ ‡§è‡§ï ‡§∏‡§∏‡§£‡•ã ‡§Ü‡§®‡•Ä ‡§ï‡§æ‡§Ç‡§∏‡§µ ‡§Ü‡§∏‡§æ.",
    "‡§ó‡•ã‡§Ç‡§Ø‡§ö‡•á ‡§™‡§Ø‡§≤‡•á ‡§Æ‡•Å‡§ñ‡•ç‡§Ø‡§Æ‡§Ç‡§§‡•ç‡§∞‡•Ä ‡§ï‡•ã‡§£ ‡§Ü‡§∂‡§ø‡§≤‡•ç‡§≤‡•á?",
    "‡§Æ‡§æ‡§Ç‡§°‡§µ‡•Ä ‡§®‡•ç‡§π‡§Ç‡§Ø‡§ö‡•á‡§Ç ‡§ó‡•ã‡§Ç‡§Ø‡§ö‡•ç‡§Ø‡§æ ‡§ú‡•Ä‡§µ‡§ø‡§§‡§æ‡§§‡§≤‡•á‡§Ç ‡§∏‡•ç‡§•‡§æ‡§® ‡§∏‡§æ‡§Ç‡§ó.",
    "‡§ï‡•ã‡§Ç‡§ï‡§£‡•Ä ‡§≠‡§æ‡§∂‡•á‡§Ç‡§§ '‡§∏‡§¶‡§æ‡§ö‡§æ‡§∞' ‡§Æ‡•ç‡§π‡§≥‡•ç‡§Ø‡§æ‡§∞ ‡§ï‡§ø‡§§‡•á‡§Ç?",
    "Tell me a story about a king in English.",
    "How can I learn Konkani fast?"
]

print("\nüé® Gonyai Multi-Directional Testing Starting...\n")
for q in test_questions:
    print(f"User Question: {q}")
    response = generate_manual(model, tokenizer, q, temperature=0.7)
    print(f"Gonyai: {response}")
    print("-" * 50)