# Benchmarking MoR: Bangla & WikiText-2 (Universal Edition)

This notebook runs the full suite of experiments for the Mixture-of-Recursion (MoR) Transformer on **Bangla** and **WikiText-2** datasets.

**Compatible Environments:**
*   **Kaggle Kernels**
*   **Google Colab**
*   **Local PC**

**Experiments:**
1. **Baseline N=12**: Standard Transformer with 12 layers.
2. **MoR N=12 (Exp 1)**: MoR with 12 layers, prioritizing efficiency.
3. **Baseline N=6**: Standard Transformer with 6 layers (comparable cost target).
4. **MoR N=12 (Exp 2)**: MoR with 12 layers, tuned for equal cost to N=6.

Includes automated setup, training, and visualization of results (Plots & Matrices).

In [None]:
# 1. Setup Repository & Dependencies
import os
import sys
import subprocess

# Environment Detection
IN_COLAB = False
IN_KAGGLE = False

try:
    import google.colab
    IN_COLAB = True
    print("Detected Environment: Google Colab")
except ImportError:
    if os.path.exists('/kaggle'):
        IN_KAGGLE = True
        print("Detected Environment: Kaggle")
    else:
        print("Detected Environment: Local PC")

# Optimized Directory Logic
# If we are already in the 'code' directory (e.g. running locally), skip cloning
if os.path.exists('train_amp.py') and os.path.exists('config.py'):
    print(f"Already in code directory: {os.getcwd()}")
else:
    # Need to setup
    REPO_URL = "https://github.com/ShMazumder/Benchmarking-MoR-on-fine-tuned-SLM.git"
    REPO_DIR = "Benchmarking-MoR-on-fine-tuned-SLM"

    if not os.path.exists(REPO_DIR):
        # Check if we are in the repo root
        if os.path.exists('code') and os.path.exists('README.md'):
            print("Already in repository root.")
        else:
            print(f"Cloning repository from {REPO_URL}...")
            !git clone {REPO_URL}
    
    # Move to code dir
    if os.path.exists(os.path.join(REPO_DIR, 'code')):
        os.chdir(os.path.join(REPO_DIR, 'code'))
    elif os.path.exists('code'):
        # If in repo root
        os.chdir('code')
        
    print(f"Changed directory to {os.getcwd()}")

# Install Requirements
if IN_COLAB or IN_KAGGLE:
    print("Installing dependencies (Cloud Environment)...")
    !pip install -r requirements.txt --quiet
    !pip install seaborn matplotlib pandas scikit-learn datasets --quiet
    print("Dependencies installed.")
else:
    print("\n[NOTICE] Local Environment detected.")
    print("Skipping automatic 'pip install' to preserve your local environment.")
    print("Please ensure you have installed usage requirements:")
    print("   pip install -r requirements.txt")
    print("   pip install seaborn matplotlib pandas scikit-learn datasets")

In [None]:
# 1.2 Check GPU Status
import torch
if torch.cuda.is_available():
    print(f"GPU Detected: {torch.cuda.get_device_name(0)}")
    print("FP16/AMP will be enabled automatically.")
else:
    print("WARNING: No GPU detected. Training will be extremely slow (FP32 CPU).")
    if IN_COLAB: print("Colab: Runtime -> Change runtime type -> GPU.")
    elif IN_KAGGLE: print("Kaggle: Session Options -> Accelerator -> GPU P100.")
    else: print("Local: Ensure you have NVIDIA Drivers and PyTorch with CUDA installed.")

In [None]:
# 1.5. Download Substitute Bangla Dataset (if missing)
# This ensures the notebook runs anywhere (Colab/Kaggle/Local) without manual file uploads.

import os
from pathlib import Path
from datasets import load_dataset

BANGLA_DATA_PATH = Path('data/bangla/bangla_slm.txt')

