In [None]:

# clone the repo
#!git clone https://github.com/Sarthak16082/DDT
# cd into DDT
#%cd DDT

#!pip install lightning==2.5.0 torch torchvision torchaudio pyyaml diffusers timm
#!wget https://huggingface.co/MCG-NJU/DDT-XL-22en6de-R512/resolve/main/model.ckpt


In [None]:
from huggingface_hub import login

login(token="use your token here") #it streams ImageNet, so you gotta use HF_token after getting the permission for ImageNet

## Imports and Model Loading

In [None]:


# Install dependencies
#!pip install datasets
%cd DDT

OUTPUT_DIR = "extracted_features_21"
# 
import os, sys, yaml, copy, gc, logging, torch, torchvision.transforms as T
from torch import nn
from torch.utils.data import DataLoader, IterableDataset
from src.lightning_model import LightningModel as MyLightningModel
from sklearn.model_selection import train_test_split
from datasets import load_dataset
from huggingface_hub import login
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from glob import glob
import os
from sklearn.model_selection import train_test_split
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import matplotlib.cm as cm

# Clear memory
gc.collect()
torch.cuda.empty_cache()

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Add repo root to Python path
sys.path.insert(0, os.getcwd())

# helpers
def instantiate(cfg):
    if not isinstance(cfg, dict) or "class_path" not in cfg:
        raise ValueError("Config must have a 'class_path'")
    module_path, cls_name = cfg["class_path"].rsplit(".", 1)
    module = __import__(module_path, fromlist=[cls_name])
    return getattr(module, cls_name)(**cfg.get("init_args", {}))

# load config and components
with open("configs/repa_improved_ddt_xlen22de6_512.yaml") as f:
    cfg = yaml.safe_load(f)

print("Instantiating model parts...")
vae = instantiate(cfg["model"]["vae"])
cond = instantiate(cfg["model"]["conditioner"])
den = instantiate(cfg["model"]["denoiser"])

# shared scheduler
sched_path = (cfg["model"]["diffusion_trainer"]["init_args"]
              .get("scheduler") or cfg["model"]["diffusion_sampler"]["init_args"].get("scheduler"))
assert isinstance(sched_path, str)
sched = instantiate({"class_path": sched_path, "init_args": {}})

# Deep copy to preserve original config details
trainer_cfg = copy.deepcopy(cfg["model"]["diffusion_trainer"])
trainer_cfg["init_args"]["scheduler"] = sched
trainer = instantiate(trainer_cfg)

sampler_cfg = copy.deepcopy(cfg["model"]["diffusion_sampler"])
sampler_cfg["init_args"]["scheduler"] = sched
sampler = instantiate(sampler_cfg)

# load lightning model
assert os.path.isfile("model.ckpt"), "Checkpoint not found."
model_cpu = MyLightningModel.load_from_checkpoint(
    "model.ckpt", map_location="cpu",
    vae=vae, conditioner=cond,
    denoiser=den, diffusion_trainer=trainer,
    diffusion_sampler=sampler, strict=False
)
print("Loaded checkpoint on CPU")

# move to muGPU(s) if available
if torch.cuda.is_available():
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model_cpu).to(device)
        model.module.eval()
        model.module.freeze()
    else:
        model = model_cpu.to(device)
        model.eval()
        model.freeze()
else:
    model = model_cpu
    model.eval()
    model.freeze()

print(f"Model ready on {device}")

ddt = model.module.denoiser if isinstance(model, nn.DataParallel) else model.denoiser
vae = model.module.vae if isinstance(model, nn.DataParallel) else model.vae
print(f"DDT model device: {next(ddt.parameters()).device}")
"""
#not applicable for the correct timestep
x = torch.randn(1, 4, 64, 64, device=device)
t = torch.randint(0, 1000, (1,), device=device)
y = torch.randint(0, 1000, (1,), device=device)
"""
feat = {}
def hook(m, i, o): feat["f"] = (o[0] if isinstance(o, tuple) else o).detach()
layer = ddt.blocks[cfg["model"]["denoiser"]["init_args"]["num_encoder_blocks"] - 1]
"""
h = layer.register_forward_hook(hook)
_ = ddt(x, t, y)
h.remove()
print("Extracted features shape:", feat["f"].shape)
"""

