# 02 â€” GAN Analysis

Inspect cGAN training: loss curves, synthetic vs real sample distributions,
and quality of generated EEG features.

In [None]:
import sys, os
os.chdir('/content/amers')
sys.path.insert(0, '.')

from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

DRIVE_BASE = Path('/content/drive/MyDrive/AMERS')
CKPT = DRIVE_BASE / 'checkpoints'
OUT = DRIVE_BASE / 'outputs'

In [None]:
# Load training results
results = torch.load(OUT / 'training_results.pt', map_location='cpu')

if 'gan' in results:
    gan = results['gan']
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.plot(gan['g_loss'], label='Generator')
    ax.plot(gan['d_loss'], label='Discriminator')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('cGAN Training Loss')
    ax.legend()
    plt.show()

In [None]:
# Generate synthetic samples and compare with real
from src.models.gan import ConditionalGAN
from src.utils.config import load_config
from src.data.deap_loader import DEAPLoader
from src.data.label_mapper import LabelMapper

cfg = load_config('config/default.yaml')

gan_model = ConditionalGAN(
    feature_dim=cfg.model.gan.feature_dim,
    noise_dim=cfg.model.gan.noise_dim,
    hidden_dim=cfg.model.gan.hidden_dim,
    num_classes=cfg.model.num_classes,
)
gan_model.load_state_dict(torch.load(CKPT / 'gan' / 'gan_final.pt', map_location='cpu'))
gan_model.eval()

# Generate 200 synthetic per class
syn_feats, syn_labels = [], []
for c in range(4):
    labels = torch.full((200,), c, dtype=torch.long)
    with torch.no_grad():
        feats = gan_model.generate(labels)
    syn_feats.append(feats.numpy())
    syn_labels.extend([c] * 200)

syn_feats = np.concatenate(syn_feats)
syn_labels = np.array(syn_labels)
print(f'Generated: {syn_feats.shape}')

In [None]:
# t-SNE: real vs synthetic
loader = DEAPLoader(processed_dir=str(DRIVE_BASE / 'data' / 'deap' / 'processed'), label_mapper=LabelMapper())
real_feats, real_labels = loader.load_all(flatten=True)

# Subsample real for speed
idx = np.random.choice(len(real_feats), min(800, len(real_feats)), replace=False)
real_sub = real_feats[idx]

combined = np.vstack([real_sub, syn_feats])
is_synthetic = np.array([0]*len(real_sub) + [1]*len(syn_feats))

tsne = TSNE(n_components=2, perplexity=30, random_state=42)
emb = tsne.fit_transform(combined)

plt.figure(figsize=(8, 6))
plt.scatter(emb[is_synthetic==0, 0], emb[is_synthetic==0, 1], alpha=0.3, s=10, label='Real')
plt.scatter(emb[is_synthetic==1, 0], emb[is_synthetic==1, 1], alpha=0.3, s=10, label='Synthetic')
plt.legend()
plt.title('t-SNE: Real vs Synthetic EEG Features')
plt.show()