In [1]:
import torch
dinoViT = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14').cuda()

Using cache found in /sailhome/redfairy/.cache/torch/hub/facebookresearch_dinov2_main


In [2]:
import torchvision.transforms.functional as TF
import glob
import os
import numpy as np
from PIL import Image
from tqdm import tqdm

In [6]:
class MultiscenesDataset(torch.utils.data.Dataset):
    def __init__(self, dataroot, n_scenes=5000, input_size=14*64):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        self.input_size = input_size
        self.scenes = []
        image_filenames = sorted(glob.glob(os.path.join(dataroot, '*.png')))  # root/00000_sc000_az00_el00.png
        mask_filenames = sorted(glob.glob(os.path.join(dataroot, '*_mask.png')))
        fg_mask_filenames = sorted(glob.glob(os.path.join(dataroot, '*_mask_for_moving.png')))
        moved_filenames = sorted(glob.glob(os.path.join(dataroot, '*_moved.png')))
        bg_mask_filenames = sorted(glob.glob(os.path.join(dataroot, '*_mask_for_bg.png')))
        bg_in_mask_filenames = sorted(glob.glob(os.path.join(dataroot, '*_mask_for_providing_bg.png')))
        changed_filenames = sorted(glob.glob(os.path.join(dataroot, '*_changed.png')))
        bg_in_filenames = sorted(glob.glob(os.path.join(dataroot, '*_providing_bg.png')))
        changed_filenames_set, bg_in_filenames_set = set(changed_filenames), set(bg_in_filenames)
        bg_mask_filenames_set, bg_in_mask_filenames_set = set(bg_mask_filenames), set(bg_in_mask_filenames)
        image_filenames_set, mask_filenames_set = set(image_filenames), set(mask_filenames)
        fg_mask_filenames_set, moved_filenames_set = set(fg_mask_filenames), set(moved_filenames)
        filenames_set = image_filenames_set - mask_filenames_set - fg_mask_filenames_set - moved_filenames_set - changed_filenames_set - bg_in_filenames_set - bg_mask_filenames_set - bg_in_mask_filenames_set
        filenames = sorted(list(filenames_set))
        self.n_scenes = n_scenes
        self.n_img_each_scene = 4
        for i in range(self.n_scenes):
            scene_filenames = [x for x in filenames if 'sc{:04d}'.format(i) in x]
            self.scenes.append(scene_filenames)

    def _transform_encoder(self, img): # for ImageNet encoder
        img = TF.resize(img, (self.input_size, self.input_size))
        img = TF.to_tensor(img)
        img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        return img

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing, here it is scene_idx
        """
        scene_idx = index
        scene_filenames = self.scenes[scene_idx]
        filenames = scene_filenames[:self.n_img_each_scene]
        rets = []
        for rd, path in enumerate(filenames):
            img = Image.open(path).convert('RGB')
            img_data = self._transform_encoder(img)
            rets.append((img_data, path))
        paths = [x[1] for x in rets]
        imgs = torch.stack([x[0] for x in rets])
        return imgs, paths
            
    def __len__(self):
        """Return the total number of images in the dataset."""
        return self.n_scenes

In [7]:
feat_size = 64
input_size = 14*feat_size
dataset = MultiscenesDataset('/viscam/projects/uorf-extension/datasets/room_diverse_nobg/train-1obj-manysize-trans-orange', n_scenes=50, input_size=input_size)

In [12]:
out_channel = 1024
for imgs, paths in tqdm(dataset):
    print(imgs.shape, paths)
    imgs = imgs.cuda()
    with torch.no_grad():
        feats = dinoViT.forward_features(imgs)['x_norm_patchtokens'].reshape(-1, feat_size, feat_size, out_channel)
        feats = feats.cpu().numpy()
    # save features
    for rd, path in enumerate(paths):
        path = path.replace('png', '.npy')
        np.save(path, feats[rd])
        print(feats[rd].shape, path)
    break


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4, 3, 896, 896]) ['/viscam/projects/uorf-extension/datasets/room_diverse_nobg/train-1obj-manysize-trans-orange/00000_sc0000_az00.png', '/viscam/projects/uorf-extension/datasets/room_diverse_nobg/train-1obj-manysize-trans-orange/00001_sc0000_az01.png', '/viscam/projects/uorf-extension/datasets/room_diverse_nobg/train-1obj-manysize-trans-orange/00002_sc0000_az02.png', '/viscam/projects/uorf-extension/datasets/room_diverse_nobg/train-1obj-manysize-trans-orange/00003_sc0000_az03.png']
(64, 64, 1024) /viscam/projects/uorf-extension/datasets/room_diverse_nobg/train-1obj-manysize-trans-orange/00000_sc0000_az00..npy
(64, 64, 1024) /viscam/projects/uorf-extension/datasets/room_diverse_nobg/train-1obj-manysize-trans-orange/00001_sc0000_az01..npy
(64, 64, 1024) /viscam/projects/uorf-extension/datasets/room_diverse_nobg/train-1obj-manysize-trans-orange/00002_sc0000_az02..npy


  0%|          | 0/50 [00:07<?, ?it/s]

(64, 64, 1024) /viscam/projects/uorf-extension/datasets/room_diverse_nobg/train-1obj-manysize-trans-orange/00003_sc0000_az03..npy



