In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import csv
import numpy as np

dataset_base = "/data2/xcg_data/lavis_data/2023us/features"
csvpath = "/data/xcg/lavis_data/coco-2023us/excels/translated.csv"

In [5]:



# Define your custom dataset class
class CustomDataset(Dataset):
    def __init__(self, dataset_base, csvpath, limitation):
        self.csvpath = csvpath
        self.dataset_base = dataset_base
        self.pairs, self.keylist = self.load_caption()
        self.limitation = limitation
    
    def load_caption(self):

        # {personid: [[image_ids, ], caption], }
        pairs = {}
        with open(self.csvpath, 'r') as csvfile:
            reader = csv.reader(csvfile)
            for row in reader:
                personid = str(row[0]).split("_")[1]
                if personid not in pairs:
                    pairs[personid] = [[row[0]], row[1]]
                else:
                    pairs[personid][0].append(row[0])

        return pairs, list(pairs.keys())


    def __len__(self):
        return len(self.keylist)

    def __getitem__(self, index):
        # clip_feature, sam_feature, caption
        personid = self.keylist[index]
        clip_feature = None
        sam_feature = None
        for i in range(self.limitation):
            pairlen = len(self.pairs[personid][0])
            clip_feature_path = self.dataset_base + "/clip_features/" + self.pairs[personid][0][i % pairlen] + '.npz'
            sam_feature_path = self.dataset_base + "/sam_features/" + self.pairs[personid][0][i % pairlen] + '.npz'
            clip_dataloads = np.load(clip_feature_path)
            sam_dataloads = np.load(sam_feature_path)
            if clip_feature is None:
                clip_feature = torch.from_numpy(clip_dataloads["arr"]).unsqueeze(0)
                sam_feature = torch.from_numpy(sam_dataloads["arr"]).unsqueeze(0)
            else:
                clip_feature = torch.cat([clip_feature, torch.from_numpy(clip_dataloads["arr"]).unsqueeze(0)], dim=0)
                sam_feature = torch.cat([sam_feature, torch.from_numpy(sam_dataloads["arr"]).unsqueeze(0)], dim=0)

        return clip_feature, sam_feature, self.pairs[personid][1]

In [None]:
datas = CustomDataset(dataset_base=dataset_base, csvpath=csvpath, limitation=4)
datas[0]

In [12]:
custom_dataloader = DataLoader(datas, batch_size=6, shuffle=False)

In [13]:
for clipfeatures, samfeatures, labels in custom_dataloader:
    print(clipfeatures.shape, samfeatures.shape, len(labels))
    break

torch.Size([6, 4, 677, 1408]) torch.Size([6, 4, 256, 4096]) 6
