In [8]:
import os
import zipfile
import PIL.Image
import json
import copy
from typing import IO, Optional, List

import numpy as np
import torch
from torch import nn

from torch.utils.data import DataLoader, Dataset

In [None]:
class Dataset(Dataset):
    def __init__(self, 
                 name       : str,                     # Name of the Dataset
                 raw_shape  : List,                    # Raw shape of the images [B, C, H, W]
                 max_size   : Optional[int] = None,    # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
                 use_labels : bool = False,            # Enable conditioning labels? False = label dimension is zero.
                 xflip      : bool = False,            # Artificially double the size of the dataset via x-flips. Applied after max_size.
                 random_seed: Optional[int] = 1,       # Random seed to use when applying max_size.
                 ):
        super().__init__()
        
        self._name        = name
        self._raw_shape   = list(raw_shape)
        self._use_labels  = use_labels
        self._raw_labels  = None
        self._label_shape = None

        self._raw_idx = np.arange(self._raw_shape, dtype = np.int64)
        self._base_raw_idx = copy.deepcopy(self._raw_idx)
        if (max_size is not None) and len(self._raw_shape[0]) > max_size:
            np.random.RandomState(random_seed).shuffle(self._raw_idx)
            self._raw_idx = np.sort(self._raw_idx[:max_size])

        self.xflip = np.zeros_like(self._raw_idx)       # self.xflip:   [0, 1, 2...]
        if xflip:
            self._raw_idx = np.tile(self._raw_idx, 2)   # 2; Doubling Dataset
                                                        # self._raw_idx: [0, 1, 2.. , 0, 1, 2...]
            self.xflip = np.concatenate((self.xflip, np.ones_like(self._raw_idx)))
                                                        # self.xflip: [0, 0, 0.. , 1, 1, 1...]

    def set_dynamic_lenght(self, new_len: int):
        self._raw_idx = self._base_raw_idx[:new_len]


    def _load_raw_labels():
        # This must be Implmented later on with different class
        raise NotImplemented
    
    def set_class(self, cls_list: List) -> None:
        self._raw_labels = self._load_raw_labels()  
        new_idcs = [self._raw_labels == cl for cl in cls_list]
        
        new_idcs = [self._raw_labels == cl for cl in cls_list]     # new_idcs: [True, False, False...]
        new_idcs = np.sum(np.vstack(new_idcs), 0)                  # new_idcs: [1   , 0    , 0    ...]
        new_idcs = np.where(new_idcs)                              # new_idcs: [2,3..]

        self._raw_idx = self._base_raw_idx[new_idcs]
        assert all(sorted(cls_list) == np.unique(self._raw_labels[self._raw_idx])) # cls_list == unique labels
        print(f"Training on the following classes: {cls_list}")
    
    def _get_raw_labels(self) -> np.ndarray:
        if self._raw_labels is None:
            self._raw_labels = self._load_raw_labels() if self._use_labels else None
            if self._raw_labels is None:
                self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)

            assert isinstance(self._raw_labels, np.ndarray)
            assert self._raw_labels[0] == self._raw_shape[0] # Same Batch size
            assert self._raw_labels.dtype.type in [np.float32, np.int64, np.str_]
            if self._raw_labels.dtype == np.int64:          # If one hot encoded
                assert self._raw_labels.ndim == 1
                assert np.all(self._raw_labels >= 0)
        return self._raw_labels

In [64]:
labels = np.array([
    "a dog playing",
    "a mountain landscape",
    "a sketch of a robot",
    "puppy",
], dtype=str)  # shape: [N]

new_idcs = [labels == cl for cl in ["a sketch of a robot", "puppy"]] # new_idcs: [True, False, False...]
new_idcs = np.sum(np.vstack(new_idcs), 0)                            # new_idcs: [1   , 0    , 0    ...]
new_idcs = np.where(new_idcs)                                        # new_idcs: [2,3..]
new_idcs

(array([2, 3]),)

In [71]:
type(labels)

numpy.ndarray

In [69]:
np.zeros([5, 0], dtype=np.float32)

array([], shape=(5, 0), dtype=float32)

array([False, False,  True])