In [None]:
import datasets

# Load dataset
dataset = datasets.load_from_disk('./partial_data')

In [None]:
train_split = dataset['train']
train_split[0]

In [None]:
from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor

# load model
model = VisionTextDualEncoderModel.from_pretrained("kaveh/rclip")
processor = VisionTextDualEncoderProcessor.from_pretrained("kaveh/rclip")

In [None]:
from torch.utils.data import Dataset, DataLoader

class ImageDataset(Dataset):
    def __init__(self, split):
        self.split = split

    def __len__(self):
        return len(self.split)

    def __getitem__(self, idx):
        return idx, self.split[idx]['images'][0]

In [None]:
import torch

if torch.cuda.is_available():
    model = model.cuda()

In [None]:
image_dataset = ImageDataset(train_split)

def collate(batch):
    indices = [item[0] for item in batch]
    images = [item[1] for item in batch]
    return indices, images

data_loader = DataLoader(image_dataset, batch_size=2560, shuffle=False, collate_fn=collate, num_workers=4, persistent_workers=True)

In [None]:
from tqdm.notebook import tqdm
import json

def to_cuda(data):
    for k, v in data.items():
        if hasattr(v, 'to'):
            data[k] = v.to('cuda')
        else:
            data[k] = v

with open("image_embeddings.jsonl", 'w') as f:
    # Generate and save embeddings
    for idxs, batch_images in tqdm(data_loader):
        with torch.no_grad():
            inputs = processor(text=None, images=batch_images, return_tensors="pt", padding=True)
            to_cuda(inputs)
            outputs = model.get_image_features(**inputs)
        for idx, output in zip(idxs, outputs):
            embedding_dict = {'index': idx, 'embedding': output.cpu().numpy().tolist()}
            json.dump(embedding_dict, f)
            f.write('\n')