# MAE on DTD — Plug & Play (Colab)

Run cells in order: setup → install → data → train → evaluate → visualize.

## 1. Setup project

**Option A:** Clone from GitHub (set your repo URL below).  
**Option B:** Upload the project zip to Colab, then unzip and skip the clone.  
**Option C:** Mount Google Drive and set `PROJECT_DIR` to your project folder.

In [None]:
import os

# Option A: Clone (change to your repo or leave as placeholder)
REPO_URL = "https://github.com/YOUR_USERNAME/D_MAE.git"  # ← change this
USE_CLONE = False  # set True to clone from REPO_URL

# Option C: Drive path (if you copied the project to Drive)
PROJECT_DIR = "/content/D_MAE"  # default after clone; or "/content/drive/MyDrive/D_MAE"

if USE_CLONE:
    !git clone {REPO_URL} /content/D_MAE
    %cd /content/D_MAE
else:
    # Assume project is already at PROJECT_DIR (e.g. uploaded zip extracted to /content/D_MAE)
    if not os.path.isdir(PROJECT_DIR):
        raise SystemExit("Project not found. Clone the repo (USE_CLONE=True), upload & unzip here, or set PROJECT_DIR.")
    %cd {PROJECT_DIR}

print("CWD:", os.getcwd())

## 2. Install dependencies

In [None]:
!pip install -q -r requirements.txt
print("Done.")

## 3. Download DTD

In [None]:
from torchvision.datasets import DTD
import os

os.makedirs("data", exist_ok=True)
for split in ["train", "val", "test"]:
    DTD(root="data", split=split, partition=1, download=True)
print("DTD ready at data/dtd/")

## 4. Choose config and train MAE

In [None]:
# Pick one: 0.75 (baseline), 0.90 (high), 0.95 (extreme)
MASK_RATIO = 0.75
CONFIG = "configs/mask75.yaml" if MASK_RATIO == 0.75 else "configs/mask90.yaml" if MASK_RATIO == 0.90 else "configs/mask95.yaml"
EPOCHS = 10  # use 100–150 for full runs; 10 for a quick test

print(f"Training MAE with mask_ratio={MASK_RATIO}, config={CONFIG}, epochs={EPOCHS}")

In [None]:
import sys
sys.path.insert(0, os.getcwd())

from src.utils.config import load_config
from src.training.train_mae import train_mae
import torch

cfg = load_config(CONFIG)
cfg["training"] = cfg.get("training", {}) | {"epochs": EPOCHS}

# Write updated config so train_mae can load it
import yaml
with open("configs/colab_run.yaml", "w") as f:
    yaml.dump(cfg, f, default_flow_style=False, sort_keys=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
ckpt_dir = train_mae("configs/colab_run.yaml", device=device)
print("Checkpoints:", ckpt_dir)

## 5. Evaluate: reconstruction + spectrum + linear probe

In [None]:
CKPT = str(ckpt_dir / "best.pt")
CONFIG_USED = "configs/colab_run.yaml"

from src.evaluation.reconstruction import run_reconstruction
from src.evaluation.spectrum import run_spectrum_analysis
from pathlib import Path

Path("experiments/results").mkdir(parents=True, exist_ok=True)

metrics = run_reconstruction(CONFIG_USED, CKPT, "experiments/results/reconstruction/colab", device=device)
print("Reconstruction PSNR / SSIM:", metrics)

spec = run_spectrum_analysis(CONFIG_USED, CKPT, "experiments/results/spectrum/colab", device=device, max_batches=20)
print("Spectrum (effective rank, etc.):", spec)

In [None]:
from src.training.linear_probe import run_linear_probe

acc = run_linear_probe(CONFIG_USED, CKPT, device=device, epochs=20)
print(f"Linear probe test accuracy: {acc:.4f}")

## 6. Visualize reconstruction

In [None]:
import matplotlib.pyplot as plt
from pathlib import Path

p = Path("experiments/results/reconstruction/colab/reconstruction_grid.png")
if p.exists():
    plt.figure(figsize=(8, 4))
    plt.imshow(plt.imread(p))
    plt.axis("off")
    plt.title("Original (top) vs Reconstructed (bottom)")
    plt.tight_layout()
    plt.show()
else:
    print("Run the reconstruction cell above first.")