# PyTorch GAN Lab
Welcome to the hands-on companion for the PyTorch DCGAN baseline. Run the cells in order to configure the environment, train the model, and generate fresh samples. Use the commentary cells to capture observations and iterate on experiments.

In [None]:
import sys
from pathlib import Path

NOTEBOOK_DIR = Path().resolve()
SRC_DIR = NOTEBOOK_DIR.parent / "src"
if str(SRC_DIR) not in sys.path:
    sys.path.append(str(SRC_DIR))

from config import CONFIG
from train import train
from inference import generate_samples

In [None]:
metrics = train(CONFIG)
metrics

## Review training metrics
Turn the metrics dictionary into an inline report so you can compare runs quickly or log them to external tooling.

In [None]:
try:
    import pandas as pd
except ImportError:
    pd = None
try:
    import matplotlib.pyplot as plt
except ImportError:
    plt = None

if pd is not None:
    df = pd.DataFrame(metrics)
    display(df.tail())
    if plt is not None and not df.empty:
        ax = df.plot(figsize=(6, 3), title="GAN losses")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        plt.tight_layout()
        plt.show()
else:
    from pprint import pprint
    pprint(metrics)

### Interpret the losses
- `d_loss` should oscillate while staying finite; collapse suggests discriminators dominating or vanishing gradients.
- `g_loss` rising steadily indicates the generator struggling to fool the discriminator; tweak learning rates or betas.
- Log the metrics JSON with experiment metadata for easier comparison across runs.
- Plot the loss curves (`pandas.DataFrame(metrics).plot()`) to verify convergence trends.

In [None]:
samples = generate_samples(CONFIG, num_images=16, output_path=CONFIG.artifact_dir / 'notebook_samples.png')
samples.shape

## Visualise the sample grid
Render the generated batch so you can inspect image quality directly inside the notebook.

In [None]:
import math
try:
    import matplotlib.pyplot as plt
    from torchvision.utils import make_grid
except ImportError as exc:
    print(f"Visualization skipped: {exc}")
else:
    grid = make_grid(samples, nrow=int(math.sqrt(len(samples))), normalize=True, value_range=(-1, 1))
    plt.figure(figsize=(6, 6))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis("off")
    plt.title("Generated samples")
    plt.show()

### Next experiments
- Try smaller batches to see how instability manifests and adjust gradient accumulation if needed.
- Add label smoothing or noise to discriminator targets to improve training stability.
- Replace transpose convolutions with upsampling + convolutions to reduce checkerboard artefacts.
- Track Fr√©chet Inception Distance for objective quality measurements.