In [51]:
# batch processing 
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import clip
import pandas as pd
import numpy as np
from tqdm import tqdm
from pathlib import Path


IMG_PATH = '../data/train_images/'


df = pd.read_csv('../data/raw/train.csv')

In [52]:
torch.cuda.is_available()

True

In [53]:
# Assuming you have a GPU available
device = "cuda" if torch.cuda.is_available() else "cpu"

from torchvision import transforms as T

# Define DINO v2 preprocessing (adjust based on the model's requirements)
def dinov2_preprocess(image: Image.Image):
    # Example preprocessing steps; adjust as needed
    image_transforms = T.Compose([
        T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])
    return image_transforms(image)


class ImageDataset(Dataset):
    def __init__(self, df, transform):
        self.image_paths = df["image"].values
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(Path(IMG_PATH, f"{img_path}.jpg")).convert("RGB")
        except Exception as e:
            # print(f"Error loading image {img_path}: {e}")
            # create PIL empty image
            image = Image.new("RGB", (224, 224), (0, 0, 0))

        return self.transform(image), img_path


def collate_fn(batch):
    images, paths = zip(*batch)
    return torch.stack(images), paths

In [58]:
# clip

use_clip = False

if use_clip:
    model, preprocess = clip.load("ViT-B/32", device=device)
    dataset = ImageDataset(df, preprocess)
else:
    dinov2_model = torch.hub.load(
        "facebookresearch/dinov2", "dinov2_vitb14", pretrained=True
    )
    dinov2_model.to(device)
    dinov2_model.eval()
    dataset = ImageDataset(df, dinov2_preprocess)

dataloader = DataLoader(dataset, batch_size=64, num_workers=8, shuffle=False, collate_fn=collate_fn)

# Extract embeddings
embeddings = {}
iter_id = 0


with torch.no_grad():
    for images, paths in tqdm(dataloader):
        images = images.to(device)

        if use_clip:
            image_features = model.encode_image(images)
        else:
            image_features = dinov2_model(images)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        image_features = image_features.cpu().numpy()
        embeddings[paths] = image_features
        iter_id += 1


# get flattened embeddings
# currently each key contains list of file ids 
# and each value contains list of embeddings

new_embeddings = {}
for key, value in tqdm(embeddings.items()):
    for k, v in zip(key, value):
        new_embeddings[k] = v

if use_clip:
    filename = 'clip_embeddings.npz'
else:
    filename = 'dino_embeddings.npz'

Path('../outputs/feat').mkdir(parents=True, exist_ok=True)
np.savez_compressed(f'../outputs/feat/{filename}', embeddings=list(new_embeddings.values()), images=list(new_embeddings.keys()))

Using cache found in /home/qb/.cache/torch/hub/facebookresearch_dinov2_main
  2%|▏         | 429/23491 [03:07<2:48:22,  2.28it/s]


KeyboardInterrupt: 

In [5]:
!du -sh ../outputs/feat/dino_embeddings.npz

3,9G	../outputs/feat/dino_embeddings.npz


In [57]:
np.savez_compressed(f'../outputs/feat/{filename}', embeddings=list(new_embeddings.values()), images=list(new_embeddings.keys()))