In [None]:
# 1) Clone repo (if running on a fresh Kaggle session) and install deps
!git clone https://github.com/ShMazumder/Benchmarking-MoR-on-fine-tuned-SLM.git || true
%cd Benchmarking-MoR-on-fine-tuned-SLM/code
# Install requirements (Kaggle may already have torch; this will install others)
!pip install -r requirements.txt

## 2) Prepare a patched script `train_amp.py`

This cell copies `train.py` to `train_amp.py` and programmatically replaces the `train_baseline` and `train_mor` functions with versions that use AMP, save checkpoint every epoch, and collect per-epoch history into JSON files under `results/.`

In [None]:
import io, sys, re
from pathlib import Path
p = Path('train.py')
assert p.exists(), 'train.py not found in code/'
s = p.read_text()

# 1) add imports for AMP, JSON (if not already present)
s = s.replace(
    "from utils import calculate_accuracy, save_checkpoint, save_results, Timer, print_model_info",
    "from utils import calculate_accuracy, save_checkpoint, save_results, Timer, print_model_info\nfrom torch.cuda.amp import autocast, GradScaler\nimport json\nfrom pathlib import Path as _Path"
)

# 2) reduce epochs defaults by editing config usage at runtime is simpler: we'll patch Config defaults here
cfg_path = Path('config.py')
cfg = cfg_path.read_text()
cfg = cfg.replace('epochs_baseline = 30','epochs_baseline = 3')
cfg = cfg.replace('epochs_mor_exp1 = 30','epochs_mor_exp1 = 3')
cfg = cfg.replace('epochs_mor_exp2 = 50','epochs_mor_exp2 = 5')
cfg_path.write_text(cfg)
print('Updated config.py to smaller epoch counts for quick tests')

# 3) prepare new train_baseline function (AMP + checkpoint + history)
new_baseline = r"""
def train_baseline(model, train_loader, test_loader, config, experiment_name):
    device = config.device
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate)
    criterion = nn.CrossEntropyLoss()
    use_amp = True if config.device.startswith('cuda') else False
    scaler = GradScaler() if use_amp else None
    timer = Timer()
    history = []
    print(f"\nTraining {experiment_name}...")
    timer.start()
    for epoch in range(config.epochs_baseline):
        model.train()
        total_loss = 0
        total_acc = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.epochs_baseline}")
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to(device), y.to(device)
            if use_amp:
                with autocast():
                    logits, effective_depth = model(x)
                    loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
                optimizer.zero_grad()
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                logits, effective_depth = model(x)
                loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            acc = calculate_accuracy(logits, y)
            total_loss += loss.item()
            total_acc += acc
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{acc:.2f}%',
                'depth': f'{effective_depth:.2f}'
            })
        avg_loss = total_loss / len(train_loader)
        avg_acc = total_acc / len(train_loader)
        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Acc={avg_acc:.2f}%")
        # checkpoint
        _Path(config.checkpoint_dir).mkdir(parents=True, exist_ok=True)
        torch.save({'epoch': epoch+1, 'model_state': model.state_dict(), 'optimizer': optimizer.state_dict()}, _Path(config.checkpoint_dir)/f'{experiment_name}_epoch{epoch+1}.pt')
        # record history
        history.append({'epoch': epoch+1, 'loss': avg_loss, 'acc': avg_acc})
    timer.stop()
    training_time = timer.get_elapsed()
    print('\nEvaluating...')
    model.eval()
    test_acc = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            test_acc += calculate_accuracy(logits, y)
    test_acc /= len(test_loader)
    results = {
        'experiment': experiment_name,
        'model_type': 'baseline',
        'n_layers': model.n_layers,
        'accuracy': avg_acc,
        'test_accuracy': test_acc,
        'effective_depth': float(model.n_layers),
        'training_time_seconds': training_time
    }
    # save results and history
    _Path(config.results_dir).mkdir(parents=True, exist_ok=True)
    save_results(results, f'{config.results_dir}/{experiment_name}.json')
    with open(f'{config.results_dir}/{experiment_name}_history.json','w') as f:
        json.dump(history, f)
    print('\nResults:')
    print(f"  Training Accuracy: {avg_acc:.2f}%")
    print(f"  Test Accuracy: {test_acc:.2f}%")
    print(f"  Effective Depth: {model.n_layers}")
    print(f"  Training Time: {training_time:.0f}s")
    return results
"""

