In [None]:
# 1) Clone repo (if running on a fresh Kaggle session) and install deps
!git clone https://github.com/ShMazumder/Benchmarking-MoR-on-fine-tuned-SLM.git || true
%cd Benchmarking-MoR-on-fine-tuned-SLM/code
# Install requirements (Kaggle may already have torch; this will install others)
!pip install -r requirements.txt

## 2) Prepare a patched script `train_amp.py`

This cell copies `train.py` to `train_amp.py` and programmatically appends safe overrides for `train_baseline` and `train_mor` that use AMP, save checkpoint every epoch, and collect per-epoch history into JSON files under `results/`. This approach writes the overrides as plain text to avoid f-string interpolation issues during generation.

In [None]:
from pathlib import Path
src = Path('train.py')
assert src.exists(), 'train.py not found in current directory'
dst = Path('train_amp.py')

# Reduce epochs for quick tests (edit config.py safely)
cfg_path = Path('config.py')
if cfg_path.exists():
    cfg = cfg_path.read_text()
    cfg = cfg.replace('epochs_baseline = 30','epochs_baseline = 3')
    cfg = cfg.replace('epochs_mor_exp1 = 30','epochs_mor_exp1 = 3')
    cfg = cfg.replace('epochs_mor_exp2 = 50','epochs_mor_exp2 = 5')
    cfg_path.write_text(cfg)
    print('Updated config.py to smaller epoch counts for quick tests')

# Copy base script and append safe overrides written as a plain triple-quoted string
dst.write_text(src.read_text())

overrides = '''
# --- APPENDED OVERRIDES (AMP + checkpointing + history) ---
from torch.cuda.amp import autocast, GradScaler
import json
from pathlib import Path as _Path

def train_baseline(model, train_loader, test_loader, config, experiment_name):
    device = config.device
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate)
    criterion = nn.CrossEntropyLoss()
    use_amp = True if str(config.device).startswith('cuda') else False
    scaler = GradScaler() if use_amp else None
    timer = Timer()
    history = []
    print('\\nTraining {}...'.format(experiment_name))
    timer.start()
    for epoch in range(config.epochs_baseline):
        model.train()
        total_loss = 0.0
        total_acc = 0.0
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config.epochs_baseline}')
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to(device), y.to(device)
            if use_amp:
                with autocast():
                    logits, effective_depth = model(x)
                    loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
                optimizer.zero_grad()
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                logits, effective_depth = model(x)
                loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            acc = calculate_accuracy(logits, y)
            total_loss += float(loss.item())
            total_acc += float(acc)
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{acc:.2f}%', 'depth': f'{float(effective_depth):.2f}'})
        avg_loss = total_loss / len(train_loader)
        avg_acc = total_acc / len(train_loader)
        print(f'Epoch {epoch+1}: Loss={avg_loss:.4f}, Acc={avg_acc:.2f}%')
        _Path(config.checkpoint_dir).mkdir(parents=True, exist_ok=True)
        torch.save({'epoch': epoch+1, 'model_state': model.state_dict(), 'optimizer': optimizer.state_dict()}, _Path(config.checkpoint_dir)/f'{experiment_name}_epoch{epoch+1}.pt')
        history.append({'epoch': epoch+1, 'loss': avg_loss, 'acc': avg_acc})
    timer.stop()
    training_time = timer.get_elapsed()
    print('\\nEvaluating...')
    model.eval()
    test_acc = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            test_acc += float(calculate_accuracy(logits, y))
    test_acc /= len(test_loader)
    results = {
        'experiment': experiment_name,
        'model_type': 'baseline',
        'n_layers': model.n_layers,
        'accuracy': avg_acc,
        'test_accuracy': test_acc,
        'effective_depth': float(model.n_layers),
        'training_time_seconds': training_time
    }
    _Path(config.results_dir).mkdir(parents=True, exist_ok=True)
    save_results(results, f'{config.results_dir}/{experiment_name}.json')
    with open(f'{config.results_dir}/{experiment_name}_history.json','w') as f:
        json.dump(history, f)
    print('\\nResults:')
    print('  Training Accuracy: {:.2f}%'.format(avg_acc))
    print('  Test Accuracy: {:.2f}%'.format(test_acc))
    print('  Effective Depth: {}'.format(model.n_layers))
    print('  Training Time: {:.0f}s'.format(training_time))
    return results

def train_mor(model, train_loader, test_loader, config, experiment_name, epochs, lambda_penalty=0.1):
    device = config.device
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate)
    criterion = nn.CrossEntropyLoss()
    use_amp = True if str(config.device).startswith('cuda') else False
    scaler = GradScaler() if use_amp else None
    timer = Timer()
    history = []
    print('\\nTraining {}...'.format(experiment_name))
    timer.start()
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        total_acc = 0.0
        total_depth = 0.0
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to(device), y.to(device)
            if use_amp:
                with autocast():
                    logits, effective_depth, routing_stats = model(x, training=True)
                    ce_loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
                    depth_penalty = lambda_penalty * effective_depth
                    loss = ce_loss + depth_penalty
                optimizer.zero_grad()
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                logits, effective_depth, routing_stats = model(x, training=True)
                ce_loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
                depth_penalty = lambda_penalty * effective_depth
                loss = ce_loss + depth_penalty
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            acc = calculate_accuracy(logits, y)
            total_loss += float(ce_loss.item())
            total_acc += float(acc)
            total_depth += float(effective_depth)
            pbar.set_postfix({'loss': f'{ce_loss.item():.4f}', 'acc': f'{acc:.2f}%', 'depth': f'{float(effective_depth):.2f}'})
        avg_loss = total_loss / len(train_loader)
        avg_acc = total_acc / len(train_loader)
        avg_depth = total_depth / len(train_loader)
        print(f'Epoch {epoch+1}: Loss={avg_loss:.4f}, Acc={avg_acc:.2f}%, Depth={avg_depth:.2f}')
        _Path(config.checkpoint_dir).mkdir(parents=True, exist_ok=True)
        torch.save({'epoch': epoch+1, 'model_state': model.state_dict(), 'optimizer': optimizer.state_dict()}, _Path(config.checkpoint_dir)/f'{experiment_name}_epoch{epoch+1}.pt')
        history.append({'epoch': epoch+1, 'loss': avg_loss, 'acc': avg_acc, 'depth': avg_depth})
    timer.stop()
    training_time = timer.get_elapsed()
    print('\\nEvaluating...')
    model.eval()
    test_acc = 0.0
    test_depth = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            logits, effective_depth, routing_stats = model(x, training=False)
            test_acc += float(calculate_accuracy(logits, y))
            test_depth += float(effective_depth)
    test_acc /= len(test_loader)
    test_depth /= len(test_loader)
    results = {
        'experiment': experiment_name,
        'model_type': 'mor',
        'n_layers': model.n_layers,
        'accuracy': avg_acc,
        'test_accuracy': test_acc,
        'effective_depth': avg_depth,
        'test_effective_depth': test_depth,
        'training_time_seconds': training_time,
        'lambda_penalty': lambda_penalty
    }
    _Path(config.results_dir).mkdir(parents=True, exist_ok=True)
    save_results(results, f'{config.results_dir}/{experiment_name}.json')
    with open(f'{config.results_dir}/{experiment_name}_history.json','w') as f:
        json.dump(history, f)
    print('\\nResults:')
    print('  Training Accuracy: {:.2f}%'.format(avg_acc))
    print('  Test Accuracy: {:.2f}%'.format(test_acc))
    print('  Effective Depth: {:.2f}'.format(avg_depth))
    print('  Test Effective Depth: {:.2f}'.format(test_depth))
    print('  Training Time: {:.0f}s'.format(training_time))
    return results
'''

