# TensorFlow Vanilla Autoencoder Lab

### Why this notebook
- Demonstrate the baseline TensorFlow/Keras autoencoder built in the `src` package.
- Provide a framework-friendly counterpart to the PyTorch notebook for cross-comparison.
- Capture the exact steps for training, evaluating, and visualising reconstructions.

### Learning objectives
- Train the dense autoencoder on Fashion-MNIST using a custom training loop.
- Inspect the metrics returned by `train()` and interpret PSNR trends.
- Reconstruct test images and reason about reconstruction fidelity.

### Prerequisites
- TensorFlow 2.x with GPU support optional.
- Familiarity with the overall autoencoder project layout.
- Optional: Matplotlib for visual comparisons.

### Notebook workflow
1. Import config and helpers from `../src`.
2. Run `train(CONFIG)` to fit the model and capture metrics.
3. Load saved weights and reconstruct sample images.
4. Experiment with configuration tweaks (latent size, optimiser, learning rate schedules).


**Workflow**

1. Import the package and view the configuration.
2. Fit the autoencoder with `train.train()`.
3. Rebuild a held-out sample using `inference.reconstruct`. 

In [None]:
from pathlib import Path
import sys
import numpy as np

NOTEBOOK_DIR = Path().resolve()
SRC_DIR = NOTEBOOK_DIR.parent / 'src'
if str(SRC_DIR) not in sys.path:
    sys.path.append(str(SRC_DIR))

from config import CONFIG  # noqa: E402
from inference import load_model, reconstruct  # noqa: E402
from train import train  # noqa: E402

CONFIG

In [None]:
metrics = train(CONFIG)
metrics

### Interpret the metrics
- The returned dictionary includes reconstruction loss and PSNR per epoch.
- Use `pandas.DataFrame(metrics)` to plot curves and spot plateaus or divergence.
- Track these values in TensorBoard by plugging the notebook into the logging callbacks.
- Benchmark against other variants (denoising, sparse) to understand baseline behaviour.

In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf

(x_train, _), (x_test, _) = tf.keras.datasets.fashion_mnist.load_data(path=str(CONFIG.data_dir / 'fashion-mnist.npz'))
image = x_test[0] / 127.5 - 1.0
image = np.expand_dims(image, axis=(0, -1))
model = load_model(config=CONFIG)
reconstruction = reconstruct(image, model=model, config=CONFIG)[0]

def undo_normalisation(arr):
    return (arr.squeeze() + 1.0) * 0.5

fig, axes = plt.subplots(1, 2, figsize=(6, 3))
axes[0].imshow(undo_normalisation(image[0]), cmap='gray')
axes[0].set_title('Original')
axes[0].axis('off')
axes[1].imshow(undo_normalisation(reconstruction), cmap='gray')
axes[1].set_title('Reconstruction')
axes[1].axis('off')
plt.tight_layout()

### Next experiments
- Enable checkpoint callbacks to monitor best PSNR across epochs.
- Introduce dropout or batch normalisation layers in `model.py` and observe the effect.
- Export the encoder as a standalone feature extractor for downstream classifiers.
- Compare training curves with the PyTorch implementation to spot framework differences.