if not BANGLA_DATA_PATH.exists():
    print("Bangla dataset not found. Downloading Bangla Wikipedia subset...")
    BANGLA_DATA_PATH.parent.mkdir(parents=True, exist_ok=True)
    
    # Load Bengali Wikipedia (streaming mode to avoid full download)
    # We aim for ~15MB of text to be slightly larger than WikiText-2 (10MB)
    try:
        dataset = load_dataset('wikimedia/wikipedia', '20231101.bn', split='train', streaming=True)
    except Exception as e:
        print(f"Failed to load wikimedia/wikipedia: {e}. Trying fallback...")
        dataset = load_dataset('wikimedia/wikipedia', '20231101.bn', split='train', streaming=True)
    
    target_size = 15 * 1024 * 1024 # 15 MB
    current_size = 0
    text_accumulated = []
    
    print("downloading...")
    for i, article in enumerate(dataset):
        text = article['text']
        text_accumulated.append(text)
        current_size += len(text.encode('utf-8'))
        
        if current_size >= target_size:
            break
        
        if i % 100 == 0:
            print(f"Downloaded {current_size / 1024 / 1024:.2f} MB...")
            
    with open(BANGLA_DATA_PATH, 'w', encoding='utf-8') as f:
        f.write('\n\n'.join(text_accumulated))
        
    print(f"Saved {current_size / 1024 / 1024:.2f} MB of Bangla text to {BANGLA_DATA_PATH}")
else:
    print(f"Bangla dataset found at {BANGLA_DATA_PATH}")

In [None]:
# 2. HOST FIX: Patch bangla.py (Fixing known issues & Pathing)
# This patch makes the loader robust for both Kaggle input paths and local/Colab paths.

bangla_py_content = '''"""Bangla SLM Dataset Loader with configurable tokenization"""
import os
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
from config import Config

try:
    from data.tokenizers import load_sentencepiece, train_sentencepiece, build_word_vocab_from_text, encode_with_sp
except Exception:
    from tokenizers import load_sentencepiece, train_sentencepiece, build_word_vocab_from_text, encode_with_sp


class BanglaSLMDataset(Dataset):
    def __init__(self, seq_length=64, split='train', split_ratio=0.9, tokenization='char', tokenizer_model=None, vocab_size=None, data_file=None):
        self.seq_length = seq_length
        cfg = Config()

        # Determine data path: use argument if provided, else use config default
        if data_file:
            data_path = Path(data_file)
        else:
            # Fallback to config 
            data_path = Path(cfg.bangla_data_file)

        # Check if file exists, if not search Kaggle/Colab common paths
        if not data_path.exists():
            # Kaggle: check input directory for common patterns
            search_paths = [
                Path('/kaggle/input/bangla-slm/bangla_slm.txt'),
                Path('/kaggle/input/bangla-dataset/bangla_slm.txt'),
                Path('/kaggle/input/bangla.txt'),
                # Add other potential paths here
            ]
            found = False
            for kp in search_paths:
                if kp.exists():
                    print(f"Found dataset at: {kp}")
                    data_path = kp
                    found = True
                    break

        if data_path.exists():
            text = data_path.read_text(encoding='utf-8')
        else:
             # Should not happen if preceding download cell ran correctly
             raise FileNotFoundError(f"Bangla dataset not found at {data_path}. Please ensure the download cell ran successfully.")

        self.tokenization = tokenization
        if tokenization == 'char':
            chars = sorted(list(set(text)))
            self.vocab_size = len(chars)
            self.stoi = {ch: i for i, ch in enumerate(chars)}
            self.itos = {i: ch for i, ch in enumerate(chars)}
            data_ids = [self.stoi[ch] for ch in text]

        elif tokenization == 'word':
            stoi, itos = build_word_vocab_from_text(text)
            self.stoi = stoi
            self.itos = itos
            self.vocab_size = len(stoi)
            data_ids = [self.stoi[w] for w in text.split()]

        elif tokenization == 'subword':
            model_path = tokenizer_model or cfg.tokenizer_model_bangla
            model_file = Path(model_path)
            if not model_file.parent.exists():
                 model_file.parent.mkdir(parents=True, exist_ok=True)
            
            if not model_file.exists():
                if len(text) > 100:
                     print(f"Training SentencePiece model for Bangla at {model_file}...")
                     # We need to write the text to a temp file for SP training
                     temp_train_file = 'temp_bangla_for_sp.txt'
                     with open(temp_train_file, 'w', encoding='utf-8') as f: f.write(text)
                     
                     model_prefix = str(model_file.with_suffix(''))
                     train_sentencepiece(temp_train_file, model_prefix, vocab_size or cfg.subword_vocab_size)
                     model_file = Path(model_prefix + '.model')
                     
                     if os.path.exists(temp_train_file): os.remove(temp_train_file)
                else:
                     print("Text too short to train tokenizer.")
                     
            if model_file.exists():
                sp = load_sentencepiece(str(model_file))
                self.vocab_size = sp.get_piece_size()
                data_ids = encode_with_sp(sp, text)
            else:
                self.vocab_size = 100; data_ids = []

        else:
            raise ValueError(f'Unknown tokenization: {tokenization}')

        data = torch.tensor(data_ids, dtype=torch.long)

        split_idx = int(len(data) * split_ratio)
        if split == 'train':
            self.data = data[:split_idx]
        else:
            self.data = data[split_idx:]

    def __len__(self):
        if len(self.data) <= self.seq_length:
            return 0
        return len(self.data) - self.seq_length

    def __getitem__(self, idx):
        x = self.data[idx:idx + self.seq_length]
        y = self.data[idx + 1:idx + self.seq_length + 1]
        return x, y


def get_bangla_loaders(batch_size=64, seq_length=64, split_ratio=0.9, tokenization=None, tokenizer_model=None, vocab_size=None, data_file=None):
    cfg = Config()
    tokenization = tokenization or cfg.tokenization
    tokenizer_model = tokenizer_model or (cfg.tokenizer_model_bangla if tokenization == 'subword' else None)
    vocab_size = vocab_size or cfg.subword_vocab_size

    train_dataset = BanglaSLMDataset(seq_length=seq_length, split='train', split_ratio=split_ratio,
                                     tokenization=tokenization, tokenizer_model=tokenizer_model, vocab_size=vocab_size, data_file=data_file)
    test_dataset = BanglaSLMDataset(seq_length=seq_length, split='test', split_ratio=split_ratio,
                                    tokenization=tokenization, tokenizer_model=tokenizer_model, vocab_size=vocab_size, data_file=data_file)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    return train_loader, test_loader, train_dataset.vocab_size
'''

