In [None]:
import rp
from icecream import ic
from easydict import EasyDict
import copy

In [None]:
class TestSample:
    def __init__(self, image, prompt:str, path:str, name:str):
        self.path=rp.get_absolute_path(path)
        self.image=image
        self.prompt=prompt
        self.name=name
    
    def copy(self):
        return copy.copy(self)
    
    def display(self):
        print(self.prompt)
        print(self.path)
        rp.display_image(self.image)

In [None]:
class CroppedRefVOCDataset:
    
    DEFAULT_PROMPT_REPLACEMENTS={
        'aeroplane'    : 'aeroplane',
        'bicycle'      : 'bicycle',
        'bird'         : 'bird',
        'boat'         : 'boat',
        'bottle'       : 'bottle',
        'bus'          : 'bus',
        'car'          : 'car',
        'cat'          : 'cat',
        'chair'        : 'chair',
        'cow'          : 'cow',
        'dog'          : 'dog',
        'horse'        : 'horse',
        'motorbike'    : 'motorcycle',
        'person'       : 'person',
        'potted plant' : 'potted plant',
        'sheep'        : 'sheep',
        'sofa'         : 'sofa',
        'train'        : 'train',
        'tv/monitor'   : 'tv/monitor',
    }
    
    def __init__(self, samples=None, prompt_replacements:dict={}):
        
        #Xiang made this dataset
        dataset_dir='/mnt/md0/nfs/datasets/RefVOC'
        self.dataset_dir = dataset_dir
        
        if samples is None:
            with rp.SetCurrentDirectoryTemporarily(dataset_dir):
                image_names=rp.text_file_to_string('refvoc_files.txt').strip().splitlines()
                #It has a bunch of lines that look like '2008_003885\n2008_004212\n2008_004612\n2008_004621' etc
                image_paths=['cropped-'+x+'.jpg' for x in image_names]
                #These files should exist in the current directory
                images=rp.load_images(image_paths, use_cache=True, show_progress=True)

                assert all(rp.is_image_file(x) for x in image_paths)

                #Get the prompts for every image
                prompts=rp.text_file_to_string('cropped.txt').strip().splitlines()
                #It has a bunch of lines that look like '2010_000241 bird\n2010_000342 bicycle\n2010_000628 car' etc
                prompts=[x.split(maxsplit=1) for x in prompts]
                prompts={image_name:prompt for image_name,prompt in prompts}
                prompts=[prompts[image_name] for image_name in image_names]

                samples=[TestSample(image,prompt,path,name) for image,prompt,path,name in zip(images,prompts,image_paths,image_names)]

        #Replace the given prompts
        prompt_replacements = {**type(self).DEFAULT_PROMPT_REPLACEMENTS, **prompt_replacements}
        for sample in samples:
            if sample.prompt in prompt_replacements:
                sample.prompt = prompt_replacements[sample.prompt]
        self.prompt_replacements=prompt_replacements
        
        self.samples = samples
        
    @property
    def images(self): return [s.image for s in self.samples]
        
    @property
    def prompts(self): return [s.prompt for s in self.samples]

    @property
    def names(self): return [s.name for s in self.samples]

    @property
    def image_paths(self): return [s.image_path for s in self.samples]

    def __getitem__(self, index):
        output = self.samples[index]
        if isinstance(index, int):
            return output
        elif isinstance(index, slice):
            output = CroppedRefVOCDataset(output)
            
    def split(self, num_divisions, division_index):
        """
        Split the dataset into num_divisions parts and return the division_index part.
        
        num_divisions: An integer representing the number of equal parts to divide the dataset into.
        division_index: An integer representing the 0-indexed part to return.
        """
        division_size = len(self) // num_divisions
        start_index = division_size * division_index
        end_index = start_index + division_size
        if division_index == num_divisions - 1:
            end_index = len(self)
        return self[start_index:end_index]
    
    def __len__(self):
        return len(self.samples)

    def __iter__(self):
        return iter(self.samples)

In [None]:
cropped_ref_voc_dataset=CroppedRefVOCDataset()
data=cropped_ref_voc_dataset

In [None]:
rp.random_element(data).display()