## 3) Run quick 3-epoch test (MoR)

This will use GPU if available and write per-epoch history to `results/MoR_Exp1_history.json` and checkpoints to `checkpoints/.`

In [None]:
# Run the MoR quick test (use mor_exp1 for initial test)
!python train_amp.py --dataset shakespeare --experiment mor_exp1 --device cuda

## 4) Plot per-epoch loss/accuracy and short analysis

This cell reads the produced history JSON and plots training loss and accuracy per epoch. It also prints a short comparison versus the expected results from the README.

In [None]:
import json
import matplotlib.pyplot as plt
from pathlib import Path

def find_history(prefixs=('MoR_','Baseline_')):
    p = Path('results')
    if not p.exists():
        print('results/ directory not found')
        return None
    for pref in prefixs:
        cand = list(p.glob(f'{pref}*history.json'))
        if cand:
            return cand[0]
    cand = list(p.glob('*history.json'))
    return cand[0] if cand else None

hist_path = Path('results/MoR_Exp1_history.json')
if not hist_path.exists():
    hist_path = find_history(('MoR_','Baseline_'))
if not hist_path:
    print('No history JSON found in results/. Run the training cell and re-run this cell.')
else:
    hist = json.load(open(hist_path))
    epochs = [h['epoch'] for h in hist]
    loss = [h.get('loss', None) for h in hist]
    acc = [h.get('acc', None) for h in hist]

    fig, ax1 = plt.subplots()
    if any(v is not None for v in loss):
        ax1.plot(epochs, loss, '-o', color='tab:red', label='train loss')
        ax1.set_ylabel('loss', color='tab:red')
    ax1.set_xlabel('epoch')
    ax2 = ax1.twinx()
    if any(v is not None for v in acc):
        ax2.plot(epochs, acc, '-s', color='tab:blue', label='train acc')
        ax2.set_ylabel('accuracy (%)', color='tab:blue')
    plt.title(f'Training metrics from {hist_path.name}')
    fig.tight_layout()
    plt.show()

    # Short analysis
    print('\\nShort analysis:')
    if acc and acc[-1] is not None:
        print(f"  Final training accuracy: {acc[-1]:.2f}%")
    print('  Observed behavior: check routing statistics and effective depth in MoR runs')
    print('  Next steps: compare with Baseline, adjust lambda_penalty, or increase epochs for stability')