## One-time notebook to pre-compute features for images

In [2]:
import open_clip
import torch
from tqdm import tqdm
import torch
# from utils import custom_collate_fn
from eval_datasets import CaptionDataset
from eval_datasets import VQADataset

In [6]:
def custom_collate_fn(batch):
    """
    Collate function for DataLoader that collates a list of dicts into a dict of lists.
    """
    collated_batch = {}
    for key in batch[0].keys():
        collated_batch[key] = [item[key] for item in batch]
    return collated_batch

In [7]:
class RICES:
    def __init__(
        self,
        dataset,
        device,
        batch_size,
        vision_encoder_path="ViT-B-32",
        vision_encoder_pretrained="openai",
        cached_features=None,
    ):
        self.dataset = dataset
        self.device = device
        self.batch_size = batch_size

        # Load the model and processor
        vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
            vision_encoder_path,
            pretrained=vision_encoder_pretrained,
        )
        self.model = vision_encoder.to(self.device)
        self.image_processor = image_processor

        # Precompute features
        if cached_features is None:
            self.features = self._precompute_features()
        else:
            self.features = cached_features

    def _precompute_features(self):
        features = []

        # Switch to evaluation mode
        self.model.eval()

        # Set up loader
        loader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            collate_fn=custom_collate_fn,
        )

        with torch.no_grad():
            for batch in tqdm(
                loader,
                desc="Precomputing features for RICES",
            ):
                batch = batch["image"]
                inputs = torch.stack(
                    [self.image_processor(image) for image in batch]
                ).to(self.device)
                image_features = self.model.encode_image(inputs)
                image_features /= image_features.norm(dim=-1, keepdim=True)
                features.append(image_features.detach())

        features = torch.cat(features)
        return features

    def find(self, batch, num_examples):
        """
        Get the top num_examples most similar examples to the images.
        """
        # Switch to evaluation mode
        self.model.eval()

        with torch.no_grad():
            inputs = torch.stack([self.image_processor(image) for image in batch]).to(
                self.device
            )

            # Get the feature of the input image
            query_feature = self.model.encode_image(inputs)
            query_feature /= query_feature.norm(dim=-1, keepdim=True)
            query_feature = query_feature.detach().cpu()

            if query_feature.ndim == 1:
                query_feature = query_feature.unsqueeze(0)

            # Compute the similarity of the input image to the precomputed features
            similarity = (query_feature @ self.features.T).squeeze()

            if similarity.ndim == 1:
                similarity = similarity.unsqueeze(0)

            # Get the indices of the 'num_examples' most similar images
            indices = similarity.argsort(dim=-1, descending=True)[:, :num_examples]

        # Return with the most similar images last
        return [[self.dataset[i] for i in reversed(row)] for row in indices]

### Caption coco dataset

In [None]:
annotations_path = "../dataset/annotations/captions_train2017.json"
image_dir_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/train2017"
dataset = CaptionDataset(image_dir_path, annotations_path)

In [6]:
retriever = RICES(dataset,"cuda",320)

Precomputing features for RICES: 100%|██████████████████████████████████████████████| 370/370 [26:13<00:00,  4.25s/it]


In [16]:
retriever.features.cpu()

tensor([[-0.0104,  0.0480,  0.0346,  ...,  0.0827,  0.0402,  0.0019],
        [-0.0458, -0.0148,  0.0108,  ...,  0.0407, -0.0092, -0.0045],
        [ 0.0483, -0.0166,  0.0134,  ...,  0.0371,  0.0147, -0.0147],
        ...,
        [-0.0217,  0.0206,  0.0063,  ...,  0.0627,  0.0450, -0.0022],
        [-0.0073,  0.0270, -0.0074,  ...,  0.0908,  0.0426,  0.0249],
        [-0.0177,  0.0144, -0.0181,  ...,  0.1143,  0.0487, -0.0009]])

In [10]:
# saving train-coco image features
save_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/features-cache/coco_train.pkl"

