# Reconstruction Demo (Colab-ready)

This mirrors `notebooks/clip_unclip_reconstruction.ipynb` but installs dependencies, clones the repo, and downloads data/models automatically so it can run in a fresh Google Colab runtime.

In [None]:
import os, sys, subprocess
from pathlib import Path

def run(cmd):
    print(" ".join(cmd))
    subprocess.check_call(cmd)

def install_torch(torch_index: str):
    base_pkgs = [
        "numpy<2",
        "torch==2.4.1",
        "torchvision==0.19.1",
        "torchaudio==2.4.1",
        "diffusers",
        "transformers",
        "fastai>=2.7.16",
        "omegaconf",
        "git+https://github.com/ml-jku/hopfield-layers.git",
        "timm",
        "tqdm",
    ]
    cmd = [
        sys.executable, "-m", "pip", "install", "-q", "--upgrade", "--force-reinstall",
        "--extra-index-url", torch_index,
        *base_pkgs,
    ]
    run(cmd)

def install_pillow():
    run([sys.executable, "-m", "pip", "install", "-q", "--upgrade", "--force-reinstall", "pillow==12.0.0"])

# Install requirements if missing (common on Colab)
# Colab ships numpy>=2 which is ABI-incompatible with these torch wheels; pin to <2.
if "google.colab" in sys.modules:
    try:
        install_torch("https://download.pytorch.org/whl/cu121")
    except subprocess.CalledProcessError:
        print("CUDA wheels unavailable; retrying CPU wheels.")
        install_torch("https://download.pytorch.org/whl/cpu")
    install_pillow()
    import PIL
    print("Pillow version:", PIL.__version__)

# Clone repo if src/ is absent
REPO_URL = os.environ.get("REPO_URL", "https://github.com/aletheia88/dream-with-embeddings")
fallback_repo_dir = Path("/content/large-embedding-models") if "google.colab" in sys.modules else Path.cwd()
default_repo_dir = Path(os.environ.get("REPO_DIR", fallback_repo_dir)).expanduser()
# Prefer a pre-existing checkout nearby before cloning.
repo_candidates = [Path.cwd(), default_repo_dir]
REPO_DIR = next((p for p in repo_candidates if (p / "src").exists()), default_repo_dir)
if not (REPO_DIR / "src").exists():
    run(["git", "clone", REPO_URL, str(REPO_DIR)])

os.chdir(REPO_DIR)
if str(REPO_DIR / "src") not in sys.path:
    sys.path.append(str(REPO_DIR / "src"))

print("Repo ready at", REPO_DIR)


In [None]:
from pathlib import Path
import sys

import torch
from torchvision import transforms
import matplotlib.pyplot as plt


def find_src_root():
    cwd = Path.cwd()
    candidates = [cwd]
    for parent in cwd.parents:
        candidates.append(parent)
    for base in candidates:
        src = base / "src"
        if src.exists():
            return base, src
    raise RuntimeError("Could not locate 'src' directory. Run this notebook from the repo root.")


repo_root, src_root = find_src_root()
if str(src_root) not in sys.path:
    sys.path.append(str(src_root))

from knn_restore import (
    EncoderConfigs,
    ReconstructConfigs,
    GlobalConfigs,
    load_encoder,
    extract_features,
    create_dataloaders,
    train,
    evaluate,
    seed_everything,
    set_up,
)
from restore_methods import evaluate_reconstruction
import dataloader
from diffusers import StableUnCLIPImg2ImgPipeline


In [None]:
# Download Imagenette via fastai (cached after first run)
global_cfg, encoder_cfg, restore_cfg, unclip_cfg = set_up()
device = torch.device(global_cfg.device)
seed = global_cfg.seed
seed_everything(seed)

# Notebook-friendly overrides
batch_size = 32
num_workers = 4
corrupt_range = (0.50, 0.70)
max_train = None
max_val = None
num_visualize = 4

restore_cfg.hidden_dims = 1024
restore_cfg.epochs = 3

SMOKE_TEST = os.environ.get("SMOKE_TEST", "0") == "1"
if SMOKE_TEST:
    batch_size = min(batch_size, 8)
    num_workers = 0
    max_train = 128
    max_val = 64
    num_visualize = min(num_visualize, 2)
    restore_cfg.epochs = 1
    unclip_cfg.enabled = False

global_cfg.batch_size = batch_size
global_cfg.num_workers = num_workers
global_cfg.corrupt_range = corrupt_range
global_cfg.scheme = "occlude"

generator = torch.Generator(device=device).manual_seed(seed)
global_cfg.generator = generator

torch.manual_seed(seed)

IMAGENET_MEAN = torch.tensor((0.485, 0.456, 0.406)).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor((0.229, 0.224, 0.225)).view(1, 3, 1, 1)


In [None]:
clean_train_loader, clean_val_loader = dataloader.get_imagenette_loaders(
    scheme="baseline",
    corrupt_range=None,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=False,
)
corrupt_train_loader, corrupt_val_loader = dataloader.get_imagenette_loaders(
    scheme=global_cfg.scheme,
    corrupt_range=global_cfg.corrupt_range,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=False,
)

def denormalize(batch: torch.Tensor) -> torch.Tensor:
    return (batch * IMAGENET_STD.to(batch) + IMAGENET_MEAN.to(batch)).clamp(0, 1)

