# TensorFlow Deconvolutional Autoencoder Lab

### Why this notebook
- Demonstrate the convolutional autoencoder implemented in TensorFlow/Keras.
- Provide a structured walkthrough for training and inspecting reconstructions.
- Mirror the PyTorch notebook for learners comparing frameworks.

### Learning objectives
- Train the convolutional autoencoder on Fashion-MNIST using the Keras `Model` subclass.
- Reconstruct samples and analyse spatial detail vs the dense baseline.
- Experiment with architectural tweaks and observe their impact on PSNR.

### Prerequisites
- TensorFlow 2.x installed with GPU optional.
- Familiarity with the vanilla autoencoder notebook and convolution basics.
- Optional: Matplotlib for visual comparisons.

### Notebook workflow
1. Import configuration, training, and inference utilities from `../src`.
2. Run `train(CONFIG)` to fit the model and log metrics.
3. Load saved weights and visualise original vs reconstructed images.
4. Extend with feature map inspections, skip connections, or denoising tests.


**Workflow**

1. Import config + training helpers.
2. Fit the model (device selection handled automatically).
3. Reconstruct a test image and visualise the result.

In [None]:
from pathlib import Path
import sys

NOTEBOOK_DIR = Path().resolve()
SRC_DIR = NOTEBOOK_DIR.parent / 'src'
if str(SRC_DIR) not in sys.path:
    sys.path.append(str(SRC_DIR))

from config import CONFIG  # noqa: E402
from inference import load_model, reconstruct  # noqa: E402
from train import train  # noqa: E402

CONFIG

In [None]:
metrics = train(CONFIG)
metrics

### Interpret the metrics
- Review reconstruction loss and PSNR across epochs to gauge convergence.
- Plotting the metrics helps detect overfitting due to the higher-capacity decoder.
- Compare against the PyTorch notebook to ensure training parity.
- Consider logging to TensorBoard for longer experiments.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

(_, _), (test_images, _) = tf.keras.datasets.fashion_mnist.load_data()
image = test_images[0][:, :, np.newaxis]
image_normalised = (image / 255.0 - 0.5) / 0.5
model = load_model(config=CONFIG)
reconstruction = reconstruct([image_normalised], model=model)[0]

def to_numpy(arr):
    return (arr.squeeze() * 0.5 + 0.5)

fig, axes = plt.subplots(1, 2, figsize=(6, 3))
axes[0].imshow(image.squeeze() / 255.0, cmap='gray')
axes[0].set_title('Original')
axes[0].axis('off')
axes[1].imshow(to_numpy(reconstruction), cmap='gray')
axes[1].set_title('Reconstruction')
axes[1].axis('off')
plt.tight_layout()

### Next experiments
- Visualise intermediate feature maps to understand encoder/decoder behaviour.
- Add skip connections to build a lightweight U-Net and compare reconstructions.
- Introduce noise or masking to test robustness without retraining.
- Export the model to TensorFlow Lite to experiment with edge deployment.