In [None]:
import os
import torch
import torchvision.datasets as datasets
from torchvision.transforms.v2 import Compose, RandomHorizontalFlip, ToDtype, Lambda, Resize
from torchvision.transforms.v2.functional import pil_to_tensor
from einops import rearrange
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

# use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
resolution = 252

def get_train_transform(resolution=resolution):
    return Compose([
        pil_to_tensor,
        Resize((resolution, resolution)),
        RandomHorizontalFlip(p=0.5),
        ToDtype(torch.float32),
        Lambda(lambda t: t / 255.0),
    ])

model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
model.eval()
model.to(device)

imagenet_dir = '/datasets01/imagenet_full_size/061417'
transform_train = get_train_transform()
dataset_train = datasets.ImageFolder(imagenet_dir, transform=transform_train)


In [None]:
idx = 3
image, label = dataset_train[idx]
image_rgb = rearrange(image, 'c h w -> h w c')

# prepare image for the model
image_batch = image.unsqueeze(0).to(device)
outputs = model.forward_features(image_batch)
x_norm_patchtokens = outputs['x_norm_patchtokens'].cpu().squeeze(0).numpy()

patch_size = 14
num_tokens, d = x_norm_patchtokens.shape
H = resolution // patch_size
W = resolution // patch_size


In [None]:
# background filtering with 1 PCA component
pca_background = PCA(n_components=1)
background = pca_background.fit_transform(x_norm_patchtokens)
background = rearrange(background, '(H W) d -> H W d', H=H, W=W)
background = (background - background.min()) / (background.max() - background.min())
mask = background < 0.55
background[mask] = 0.0

# RGB visualization with 3 PCA components
pca_rgb = PCA(n_components=3)
colors = pca_rgb.fit_transform(x_norm_patchtokens)
colors = rearrange(colors, '(H W) d -> H W d', H=H, W=W)
colors = (colors - colors.min()) / (colors.max() - colors.min())
mask_repeated = np.repeat(mask, 3, axis=-1)
colors[mask_repeated] = 0.0


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5), constrained_layout=True)

im0 = axes[0].imshow(image_rgb)
axes[0].set_title('Original')
axes[0].axis('off')

im1 = axes[1].imshow(background.squeeze(-1), cmap='gray', vmin=0, vmax=1)
axes[1].set_title('Background Filtered (PCA=1)')
axes[1].axis('off')
fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)

im2 = axes[2].imshow(colors, vmin=0, vmax=1)
axes[2].set_title('RGB (PCA=3)')
axes[2].axis('off')
fig.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)

plt.show()
