# üß¨ DenseNet Exploration Notebook

Welcome!  This notebook re‚Äëcreates all the learning experiments you did
with **LeNet** ‚Äì but for **DenseNet**.  We reuse your existing
`densenet.py` model definitions and the training helper
`train_densenet.py` so you can focus on experimentation rather than
boilerplate.


## Setup

In [None]:
# Uncomment on first run
# !pip install --quiet torch torchvision matplotlib tqdm

from train.train_densenet import run_densenet_training  # üëà your helper
from itertools import product
import json, time, pathlib, matplotlib.pyplot as plt


## 1Ô∏è‚É£  Baseline: lightweight CIFAR‚ÄëDenseNet

In [None]:
history_light = run_densenet_training(
    model_type='densenetcustom',   # lightweight variant
    epochs=100,
    train_batch_size=128,
    test_batch_size=256,
    learning_rate=0.1,
    optimiser='sgd',
)


### Learning curves (lightweight)

In [None]:
val_loss, val_acc = zip(*history_light)
plt.figure(figsize=(6,4))
plt.plot(val_loss); plt.title('Light DenseNet ‚Äì validation loss');
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.show()

plt.figure(figsize=(6,4))
plt.plot(val_acc); plt.title('Light DenseNet ‚Äì validation accuracy');
plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.show()


## 2Ô∏è‚É£  Baseline: full DenseNet‚Äë121

In [None]:
history_121 = run_densenet_training(
    model_type='densenet121',
    epochs=100,
    train_batch_size=128,
    test_batch_size=256,
    learning_rate=0.1,
    optimiser='sgd',
)


### Learning curves (DenseNet‚Äë121)

In [None]:
val_loss, val_acc = zip(*history_121)
plt.figure(figsize=(6,4))
plt.plot(val_loss); plt.title('DenseNet‚Äë121 ‚Äì validation loss');
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.show()

plt.figure(figsize=(6,4))
plt.plot(val_acc); plt.title('DenseNet‚Äë121 ‚Äì validation accuracy');
plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.show()


## 3Ô∏è‚É£  Grid‚Äësearch hyper‚Äëparameter sweep

We replicate the grid‚Äësearch you performed with LeNet.  The search
space below is small by default so the notebook finishes in a reasonable
time ‚Äì feel free to expand the lists.

In [None]:
param_grid = {
    'model_type':        ['densenetcustom', 'densenet121'],
    'train_batch_size':  [64, 128],
    'learning_rate':     [0.1, 0.01],
    'optimiser':         ['sgd', 'adam'],
}

search_results = []
run_id = 0

for model_type, bs, lr, opt in product(
        param_grid['model_type'],
        param_grid['train_batch_size'],
        param_grid['learning_rate'],
        param_grid['optimiser']):
    run_id += 1
    print(f"üîç  Run {run_id}: {model_type}, bs={bs}, lr={lr}, opt={opt}")
    hist = run_densenet_training(
        model_type=model_type,
        epochs=50,                        # shorter epochs for search
        train_batch_size=bs,
        test_batch_size=256,
        learning_rate=lr,
        optimiser=opt,
        silent=True,                       # assume helper supports this
    )
    best_acc = max(acc for _loss, acc in hist)
    search_results.append({
        'model_type': model_type, 'batch_size': bs, 'lr': lr,
        'optimiser': opt, 'best_val_acc': best_acc
    })

print("‚úÖ  Grid search complete!")


### Results DataFrame

In [None]:
import pandas as pd
res_df = pd.DataFrame(search_results)
res_df.sort_values('best_val_acc', ascending=False, inplace=True)
res_df.reset_index(drop=True, inplace=True)
res_df


### Best configuration

In [None]:
best_cfg = res_df.iloc[0]
print(best_cfg)


## 4Ô∏è‚É£  Train best DenseNet from scratch

In [None]:
history_best = run_densenet_training(
    model_type=best_cfg.model_type,
    epochs=150,
    train_batch_size=int(best_cfg.batch_size),
    test_batch_size=256,
    learning_rate=float(best_cfg.lr),
    optimiser=best_cfg.optimiser,
)


### Final curves & insights

In [None]:
val_loss, val_acc = zip(*history_best)
plt.figure(figsize=(6,4))
plt.plot(val_loss); plt.title('Best DenseNet ‚Äì validation loss');
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.show()

plt.figure(figsize=(6,4))
plt.plot(val_acc); plt.title('Best DenseNet ‚Äì validation accuracy');
plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.show()


## 5Ô∏è‚É£  Next explorations

* Try **DenseNet‚ÄëBC** variants, vary growth‚Äërate *k* & compression Œ∏.
* Add **cutout / RandAugment** to boost generalisation.
* Replace the classifier with **ArcFace** head for metric learning
  experiments.
* Port the best‚Äëfound config onto **Tiny‚ÄëImagenet** to see scaling
  behaviour.
