In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
from pathlib import Path

from config import Config
from utils import ensure_dirs, device
from dataset import ImageFolderDataset
from models.dinov2_wrapper import DINOv2Extractor

def extract(split="train"):
    cfg = Config()
    dev = device()
    ensure_dirs(cfg.features_dir)

    ds = ImageFolderDataset(cfg.data_root/split, img_size=cfg.img_size, augment=False)
    dl = DataLoader(ds, batch_size=32, shuffle=False, num_workers=cfg.num_workers)

    extractor = DINOv2Extractor(cfg.dinov2_variant, str(cfg.dinov2_local_ckpt)).to(dev)
    extractor.eval()

    feats_list, y_list = [], []

    for x, y in tqdm(dl, desc=f"Extracting DINOv2 features ({split})"):
        x = x.to(dev)
        with torch.no_grad():
            f = extractor(x) 
        feats_list.append(f.cpu().numpy())
        y_list.append(y.numpy())

    X = np.concatenate(feats_list, axis=0)
    Y = np.concatenate(y_list, axis=0)

    np.save(cfg.features_dir / f"X_{split}.npy", X)
    np.save(cfg.features_dir / f"y_{split}.npy", Y)
    print("Saved:", cfg.features_dir / f"X_{split}.npy")

if __name__ == "__main__":
    extract("train")
    extract("val")
    extract("test")