clean_samples, corrupt_samples, sample_labels = [], [], []
clean_iter = iter(clean_val_loader)
corrupt_iter = iter(corrupt_val_loader)
while len(clean_samples) < num_visualize:
    clean_batch, clean_lbl = next(clean_iter)
    corrupt_batch, _ = next(corrupt_iter)
    for idx in range(clean_batch.size(0)):
        clean_samples.append(denormalize(clean_batch[idx:idx+1])[0].cpu())
        corrupt_samples.append(denormalize(corrupt_batch[idx:idx+1])[0].cpu())
        sample_labels.append(int(clean_lbl[idx]))
        if len(clean_samples) >= num_visualize:
            break

fig, axes = plt.subplots(num_visualize, 2, figsize=(5, 2.7 * num_visualize))
for row in range(num_visualize):
    axes[row, 0].imshow(clean_samples[row].permute(1, 2, 0).numpy())
    axes[row, 0].set_title(f"Original (class {sample_labels[row]})")
    axes[row, 0].axis("off")
    axes[row, 1].imshow(corrupt_samples[row].permute(1, 2, 0).numpy())
    axes[row, 1].set_title("Occluded input")
    axes[row, 1].axis("off")
plt.tight_layout()


In [None]:
encoder_bundle = load_encoder(encoder_cfg, device)

clean_train_embeddings, train_labels = extract_features(
    encoder_bundle,
    clean_train_loader,
    scheme="baseline",
    device=device,
    generator=torch.Generator(device=device).manual_seed(seed),
    noise_std=None,
    max_items=max_train,
)
corrupt_train_embeddings, _ = extract_features(
    encoder_bundle,
    corrupt_train_loader,
    scheme=global_cfg.scheme,
    device=device,
    generator=torch.Generator(device=device).manual_seed(seed + 1),
    noise_std=None,
    max_items=max_train,
)
clean_valid_embeddings, valid_labels = extract_features(
    encoder_bundle,
    clean_val_loader,
    scheme="baseline",
    device=device,
    generator=torch.Generator(device=device).manual_seed(seed + 2),
    noise_std=None,
    max_items=max_val,
)
corrupt_valid_embeddings, _ = extract_features(
    encoder_bundle,
    corrupt_val_loader,
    scheme=global_cfg.scheme,
    device=device,
    generator=torch.Generator(device=device).manual_seed(seed + 3),
    noise_std=None,
    max_items=max_val,
)

train_loader, train_eval_loader, val_loader = create_dataloaders(
    global_cfg,
    restore_cfg,
    clean_train_embeddings,
    corrupt_train_embeddings,
    clean_valid_embeddings,
    corrupt_valid_embeddings,
)

model, losses = train(
    restore_cfg,
    train_loader,
    val_loader,
    device=device,
    feature_dim=encoder_bundle["feature_dim"],
    normalize_outputs=encoder_bundle.get("normalize", True),
)
train_pred, valid_pred = evaluate(
    model,
    train_eval_loader,
    val_loader,
    device,
    encoder_bundle.get("normalize", True),
)
print("Train metrics:", evaluate_reconstruction(train_pred, clean_train_embeddings))
print("Valid metrics:", evaluate_reconstruction(valid_pred, clean_valid_embeddings))


In [None]:
pipe_dtype = torch.float16 if device.type == "cuda" else torch.float32
pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
    "sd2-community/stable-diffusion-2-1-unclip-small",
    torch_dtype=pipe_dtype,
)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=True)
pipe.enable_attention_slicing()

num_recon = min(num_visualize, valid_pred.size(0))
clip_embeds_clean = valid_pred[:num_recon].to(device=device, dtype=pipe_dtype)
clip_embeds_corrupt = corrupt_valid_embeddings[:num_recon].to(device=device, dtype=pipe_dtype)
prompts = [""] * num_recon

to_pil = transforms.ToPILImage()
init_images = [to_pil(img.cpu()) for img in corrupt_samples[:num_recon]]

generator = torch.Generator(device=device).manual_seed(seed)
with torch.inference_mode():
    recon_from_clean = pipe(
        image=None,
        prompt=prompts,
        image_embeds=clip_embeds_clean,
        guidance_scale=5.0,
        num_inference_steps=25,
        generator=generator,
    ).images

with torch.inference_mode():
    recon_from_corrupt = pipe(
        image=None,
        prompt=prompts,
        image_embeds=clip_embeds_corrupt,
        guidance_scale=5.0,
        num_inference_steps=25,
        generator=torch.Generator(device=device).manual_seed(seed + 1),
    ).images

recon_tensors = [transforms.ToTensor()(img) for img in recon_from_clean]
fragment_tensors = [transforms.ToTensor()(img) for img in recon_from_corrupt]

In [None]:
fig, axes = plt.subplots(num_recon, 4, figsize=(12, 3 * num_recon))
columns = ["Original", "Occluded", "Stable unCLIP", "Fragmented unCLIP"]
for row in range(num_recon):
    triplet = [
        clean_samples[row],
        corrupt_samples[row],
        recon_tensors[row],
        fragment_tensors[row],
    ]
    for col in range(4):
        axes[row, col].imshow(triplet[col].permute(1, 2, 0).numpy())
        axes[row, col].set_title(columns[col])
        axes[row, col].axis("off")
plt.tight_layout()