# 4) prepare new train_mor function
new_mor = r"""
def train_mor(model, train_loader, test_loader, config, experiment_name, epochs, lambda_penalty=0.1):
    device = config.device
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate)
    criterion = nn.CrossEntropyLoss()
    use_amp = True if config.device.startswith('cuda') else False
    scaler = GradScaler() if use_amp else None
    timer = Timer()
    history = []
    print(f"\nTraining {experiment_name}...")
    timer.start()
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        total_acc = 0
        total_depth = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to(device), y.to(device)
            if use_amp:
                with autocast():
                    logits, effective_depth, routing_stats = model(x, training=True)
                    ce_loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
                    depth_penalty = lambda_penalty * effective_depth
                    loss = ce_loss + depth_penalty
                optimizer.zero_grad()
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                logits, effective_depth, routing_stats = model(x, training=True)
                ce_loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
                depth_penalty = lambda_penalty * effective_depth
                loss = ce_loss + depth_penalty
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            acc = calculate_accuracy(logits, y)
            total_loss += ce_loss.item()
            total_acc += acc
            total_depth += effective_depth.item()
            pbar.set_postfix({
                'loss': f'{ce_loss.item():.4f}',
                'acc': f'{acc:.2f}%',
                'depth': f'{effective_depth.item():.2f}'
            })
        avg_loss = total_loss / len(train_loader)
        avg_acc = total_acc / len(train_loader)
        avg_depth = total_depth / len(train_loader)
        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Acc={avg_acc:.2f}%, Depth={avg_depth:.2f}")
        # checkpoint
        _Path(config.checkpoint_dir).mkdir(parents=True, exist_ok=True)
        torch.save({'epoch': epoch+1, 'model_state': model.state_dict(), 'optimizer': optimizer.state_dict()}, _Path(config.checkpoint_dir)/f'{experiment_name}_epoch{epoch+1}.pt')
        history.append({'epoch': epoch+1, 'loss': avg_loss, 'acc': avg_acc, 'depth': avg_depth})
    timer.stop()
    training_time = timer.get_elapsed()
    print('\nEvaluating...')
    model.eval()
    test_acc = 0
    test_depth = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            logits, effective_depth, routing_stats = model(x, training=False)
            test_acc += calculate_accuracy(logits, y)
            test_depth += effective_depth.item()
    test_acc /= len(test_loader)
    test_depth /= len(test_loader)
    results = {
        'experiment': experiment_name,
        'model_type': 'mor',
        'n_layers': model.n_layers,
        'accuracy': avg_acc,
        'test_accuracy': test_acc,
        'effective_depth': avg_depth,
        'test_effective_depth': test_depth,
        'training_time_seconds': training_time,
        'lambda_penalty': lambda_penalty
    }
    _Path(config.results_dir).mkdir(parents=True, exist_ok=True)
    save_results(results, f'{config.results_dir}/{experiment_name}.json')
    with open(f'{config.results_dir}/{experiment_name}_history.json','w') as f:
        json.dump(history, f)
    print('\nResults:')
    print(f"  Training Accuracy: {avg_acc:.2f}%")
    print(f"  Test Accuracy: {test_acc:.2f}%")
    print(f"  Effective Depth: {avg_depth:.2f}")
    print(f"  Test Effective Depth: {test_depth:.2f}")
    print(f"  Training Time: {training_time:.0f}s")
    return results
"""

# 5) replace original functions in the file
s2 = s
s2 = re.sub(r"def train_baseline\([\s\S]*?return results\n\n", new_baseline + "\n\n", s2)
s2 = re.sub(r"def train_mor\([\s\S]*?return results\n\n", new_mor + "\n\n", s2)

Path('train_amp.py').write_text(s2)
print('Wrote train_amp.py with AMP+checkpointing and history.\nRun: python train_amp.py --dataset shakespeare --experiment baseline_6 --device cuda')

## 3) Run quick 3-epoch test (baseline_6) using the patched script

This will use GPU if available and write per-epoch history to `results/Baseline_N6_history.json` and checkpoints to `checkpoints/`.

In [None]:
# Run the quick test
!python train_amp.py --dataset shakespeare --experiment baseline_6 --device cuda

## 4) Plot per-epoch loss/accuracy and short analysis

This cell reads the produced history JSON and plots training loss and accuracy per epoch. It also prints a short comparison versus the expected results from the README.

In [None]:
import json
import matplotlib.pyplot as plt
from pathlib import Path

hist_path = Path('results/Baseline_N6_history.json')
if not hist_path.exists():
    # try the naming pattern used earlier (experiment name in code was Baseline_N6)
    hist_candidates = list(Path('results').glob('*history.json'))
    if hist_candidates:
        hist_path = hist_candidates[0]
    else:
        print('No history JSON found in results/. Look for files in results/ and adjust path.')
if hist_path and hist_path.exists():
    hist = json.load(open(hist_path))
    epochs = [h['epoch'] for h in hist]
    loss = [h['loss'] for h in hist]
    acc = [h['acc'] for h in hist]

    fig, ax1 = plt.subplots()
    ax1.plot(epochs, loss, '-o', color='tab:red', label='train loss')
    ax1.set_xlabel('epoch')
    ax1.set_ylabel('loss', color='tab:red')
    ax2 = ax1.twinx()
    ax2.plot(epochs, acc, '-s', color='tab:blue', label='train acc')
    ax2.set_ylabel('accuracy (%)', color='tab:blue')
    plt.title('Training loss and accuracy per epoch')
    fig.tight_layout()
    plt.show()

    # Short analysis
    print('\nShort analysis:')
    print(f"  Final training accuracy: {acc[-1]:.2f}%")
    # Load README expected numbers for tiny shakespeare (from earlier README notes)
    print('  Expected (README): Baseline N=6 test Acc ≈ 39.87% (example)')
    print('  Observed: training accuracy is much lower than expected — consider:')
    print('    - learning rate, model size, dataset preprocessing or labeling differences')
    print('    - try larger epochs, or hyperparameter tuning; verify dataset tokens and vocab mapping')