# PCA & t-SNE Visualization of θ Space

This notebook loads all `theta.pt` vectors (TinySIREN, TinySIREN+WT, or TinySIREN+WT+KFAC),
applies PCA and t-SNE, and plots 2D embeddings colored by class.

**Prerequisites**:
- A folder of `theta.pt` files organized by class (e.g. `data/Thetas/TinySIREN_WT/ACDC_Subset/`)
- `class_to_idx` JSON mapping (to find numeric labels or reverse-map for plotting).
- Installed: `torch`, `numpy`, `scikit-learn`, `matplotlib`.


In [None]:
import os
import torch
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import json

# Adjust these paths as needed:
THETA_ROOT = "D:\Emin\PythonProjects\lfd_project\data\Thetas\Baseline_SMALL_INR\BRATS_Subset"
CLASS_TO_IDX_JSON = "src/class_to_idx_acdc.json"
OUTPUT_DIR = "results/pca_tsne_brats"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Load class_to_idx, then build reverse map idx->class for labeling
with open(CLASS_TO_IDX_JSON, "r") as f:
    class_to_idx = json.load(f)
idx_to_class = {v: k for k, v in class_to_idx.items()}


## 1. Load θ Vectors and Corresponding Labels


In [None]:
theta_paths = []
labels = []
for cls, idx in class_to_idx.items():
    cls_dir = os.path.join(THETA_ROOT, cls)
    if not os.path.isdir(cls_dir):
        continue
    for fname in os.listdir(cls_dir):
        if fname.endswith("_theta.pt"):
            theta_paths.append(os.path.join(cls_dir, fname))
            labels.append(idx)

print(f"Found {len(theta_paths)} theta files across {len(class_to_idx)} classes.")
thetas = [torch.load(p).numpy() for p in theta_paths]  # List of (D,) arrays
thetas = np.stack(thetas, axis=0)  # Shape = (N, D)
labels = np.array(labels)
print("θ shape:", thetas.shape, "Labels shape:", labels.shape)


## 2. PCA (2D) Embedding


In [None]:
pca2 = PCA(n_components=2)
emb_pca2 = pca2.fit_transform(thetas)  # (N, 2)

plt.figure(figsize=(6,6))
for idx in np.unique(labels):
    mask = (labels == idx)
    plt.scatter(emb_pca2[mask, 0], emb_pca2[mask, 1],
                label=idx_to_class[idx], s=15, alpha=0.7)
plt.legend(loc="best")
plt.title("PCA(2) of θ Space")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "pca2d.png"))
plt.show()


## 3. t-SNE (2D) Embedding

- **Perplexity** = 30, **n_iter** = 1000 (may take a minute)


In [None]:
tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=42)
emb_tsne2 = tsne.fit_transform(thetas)  # (N, 2)

plt.figure(figsize=(6,6))
for idx in np.unique(labels):
    mask = (labels == idx)
    plt.scatter(emb_tsne2[mask, 0], emb_tsne2[mask, 1],
                label=idx_to_class[idx], s=15, alpha=0.7)
plt.legend(loc="best")
plt.title("t-SNE(2) of θ Space")
plt.xlabel("Dim 1")
plt.ylabel("Dim 2")
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "tsne2d.png"))
plt.show()


## 4. Compare a Few Points in θ Space

Here we pick two θ vectors (one from class A, one from class B),
reconstruct their MLP output on a small grid, and visualize differences.
(Optional sanity check.)


In [None]:
from src.inr_models import TinySIREN

def reconstruct_from_theta(theta_vec, omega_0=30.0, H=32, W=32):
    """
    Reconstruct a low‐res image by sampling MLP on a coarse grid.
    (Using TinySIREN structure; this is just for quick visualization.)
    """
    model = TinySIREN(omega_0=omega_0)
    # Load weights into model (matching flatten order)
    pointer = 0
    for param in model.parameters():
        numel = param.numel()
        vals = theta_vec[pointer : pointer + numel]
        param.data = torch.from_numpy(vals.reshape(param.shape)).float()
        pointer += numel

    # Build a coarse grid of coords shape (H*W,2)
    xs = np.linspace(0, 1, H)
    ys = np.linspace(0, 1, W)
    grid = np.stack(np.meshgrid(xs, ys, indexing="ij"), axis=-1).reshape(-1,2)
    coords = torch.from_numpy(grid.astype(np.float32)).to(model.layer1.linear.weight.device)
    with torch.no_grad():
        out = model(coords).cpu().numpy()
    return out.reshape(H, W)

# Pick one theta from class 0 and one from class 1 (if >=2 classes exist)
cls0_idxs = np.where(labels == 0)[0]
cls1_idxs = np.where(labels == 1)[0] if len(np.unique(labels)) > 1 else []
idx0 = cls0_idxs[0]
rec0 = reconstruct_from_theta(thetas[idx0], omega_0=30.0, H=64, W=64)

plt.figure(figsize=(4,4))
plt.imshow(rec0, cmap="gray")
plt.title(f"Reconstruction of θ[{idx0}] ({idx_to_class[labels[idx0]]})")
plt.axis("off")
plt.show()

if len(cls1_idxs) > 0:
    idx1 = cls1_idxs[0]
    rec1 = reconstruct_from_theta(thetas[idx1], omega_0=30.0, H=64, W=64)
    plt.figure(figsize=(4,4))
    plt.imshow(rec1, cmap="gray")
    plt.title(f"Reconstruction of θ[{idx1}] ({idx_to_class[labels[idx1]]})")
    plt.axis("off")
    plt.show()


## 5. Conclusion

- PCA(2) and t-SNE(2) plots show how θ vectors cluster by class.
- You can adjust `perplexity` or `n_iter` for t-SNE to refine separation.
- Next: measure numeric intra‐class / inter‐class distances or silhouette scores in code.