In [12]:
import pickle
with open(save_path,'wb') as f:
    pickle.dump(retriever.features.cpu(), f)

In [13]:
with open(save_path,'rb') as f:
    ret_f2 = pickle.load(f)

In [14]:
ret_f2

tensor([[-0.0104,  0.0480,  0.0346,  ...,  0.0827,  0.0402,  0.0019],
        [-0.0458, -0.0148,  0.0108,  ...,  0.0407, -0.0092, -0.0045],
        [ 0.0483, -0.0166,  0.0134,  ...,  0.0371,  0.0147, -0.0147],
        ...,
        [-0.0217,  0.0206,  0.0063,  ...,  0.0627,  0.0450, -0.0022],
        [-0.0073,  0.0270, -0.0074,  ...,  0.0908,  0.0426,  0.0249],
        [-0.0177,  0.0144, -0.0181,  ...,  0.1143,  0.0487, -0.0009]])

In [15]:
ret_f2.shape

torch.Size([118287, 512])

### VQA - coco 2014 train dataset

In [3]:
image_dir_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/train2014"
questions_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/v2_OpenEnded_mscoco_train2014_questions.json"
annotations_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/v2_mscoco_train2014_annotations.json"
dataset = VQADataset(image_dir_path, questions_path, annotations_path,True, "vqav2")

In [8]:
retriever = RICES(dataset,"cuda",320)

Precomputing features for RICES: 100%|█████████████████████████████████| 259/259 [15:20<00:00,  3.55s/it]


In [9]:
retriever.features.cpu()

tensor([[ 1.3744e-02,  2.3473e-02,  2.2617e-02,  ..., -7.0341e-03,
         -2.1294e-02, -4.3456e-02],
        [-5.6258e-02,  3.1442e-02, -5.4619e-03,  ...,  1.9843e-02,
          9.1135e-03,  3.2896e-02],
        [ 1.2168e-02,  7.9854e-03, -1.5400e-02,  ...,  5.4900e-02,
         -2.1708e-02, -2.3898e-02],
        ...,
        [-1.6284e-02,  2.7989e-03, -4.2672e-03,  ...,  9.8049e-02,
          7.1288e-03, -2.5878e-05],
        [-2.0299e-02,  2.4458e-02, -4.5172e-03,  ...,  8.3662e-02,
         -1.6326e-02, -1.1608e-02],
        [ 6.3887e-03,  5.9976e-02, -9.9650e-03,  ...,  7.5395e-02,
         -9.3653e-03, -7.6931e-03]])

In [10]:
# saving train-coco image features
save_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/features-cache/coco_train_2014.pkl"

In [11]:
import pickle
with open(save_path,'wb') as f:
    pickle.dump(retriever.features.cpu(), f)

In [12]:
with open(save_path,'rb') as f:
    ret_f2 = pickle.load(f)

In [13]:
ret_f2

tensor([[ 1.3744e-02,  2.3473e-02,  2.2617e-02,  ..., -7.0341e-03,
         -2.1294e-02, -4.3456e-02],
        [-5.6258e-02,  3.1442e-02, -5.4619e-03,  ...,  1.9843e-02,
          9.1135e-03,  3.2896e-02],
        [ 1.2168e-02,  7.9854e-03, -1.5400e-02,  ...,  5.4900e-02,
         -2.1708e-02, -2.3898e-02],
        ...,
        [-1.6284e-02,  2.7989e-03, -4.2672e-03,  ...,  9.8049e-02,
          7.1288e-03, -2.5878e-05],
        [-2.0299e-02,  2.4458e-02, -4.5172e-03,  ...,  8.3662e-02,
         -1.6326e-02, -1.1608e-02],
        [ 6.3887e-03,  5.9976e-02, -9.9650e-03,  ...,  7.5395e-02,
         -9.3653e-03, -7.6931e-03]])

In [14]:
ret_f2.shape

torch.Size([82783, 512])