# MAE on DTD — Plug & Play (Colab)

Same flow as `run_colab_local.py`: setup → install → DTD download → train → evaluate → visualize. Run cells in order.

## 1. Setup project root

**Colab:** Clone from GitHub (set `REPO_URL`) or upload zip / use Drive and set `PROJECT_DIR`.  
**Local:** Run from project root, or from `notebooks/`; script will use the parent dir as project root.

In [None]:
import os
import sys

# Colab: set REPO_URL and USE_CLONE=True to clone; or set PROJECT_DIR to Drive/upload path.
REPO_URL = "https://github.com/YOUR_USERNAME/D_MAE.git"
USE_CLONE = False
PROJECT_DIR = "/content/D_MAE"  # Colab after clone; or e.g. "/content/drive/MyDrive/D_MAE"
if not os.path.isdir(PROJECT_DIR):
    cwd = os.getcwd()
    PROJECT_DIR = os.path.dirname(cwd) if os.path.basename(cwd) == "notebooks" else cwd

if USE_CLONE:
    !git clone {REPO_URL} /content/D_MAE
    %cd /content/D_MAE
else:
    %cd {PROJECT_DIR}

sys.path.insert(0, os.getcwd())
print("CWD:", os.getcwd())

## 2. Install dependencies

Same as `run_colab_local.py`: `pip install -r requirements.txt` from project root.

In [1]:
!pip install -q -r requirements.txt
print("Dependencies ready.")

Done.



[notice] A new release of pip is available: 25.3 -> 26.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip
ERROR: Could not open requirements file: [Errno 2] No such file or directory: 'requirements.txt'


## 3. Download DTD

Same as `run_colab_local.py`: train / val / test splits into `data/`.

In [None]:
from torchvision.datasets import DTD
import os

print("Downloading DTD...")
os.makedirs("data", exist_ok=True)
for split in ["train", "val", "test"]:
    DTD(root="data", split=split, partition=1, download=True)
print("DTD ready.")

## 4. Config and train MAE

Same as `run_colab_local.py`: base config `mask75.yaml`, override epochs, write `configs/colab_run.yaml`, then train. Use 1 epoch for smoke test; increase for full runs.

In [None]:
# Same defaults as run_colab_local.py: mask75, 1 epoch smoke test. Change for full runs.
CONFIG = "configs/mask75.yaml"
EPOCHS = 1  # use 100–150 for full runs
print(f"Training MAE: config={CONFIG}, epochs={EPOCHS}")

In [None]:
import yaml
import torch
from src.utils.config import load_config
from src.training.train_mae import train_mae

cfg = load_config(CONFIG)
cfg["training"] = cfg.get("training", {}) | {"epochs": EPOCHS}
with open("configs/colab_run.yaml", "w") as f:
    yaml.dump(cfg, f, default_flow_style=False, sort_keys=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
ckpt_dir = train_mae("configs/colab_run.yaml", device=device)
print("Checkpoints:", ckpt_dir)

## 5. Evaluate: reconstruction + spectrum + linear probe

In [None]:
CKPT = str(ckpt_dir / "best.pt")
CONFIG_USED = "configs/colab_run.yaml"

from src.evaluation.reconstruction import run_reconstruction
from src.evaluation.spectrum import run_spectrum_analysis
from pathlib import Path

Path("experiments/results").mkdir(parents=True, exist_ok=True)

metrics = run_reconstruction(CONFIG_USED, CKPT, "experiments/results/reconstruction/colab", device=device)
print("Reconstruction PSNR / SSIM:", metrics)

spec = run_spectrum_analysis(CONFIG_USED, CKPT, "experiments/results/spectrum/colab", device=device, max_batches=20)
print("Spectrum (effective rank, etc.):", spec)

In [None]:
from src.training.linear_probe import run_linear_probe

acc = run_linear_probe(CONFIG_USED, CKPT, device=device, epochs=20)
print(f"Linear probe test accuracy: {acc:.4f}")

## 6. Visualize reconstruction

In [None]:
import matplotlib.pyplot as plt
from pathlib import Path

p = Path("experiments/results/reconstruction/colab/reconstruction_grid.png")
if p.exists():
    plt.figure(figsize=(8, 4))
    plt.imshow(plt.imread(p))
    plt.axis("off")
    plt.title("Original (top) vs Reconstructed (bottom)")
    plt.tight_layout()
    plt.show()
else:
    print("Run the reconstruction cell above first.")