# Cached dataset

> Classes to create a dataset with cached labels.

In [None]:
#| default_exp dataset.cached_dataset

In [None]:
#| export
from genQC.imports import *
from genQC.dataset.config_dataset import ConfigDataset, ConfigDatasetConfig
from genQC.utils.config_loader import *

In [None]:
#| export
@dataclass
class CachedOpenCLIPDatasetConfig(ConfigDatasetConfig):
    pass

In [None]:
#| export
class CachedOpenCLIPDataset(ConfigDataset):
    """
    Adds `.caching` to the `ConfigDataset` class.
    
    Cached dataset that caches the label `y` prompts using the CLIP `text_encoder`. This speeds up training significantly.
    """

    #-----------------------------------
    
    def x_y_preprocess(self, balance_max, shuffle=False, max_samples=None, make_unique=True):
        x_proc, y_proc, *z = super().x_y_preprocess(balance_max=balance_max, shuffle=shuffle, max_samples=max_samples, make_unique=make_unique)        
        y_proc = self.caching(y_proc)
        return x_proc, y_proc, *z
    
    def caching(self, y_proc, y_on_cpu=False):
        print("[INFO]: Generate cache: converting tensors to str and tokenize")   
        
        print(" - to str list")  
        if isinstance(y_proc, (torch.Tensor, torch.IntTensor, torch.FloatTensor, torch.LongTensor)):         
            y_str = [str(i) for i in y_proc.cpu().tolist()]
        elif isinstance(y_proc, list): 
            y_str = []
            for iy in y_proc:                
                if isinstance(iy, np.ndarray): y_str += [str(i) for i in iy.tolist()]        # list of numpy arrays
                else:                          y_str += [str(i) for i in iy.cpu().tolist()]  # list of tensors
        elif isinstance(y_proc, np.ndarray):
            y_str = [str(i) for i in y_proc.tolist()]
            
        else: raise NotImplementedError()
                            
        print(" - tokenize_and_push_to_device")  
        y_tok = self.text_encoder.tokenize_and_push_to_device(y_str, to_device= not y_on_cpu)
        if y_on_cpu: y_tok = y_tok.cpu()
        
        
        # Now for using cache we need the uniques and the corresponding indices of the uniques
        y_uniques, y_ptrs  = torch.unique(torch.cat([self.text_encoder.empty_token.to(y_tok.device), y_tok], dim=0), dim=0, return_inverse=True)
    
        cached_empty_token_index = y_ptrs[0]  #store what index the empty token has   
        y_ptrs                   = y_ptrs[1:] #remove the cat empty token
      
        # Use cache
        print(" - generate_cache")  
        self.text_encoder.generate_cache(tokens=y_uniques, cached_empty_token_index=cached_empty_token_index, y_on_cpu=y_on_cpu)
      
        print(f"[INFO]: Generated cache, {y_ptrs.shape=}")  
        return y_ptrs.clone()
    
    #-------------------------------------------
    
    def get_dataloaders(self, batch_size, text_encoder, p_valid=0.1, balance_max=None, max_samples=None):
        self.text_encoder = text_encoder    
        return super().get_dataloaders(batch_size, p_valid, balance_max, max_samples)     

# Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()