with open('data/bangla.py', 'w') as f:
    f.write(bangla_py_content)
print("Patched data/bangla.py")

In [None]:
# 3. HOST FIX: Patch wikitext.py to handle sentencepiece training file correctly
wikitext_py_content = '''"""WikiText-2 Dataset Loader with configurable tokenization"""
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from pathlib import Path
from config import Config
import os

try:
    from data.tokenizers import load_sentencepiece, train_sentencepiece, build_word_vocab_from_text, encode_with_sp
except Exception:
    from tokenizers import load_sentencepiece, train_sentencepiece, build_word_vocab_from_text, encode_with_sp


class WikiText2Dataset(Dataset):
    def __init__(self, seq_length=64, split='train', tokenization='char', tokenizer_model=None, vocab_size=None):
        self.seq_length = seq_length
        cfg = Config()

        # Load WikiText-2 from HuggingFace
        dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split=split)
        text = ' '.join(dataset['text'])

        self.tokenization = tokenization
        if tokenization == 'char':
            chars = sorted(list(set(text)))
            self.vocab_size = len(chars)
            self.stoi = {ch: i for i, ch in enumerate(chars)}
            self.itos = {i: ch for i, ch in enumerate(chars)}
            data_ids = [self.stoi[ch] for ch in text]

        elif tokenization == 'word':
            stoi, itos = build_word_vocab_from_text(text)
            self.stoi = stoi
            self.itos = itos
            self.vocab_size = len(stoi)
            data_ids = [self.stoi[w] for w in text.split()]

        elif tokenization == 'subword':
            model_path = tokenizer_model or cfg.tokenizer_model_wikitext
            model_file = Path(model_path)
            if not model_file.parent.exists():
                 model_file.parent.mkdir(parents=True, exist_ok=True)

            if not model_file.exists():
                model_prefix = str(model_file.with_suffix(''))
                # FIX: Ensure the raw file exists for SP training
                raw_file = 'wikitext_raw.txt'
                if not os.path.exists(raw_file):
                    # Use training split for training tokenizer
                    train_ds = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
                    full_text = ' '.join(train_ds['text'])
                    with open(raw_file, 'w', encoding='utf-8') as f:
                        f.write(full_text)
                
                train_sentencepiece(raw_file, model_prefix, vocab_size or cfg.subword_vocab_size)
                model_file = Path(model_prefix + '.model')
            sp = load_sentencepiece(str(model_file))
            self.vocab_size = sp.get_piece_size()
            self.stoi = None
            self.itos = None
            data_ids = encode_with_sp(sp, text)

        else:
            raise ValueError(f'Unknown tokenization: {tokenization}')

        self.data = torch.tensor(data_ids, dtype=torch.long)

    def __len__(self):
        return len(self.data) - self.seq_length

    def __getitem__(self, idx):
        x = self.data[idx:idx + self.seq_length]
        y = self.data[idx + 1:idx + self.seq_length + 1]
        return x, y


def get_wikitext_loaders(batch_size=64, seq_length=64, tokenization=None, tokenizer_model=None, vocab_size=None):
    cfg = Config()
    tokenization = tokenization or cfg.tokenization
    tokenizer_model = tokenizer_model or (cfg.tokenizer_model_wikitext if tokenization == 'subword' else None)
    vocab_size = vocab_size or cfg.subword_vocab_size

    train_dataset = WikiText2Dataset(seq_length, 'train', tokenization, tokenizer_model, vocab_size)
    val_dataset = WikiText2Dataset(seq_length, 'validation', tokenization, tokenizer_model, vocab_size)
    test_dataset = WikiText2Dataset(seq_length, 'test', tokenization, tokenizer_model, vocab_size)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    return train_loader, val_loader, test_loader, train_dataset.vocab_size
'''