# build dataloader
transform = T.Compose([
    T.Resize(256), T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

def preprocess(batch):
    batch["imgs"] = [transform(img.convert("RGB")) for img in batch["image"]]
    return batch

ds = load_dataset("imagenet-1k", split="train", streaming=True)
ds = ds.shuffle(buffer_size=1000).map(preprocess, batched=True, batch_size=32)

class Wrap(IterableDataset):
    def __init__(self, ds): self.ds = ds
    def __iter__(self):
        for item in self.ds:
            yield item["imgs"], item["label"]

dataloader = DataLoader(Wrap(ds), batch_size=128, num_workers=2)
print("Dataloader ready:", next(iter(dataloader))[0].shape)

# extract and save features
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
BATCH_SIZE = 128
CHUNK = 100000
SAVE_EVERY = CHUNK//BATCH_SIZE
print(SAVE_EVERY)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# preset scheduler constants
scheduler = getattr(model.module if isinstance(model, nn.DataParallel) else model, 
                    "diffusion_trainer", model).scheduler
t0 = torch.tensor([0.95], device=device)
a0 = scheduler.alpha(t0).view(-1,1,1,1)
s0 = scheduler.sigma(t0).view(-1,1,1,1)

hook_feats = {"f": None}
def hook_fn(_,__,o):
    hook_feats["f"] = (o[0] if isinstance(o, tuple) else o).detach()

HOOK_IDX = cfg["model"]["diffusion_trainer"]["init_args"].get("num_encoder_blocks", 22) - 1

print(HOOK_IDX)
blk = ddt.blocks[HOOK_IDX]
h = blk.register_forward_hook(hook_fn)

chunks, labs, toks = [], [], []
cnt = 0

In [12]:
# for when CUDA goes out of memory
import torch
import gc

# Clear CUDA cache
torch.cuda.empty_cache()

# Force garbage collection
gc.collect()

# Optional: Clear all cached memory
if torch.cuda.is_available():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

## Feature Extaction Loop

In [None]:

from tqdm import tqdm
NUM_SAMPLES = 1281167
NUM_BATCHES = NUM_SAMPLES // BATCH_SIZE

for i, (imgs, labs_batch) in enumerate(tqdm(dataloader, desc="Extracting Features", total=NUM_BATCHES, leave=True)):
    imgs = imgs.to(device)
    labs_batch = labs_batch.to(device)

    with torch.no_grad():
        z_raw = vae.encode(imgs)
        z = z_raw.sample() if hasattr(z_raw, "sample") else z_raw
        noise = torch.randn_like(z)
        zt = a0 * z + s0 * noise
        uncond = torch.full_like(labs_batch, cfg["model"]["denoiser"]["init_args"].get("num_classes", 1000))
        out = ddt(zt, torch.full((z.shape[0],), 0.95, device=device, dtype=torch.long), uncond)

        feats = hook_feats["f"]
        if feats is not None:
            if feats.ndim == 4:
                feats = feats.mean(dim=[2, 3])  # [B, C, H, W] → [B, C]
            elif feats.ndim == 3:
                feats = feats.mean(dim=1)       # [B, N, D] → [B, D]
            elif feats.ndim == 2:
                pass
            else:
                raise ValueError(f"Unexpected feat shape: {feats.shape}")
    
    chunks.append(feats.cpu())
    labs.append(labs_batch.cpu())
    toks.append(out[0].detach().cpu())
    del z, zt, noise, out, feats
    torch.cuda.empty_cache()

    if (i + 1) % SAVE_EVERY == 0:
        torch.save(torch.cat(chunks), os.path.join(OUTPUT_DIR, f"feat_{cnt}.pt"))
        torch.save(torch.cat(labs), os.path.join(OUTPUT_DIR, f"labs_{cnt}.pt"))
        torch.save(torch.cat(toks), os.path.join(OUTPUT_DIR, f"toks_{cnt}.pt"))
        logging.info(f"Saved chunk {cnt}")
        cnt += 1
        chunks, labs, toks = [], [], []



## Visualize the Features with a t-SNE Plot

In [None]:
def load_tensors(pattern):
    files = sorted(glob(pattern))
    print("loading data")
    if not files:
        raise FileNotFoundError(f"No files match: {pattern}")
    return torch.cat([torch.load(f) for f in files])

def reshape_if_needed(x):
    return x.view(x.size(0), -1) if x.dim() != 2 else x

def plot_tsne(data, labels, title, max_points=20000):
    data = data.cpu()
    labels = labels.cpu()

    if data.size(0) > max_points:
        idx = torch.randperm(data.size(0))[:max_points]
        data = data[idx]
        labels = labels[idx]

    tsne = TSNE(n_components=2, perplexity=30, random_state=42)
    reduced = tsne.fit_transform(data)
    labels_np = labels.numpy()

    # Normalize labels to [0, 1] range for colormap
    norm_labels = (labels_np - labels_np.min()) / (labels_np.max() - labels_np.min())
    colors = cm.viridis(norm_labels)

    plt.figure(figsize=(6, 6))
    plt.scatter(reduced[:, 0], reduced[:, 1], c=colors, s=5, alpha=0.7)
    plt.title(title)
    plt.tight_layout()
    plt.colorbar(label='Class index (normalized)')
    plt.show()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base = 'extracted_features'
feats = load_tensors(os.path.join(base, 'feat_*.pt'))
toks = load_tensors(os.path.join(base, 'toks_*.pt'))
labs = load_tensors(os.path.join(base, 'labs_*.pt')).long().squeeze()

if feats.dim() == 4 and toks.dim() == 2:
    feats, toks = toks, feats

feats = reshape_if_needed(feats)
toks = reshape_if_needed(toks)

plot_tsne(feats, labs, "t-SNE: Features + Labels")
plot_tsne(toks, labs, "t-SNE: Tokens + Labels")

## Run a Linear Probe

In [None]:
X = F.normalize(torch.cat((feats, toks), dim=1), p=2, dim=1)

if X.size(0) != labs.size(0):
    print("Mismatch between features and labels.")
    

try:
    y_cpu = labs.cpu()
    X_train, X_val, y_train, y_val = train_test_split(X, labs, test_size=0.2, stratify=y_cpu, random_state=42)
except ValueError:
    X_train, X_val, y_train, y_val = train_test_split(X, labs, test_size=0.2, random_state=42)

model = nn.Linear(X_train.size(1), len(torch.unique(labs))).to(device)
opt = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(100):
    model.train()
    perm = torch.randperm(X_train.size(0))
    for i in range(0, X_train.size(0), 64):
        idx = perm[i:i+64]
        batch_X = X_train[idx].to(device)
        batch_y = y_train[idx].to(device)
        out = model(batch_X)
        loss = loss_fn(out, batch_y)
        opt.zero_grad()
        loss.backward()
        opt.step()

    if (epoch + 1) % 10 == 0:
        model.eval()
        with torch.no_grad():
            val_out = model(X_val.to(device))
            val_loss = loss_fn(val_out, y_val.to(device))
            acc = (val_out.argmax(1) == y_val.to(device)).float().mean().item() * 100
            print(f"Epoch {epoch+1}: Loss {loss.item():.4f}, Val Loss {val_loss.item():.4f}, Acc {acc:.2f}%")

