# TensorFlow GAN Lab
This notebook mirrors the PyTorch GAN workflow using Keras/GradientTape. Execute the cells sequentially to configure the environment, train the model, and visualise freshly generated samples. Use the commentary to capture findings for future runs.

In [None]:
import sys
from pathlib import Path

NOTEBOOK_DIR = Path().resolve()
SRC_DIR = NOTEBOOK_DIR.parent / "src"
if str(SRC_DIR) not in sys.path:
    sys.path.append(str(SRC_DIR))

from config import CONFIG
from train import train
from inference import generate_samples

In [None]:
metrics = train(CONFIG)
metrics

## Review training metrics
Transform the raw training dictionary into a readable summary and optional plots for faster iteration.

In [None]:
try:
    import pandas as pd
except ImportError:
    pd = None
try:
    import matplotlib.pyplot as plt
except ImportError:
    plt = None

if pd is not None:
    df = pd.DataFrame(metrics)
    display(df.tail())
    if plt is not None and not df.empty:
        ax = df.plot(figsize=(6, 3), title="GAN losses")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        plt.tight_layout()
        plt.show()
else:
    from pprint import pprint
    pprint(metrics)

### Interpret the losses
- `d_loss` spikes can signal discriminator dominance; try reducing learning rate or adding label smoothing.
- `g_loss` collapsing near zero may hide vanishing gradients; adjust optimiser betas or add noise.
- Aggregate metrics with experiment metadata (e.g., mixed precision status) for later comparison.
- Plot both losses to spot cyclic behaviour typical of GAN training.

In [None]:
samples = generate_samples(CONFIG, num_images=16, output_path=CONFIG.artifact_dir / 'notebook_samples.png')
samples.shape

## Visualise the sample grid
Inspect the generated batch to spot artefacts, diversity, and mode collapse signals early.

In [None]:
import math
import numpy as np
try:
    import matplotlib.pyplot as plt
except ImportError as exc:
    print(f"Visualization skipped: {exc}")
else:
    images = samples.numpy()
    images = (images + 1.0) / 2.0
    n_cols = int(math.sqrt(images.shape[0])) or 1
    n_rows = math.ceil(images.shape[0] / n_cols)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2))
    axes = np.array(axes).reshape(n_rows, n_cols)
    for idx in range(n_rows * n_cols):
        ax = axes[idx // n_cols, idx % n_cols]
        ax.axis("off")
        if idx < images.shape[0]:
            ax.imshow(images[idx].squeeze(), cmap="gray")
    plt.suptitle("Generated samples", y=0.95)
    plt.tight_layout()
    plt.show()

### Next experiments
- Enable `tf.keras.mixed_precision` to benchmark GPU/TPU speedups.
- Evaluate sample grids per epoch to detect early mode collapse.
- Introduce gradient penalty or spectral normalisation for improved stability.
- Port the pipeline to other data (e.g., EMNIST) by customising `data.py`.