# Bangla Training - FINAL FIX

## Fixes:
1. ✅ Curated corpus (20MB Wikipedia)
2. ✅ 16K vocab, 10 epochs
3. ✅ LR 1e-4 + warmup + cosine
4. ✅ Gradient clipping + weight decay
5. ✅ Batch 64

**Target**: Loss <6.0, Acc >15%

In [None]:
# 1. Setup
import os
import sys

if os.path.exists('/content'):
    print("Environment: Colab")
elif os.path.exists('/kaggle'):
    print("Environment: Kaggle")
else:
    print("Environment: Local")

if not os.path.exists('train_amp_v2.py'):
    !git clone https://github.com/ShMazumder/Benchmarking-MoR-on-fine-tuned-SLM.git
    os.chdir('Benchmarking-MoR-on-fine-tuned-SLM/code')
    print(f"Changed to: {os.getcwd()}")

!pip install -q datasets sentencepiece
print("✓ Ready")

In [None]:
# 2. GPU Check
import torch
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

In [None]:
# 3. Download Corpus (Wikipedia - more reliable)
from pathlib import Path
from datasets import load_dataset
import re
import shutil

def clean_bangla(text):
    text = re.sub(r'http\S+|www\S+|\S+@\S+', '', text)
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'([।!?])\s*', r'\1\n', text)
    lines = [l.strip() for l in text.split('\n') if len(l.strip()) > 20]
    return '\n'.join(list(dict.fromkeys(lines)))

corpus_path = Path('data/bangla/bangla_slm.txt')
corpus_path.parent.mkdir(parents=True, exist_ok=True)

if not corpus_path.exists():
    print("Downloading Wikipedia Bangla...")
    dataset = load_dataset('wikimedia/wikipedia', '20231101.bn', split='train')
    
    texts = []
    size = 0
    target = 20 * 1024 * 1024
    
    for i, article in enumerate(dataset):
        text = clean_bangla(article['text'])
        if len(text) > 100:
            texts.append(text)
            size += len(text.encode('utf-8'))
            if i % 100 == 0:
                print(f"  {size/1024/1024:.1f}MB")
            if size >= target:
                break
    
    corpus_path.write_text('\n\n'.join(texts), encoding='utf-8')
    print(f"✓ Saved {size/1024/1024:.1f}MB")
else:
    print("✓ Corpus exists")

print(f"Sample: {corpus_path.read_text(encoding='utf-8')[:200]}...")

In [None]:
# 4. Config
with open('config.py', 'r') as f:
    cfg = f.read()
cfg = cfg.replace('batch_size = 128', 'batch_size = 64')
cfg = cfg.replace('learning_rate = 3e-4', 'learning_rate = 1e-4')
with open('config.py', 'w') as f:
    f.write(cfg)
print("✓ Config: batch=64, lr=1e-4")

In [None]:
# 5. Train All
for exp in ['baseline_6', 'baseline_12', 'mor_exp1', 'mor_exp2']:
    print(f"\n{'='*70}\nTRAINING: {exp}\n{'='*70}")
    !python train_amp_v2.py --dataset bangla --experiment {exp} --tokenization subword --subword_vocab_size 16000 --epochs 10 --device cuda --amp
    print(f"✓ {exp} done")

In [None]:
# 6. Results
import json
import matplotlib.pyplot as plt

results = {}
for exp in ['baseline_6', 'baseline_12', 'mor_exp1', 'mor_exp2']:
    f = Path(f'results/bangla_{exp}.json')
    if f.exists():
        results[exp] = json.loads(f.read_text())

print("\nRESULTS:")
for name, data in results.items():
    acc = data.get('test_accuracy', 0)
    loss = data.get('test_loss', 0)
    print(f"{name:15s}: {acc:5.2f}% acc, {loss:.4f} loss")

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
for exp in results:
    h = Path(f'results/bangla_{exp}_history.json')
    if h.exists():
        hist = json.loads(h.read_text())
        epochs = [x['epoch'] for x in hist]
        ax1.plot(epochs, [x['loss'] for x in hist], 'o-', label=exp)
        ax2.plot(epochs, [x['acc'] for x in hist], 's-', label=exp)

ax1.set(xlabel='Epoch', ylabel='Loss', title='Training Loss')
ax1.legend()
ax1.grid(alpha=0.3)
ax2.set(xlabel='Epoch', ylabel='Accuracy (%)', title='Training Accuracy')
ax2.legend()
ax2.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('results.png', dpi=150)
plt.show()

# Success check
b6 = results.get('baseline_6', {}).get('test_accuracy', 0)
b12 = results.get('baseline_12', {}).get('test_accuracy', 0)
print(f"\nN=6: {b6:.1f}% (target >30%)")
print(f"N=12: {b12:.1f}% (target >15%)")
if b6 > 30 and b12 > 15:
    print("✅ SUCCESS!")
elif b6 > 30:
    print("⚠️ Partial (N=6 works, N=12 struggles)")
else:
    print("❌ Failed - need more fixes")