In [21]:
import pandas as pd
import os
import torch
from torchvision import transforms as T
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [65]:
class ImageDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        category = self.df.iloc[idx]['category']
        filename = self.df.iloc[idx]['filename']
        
        img_path = os.path.join('../dataset/raw', category, filename)
        
        image = Image.open(img_path).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)

        return image

In [48]:
def dinov2_preprocess(H, W):
    patch_size = 14
    new_H = H - H % patch_size
    new_W = W - W % patch_size
    transform = T.Compose([
        T.Resize((new_H, new_W), interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])

    return transform

In [67]:
# Initialize Preprocessing Transform
img_shape = (256, 256)
transform = dinov2_preprocess(*img_shape)

# Initialize Dataset
train_csv_path = '../dataset/dataset_metadata/train_set.csv'
train_df = pd.read_csv(train_csv_path)
train_dataset = ImageDataset(train_df, transform=transform)

test_csv_path = '../dataset/dataset_metadata/test_set.csv'
test_df = pd.read_csv(test_csv_path)
test_dataset = ImageDataset(test_df, transform=transform)

batch_size = 64
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [19]:
dinov2_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
dinov2_model.eval()
dinov2_model.cuda()

Using cache found in /home/hice1/asinghal81/.cache/torch/hub/facebookresearch_dinov2_main


DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x NestedTensorBlock(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )
  )
  (n

In [70]:
# Compute all image embeddings (768 vector)

def compute_embeddings(dataloader, num_images):
    img_embeddings = torch.zeros((num_images, 768))

    with torch.no_grad():
        for i, images in tqdm(enumerate(dataloader), total=len(dataloader)):
            images = images.cuda()
            outputs = dinov2_model(images)

            if i != len(dataloader) - 1:
                img_embeddings[i*batch_size:(i+1)*batch_size, :] = outputs
            else:
                img_embeddings[i*batch_size:, :] = outputs

    return img_embeddings

train_img_embeddings = compute_embeddings(train_dataloader, len(train_df))
test_img_embeddings = compute_embeddings(test_dataloader, len(test_df))

100%|██████████| 136/136 [01:29<00:00,  1.53it/s]
100%|██████████| 34/34 [00:21<00:00,  1.57it/s]


In [71]:
# Save image embeddings as tensor file, so we don't need to recompute every time
torch.save(train_img_embeddings, '../dataset/img_embeddings/train_img_embeddings.pt')
torch.save(test_img_embeddings, '../dataset/img_embeddings/test_img_embeddings.pt')
print(train_img_embeddings.shape, test_img_embeddings.shape)

torch.Size([8679, 768]) torch.Size([2170, 768])