with open('data/wikitext.py', 'w') as f:
    f.write(wikitext_py_content)
print("Patched data/wikitext.py")

In [None]:
# 4. Run Experiments
import subprocess
import torch

# Define your experiment configuration
datasets = ['bangla', 'wikitext'] 
experiments = ['baseline_12', 'mor_exp1', 'baseline_6', 'mor_exp2']
EPOCHS = 10 # Adjust based on your time budget
VOCAB = 4000

def run_cmd(cmd):
    print("Running:", " ".join(cmd))
    subprocess.check_call(cmd)

# Explicitly enabling AMP (Mixed Precision) flag for safety
# Note: The script 'train_amp.py' enables AMP by default on CUDA, but we pass flag to be explicit.

for datas in datasets:
    print(f"\n=== Running Experiments for {datas.upper()} ===")
    
    for exp in experiments:
        cmd = [
            sys.executable, 'train_amp.py', 
            '--dataset', datas, 
            '--experiment', exp,
            '--tokenization', 'subword',
            '--subword_vocab_size', str(VOCAB), 
            '--epochs', str(EPOCHS),
            '--device', 'cuda' if torch.cuda.is_available() else 'cpu',
            '--amp' # Request Mixed Precision
        ]
        try:
            run_cmd(cmd)
        except Exception as e:
            print(f"Experiment {exp} on {datas} failed: {e}")


## 5. Visualizations: Plots, Graphs, & Matrices
We will now visualize the results:
1. **Training Curves**: Loss and Accuracy over time.
2. **Comparative Analysis**: Bar charts comparing Accuracy vs. Compute (Depth).
3. **Confusion Matrix**: On the test set for the best performing model.

In [None]:
import json
import matplotlib.pyplot as plt
import pandas as pd
import glob
import seaborn as sns

# Set style
sns.set_theme(style="whitegrid")

def plot_training_curves(dataset_name):
    history_files = glob.glob(f'results/{dataset_name}_*_history.json')
    if not history_files:
        print(f"No history files found for {dataset_name}")
        return

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    for hf in history_files:
        exp_name = os.path.basename(hf).replace(f'{dataset_name}_', '').replace('_history.json', '')
        with open(hf, 'r') as f:
            hist = json.load(f)
            df = pd.DataFrame(hist)
            
            ax1.plot(df['epoch'], df['loss'], marker='o', label=exp_name)
            ax2.plot(df['epoch'], df['acc'], marker='s', label=exp_name)

    ax1.set_title(f'{dataset_name.capitalize()}: Training Loss')
    ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss')
    ax1.legend()

    ax2.set_title(f'{dataset_name.capitalize()}: Training Accuracy')
    ax2.set_xlabel('Epoch'); ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    
    plt.tight_layout()
    plt.show()

for d in datasets:
    plot_training_curves(d)

In [None]:
def plot_comparison(dataset_name):
    result_files = glob.glob(f'results/{dataset_name}_*.json')
    # Filter out history files
    result_files = [f for f in result_files if '_history' not in f]
    
    data = []
    for rf in result_files:
        with open(rf, 'r') as f:
            res = json.load(f)
            data.append(res)
    
    if not data: return

    df = pd.DataFrame(data)
    
    fig, ax1 = plt.subplots(figsize=(10, 6))
    
    # Bar plot for Test Accuracy
    sns.barplot(data=df, x='experiment', y='test_accuracy', ax=ax1, palette='viridis', alpha=0.8)
    ax1.set_ylabel('Test Accuracy (%)')
    ax1.set_title(f'{dataset_name.capitalize()}: Accuracy vs Effective Depth')
    
    # Line plot for Effective Depth on secondary axis
    ax2 = ax1.twinx()
    sns.lineplot(data=df, x='experiment', y='effective_depth', ax=ax2, color='red', marker='D', markersize=10, linewidth=3, sort=False)
    ax2.set_ylabel('Effective Depth (Layers)')
    ax2.grid(False)
    
    plt.show()

for d in datasets:
    plot_comparison(d)

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import numpy as np

# Function to generate confusion matrix for a specific model
def plot_confusion_matrix(dataset, experiment, num_samples=1000):
    print(f"Generating confusion matrix for {dataset} - {experiment}...")
    
    # Load config and model (This requires importing project modules)
    from config import Config
    from models import BaselineTransformer, MoRTransformer
    from data import get_bangla_loaders, get_wikitext_loaders
    
    cfg = Config()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    cfg.device = device
    
    # Get Loaders
    if dataset == 'bangla':
         _, test_loader, vocab_size = get_bangla_loaders(batch_size=32, tokenization='subword', vocab_size=VOCAB)
    elif dataset == 'wikitext':
         _, _, test_loader, vocab_size = get_wikitext_loaders(batch_size=32, tokenization='subword', vocab_size=VOCAB)
    
    # Initialize Model Structure
    if 'baseline' in experiment:
        n_layers = 12 if '12' in experiment else 6
        model = BaselineTransformer(vocab_size, n_layers=n_layers, **vars(cfg))
    else:
        model = MoRTransformer(vocab_size, n_layers=12, **vars(cfg))
    
    # Load Weights (Find the last checkpoint)
    ckpt_path = f"checkpoints/{experiment}_epoch{EPOCHS}.pt"
    if not os.path.exists(ckpt_path):
        print(f"Checkpoint {ckpt_path} not found. Skipping CM.")
        # Try finding any checkpoint
        avail = glob.glob(f"checkpoints/{experiment}_epoch*.pt")
        if avail:
            ckpt_path = sorted(avail)[-1]
            print(f"Using alternate checkpoint: {ckpt_path}")
        else:
            return
        
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt['model_state'])
    model.to(device)
    model.eval()
    
    # Collect predictions
    y_true = []
    y_pred = []
    
    with torch.no_grad():
        for i, (x, y) in enumerate(test_loader):
            if i * x.size(0) > num_samples: break
            x = x.to(device)
            if 'mor' in experiment:
                logits, _, _ = model(x, training=False)
            else:
                logits, _ = model(x)
            
            preds = torch.argmax(logits, dim=-1)
            y_true.extend(y.view(-1).cpu().numpy())
            y_pred.extend(preds.view(-1).cpu().numpy())
            
    # Compute CM (Subset of top 20 most frequent tokens to make it readable)
    cm = confusion_matrix(y_true, y_pred)
    
    # Filter Top 20 tokens
    unique, counts = np.unique(y_true, return_counts=True)
    top_indices = np.argsort(counts)[::-1][:20]
    # Keep indices valid
    top_indices = [i for i in top_indices if i < len(unique)]
    top_tokens = unique[top_indices]
    
    cm_subset = confusion_matrix(y_true, y_pred, labels=top_tokens)
    
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm_subset, annot=True, fmt='d', cmap='Blues', 
                xticklabels=top_tokens, yticklabels=top_tokens)
    plt.title(f'Confusion Matrix (Top 20 Tokens) - {experiment}')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

# Example: Plot CM for models
try:
    # Plot for first available non-baseline exp per dataset
    for d in datasets:
        plot_confusion_matrix(d, 'mor_exp1')
except Exception as e:
    print(f"Could not plot confusion matrix: {e}")