# Refactor-DataModule.ipynb

Efforts to simplify from implementation of MultuTaskDataModule

Created on: Wednesday November 3rd, 2021  
Created by: Jacob A Rose



### Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%load_ext heat

In [3]:
import pytorch_lightning as pl
from torchvision import transforms
import torch
from typing import *

In [4]:
from pytorch_lightning.utilities.distributed import rank_zero_only
from lightning_hydra_classifiers.experiments.transfer_experiment import TransferExperiment, TransferExperimentConfig, Extant_to_PNAS_ExperimentConfig, Extant_to_Fossil_ExperimentConfig
from lightning_hydra_classifiers.utils.template_utils import get_logger

### Definitions

In [None]:
# @dataclass
# class TransferExperimentConfig:
#     source_root_dir: str = CSV_CATALOG_DIR_V1_0
#     experiment_root_dir: str = EXPERIMENTAL_DATASETS_DIR # 

In [5]:
from typing import *
from dataclasses import dataclass, field
from lightning_hydra_classifiers.experiments.configs.base import BaseConfig
from hydra.core.config_store import ConfigStore

# __all__ = ["MultiTaskDataModuleConfig", "TransferExperimentConfig", "Extant_to_PNAS_ExperimentConfig", "Extant_to_Fossil_ExperimentConfig", "TaskConfig", "register_configs"]
# @dataclass
# class TransferExperimentConfig:
#     source_root_dir: str = CSV_CATALOG_DIR_V1_0
#     experiment_root_dir: str = EXPERIMENTAL_DATASETS_DIR # 


@dataclass
class DataModuleConfig(BaseConfig):
    _target_: str
    image_size: int = 512
    image_buffer_size: int = 32
    batch_size: int = 32
    num_workers: int = 4
    pin_memory: bool = True
    dataset_name: str = "dataset"


@dataclass
class SingleTaskDataModuleConfig(DataModuleConfig):
    _target_: str = "lightning_hydra_classifiers.experiments.multitask.datamodules.SingleTaskDataModule"
    dataset_name: str = "Extant_Leaves"
    

from lightning_hydra_classifiers.data.utils.make_catalogs import CSV_CATALOG_DIR_V1_0, EXPERIMENTAL_DATASETS_DIR



@dataclass
class TaskConfig:
    name: str
    val_split: Optional[float] = 0.2
    test_split: Optional[float] = None
        



In [8]:
# catalog_registry.available_datasets.versions.keys()
# catalog_registry.available_datasets.tags
# type(catalog_registry.available_datasets.versions["v0_3"].__repr__())

In [4]:
# from lightning_hydra_classifiers import catalog_registry, available_datasets
# from rich import pretty, print
# pretty.install()
# ["Rich and pretty", True]

# self = catalog_registry.available_datasets.versions["v0_3"]
# self = catalog_registry.available_datasets.versions["v1_0"]
# print(self.__repr__())
# print(self)

# import rich
# c = rich.color.Color.default()
# dir(c)

# from rich import inspect
# inspect(self)
# import inspect
# # print(inspect.getsource(type(self)))#$.__repr__)
# self.__repr__()

In [19]:
from lightning_hydra_classifiers.data.utils import catalog_registry
from lightning_hydra_classifiers.utils.common_utils import Batch
import torchdata as td
from PIL import Image
import pandas as pd
from pathlib import Path
from dataclasses import dataclass, asdict
from rich import pretty, print
from typing import *
pretty.install()


import rich.repr

# torchvision.datasets.ImageFolder


@dataclass
class PathSchema:
    """
    User provides a template str for instantiating this class, 
    which is then used to parse sample labels from file path names.
    """
    path_schema: str = Path("{family}_{genus}_{species}_{collection}_{catalog_number}")
        
    def __init__(self,
                 path_schema,
                 sep: str="_"):

        self.sep = sep
        self.schema_parts: List[str] = path_schema.split(sep)
        self.maxsplit: int = len(self.schema_parts) - 2
    
    def parse(self, path: Union[Path, str], sep: str="_"):
    
        parts = Path(path).stem.split(sep, maxsplit=self.maxsplit)
        if len(parts) == 5:
            family, genus, species, collection, catalog_number = parts
        elif len(parts) == 4:
            family, genus, species, catalog_number = parts
            collection = catalog_number.split("_")[0]
        else:
            print(f'len(parts)={len(parts)}, parts={parts}, path={path}')

        return family, genus, species, collection, catalog_number
    
    def split(self, sep):
        return self.schema_parts


@dataclass
class SampleSchema:
    """
    Data structure for representing the types of useful info extracted from a single sample's file path.
    
    Allows flexibility in returning Dataset items, so that any or all labels of a sample can be accessed
    via int indexing like a list, string indexing like a dict, or via named attributes.
    """
    path : Union[str, Path] = None
    family : str = None
    genus : str = None
    species : str = None
    collection : str = None
    catalog_number : str = None

    @classmethod
    def keys(cls):
        return list(cls.__dataclass_fields__.keys())
        
    def __getitem__(self, index: int):
        return getattr(self, self.keys()[index])
    
###################################
###################################
###################################


# @rich.repr.auto
class FileDataset(td.Dataset):
    def __init__(self, 
                 data_dir: Path,
                 regex: str="*.jpg",
                 name: Optional[str]=None):
        super().__init__() # This is necessary
        self.name=name
        self.data_dir = data_dir
        self.regex=regex
        
        self.files = [file for file in Path(data_dir).rglob(regex)]

    def __getitem__(self, index):
        return self.files[index]

    def __len__(self) -> int:
        return len(self.files)
    
    
    def __repr__(self) -> str:
        out = f"Dataset_name: {self.name}" + "\n"
        out += f"num_samples: {len(self)}"
        return out
    

    
class ImageFileDataset(FileDataset):
    def __getitem__(self, index):
        return Image.open(self.files[index])



class SupervisedImageFileDataset(ImageFileDataset):
    def __init__(self, 
                 data_dir: Path,
                 regex: str="*.jpg",
                 name: Optional[str]=None,
                 path_schema: str = "{family}_{genus}_{species}_{collection}_{catalog_number}",
                 y_col: str="family"):
        super().__init__(data_dir=data_dir,
                         regex=regex,
                         name=name)
        self.path_schema = PathSchema(path_schema)
        self.y_col = y_col
        self.process_all()

#     def setup(self,
#               samples_df: pd.DataFrame=None,
#               label_encoder: LabelEncoder=None,
#               fit_targets: bool=True):
#         """
#         Running setup() should result in the Dataset having assigned values for:
#             self.samples
#             self.targets
#             self.samples_df
#             self.label_encoder
        
#         """
#         if samples_df is not None:
#             self.samples_df = samples_df.convert_dtypes()
        # self.samples = [self.parse_sample(idx) for idx in range((len(self)))]
        # self.targets = [sample[1] for sample in self.samples]
        # self.samples_df = pd.DataFrame(self.samples).convert_dtypes()
    
    @property
    def class_counts(self):
        return self.samples_df.value_counts(self.y_col).to_dict()

    def process_all(self):
        self.index = list(range(len(self)))
        self.samples = [self.parse_sample(idx) for idx in self.index]
        self.targets = [sample[1] for sample in self.samples]
        self.samples_df = pd.DataFrame(self.samples).convert_dtypes()
        self.classes = list(self.class_counts.keys())
        
        
    

    def parse_sample(self, index: int):
        path = self.files[index]
        family, genus, species, collection, catalog_number = self.path_schema.parse(path)

        return SampleSchema(path=path,
                            family=family,
                            genus=genus,
                            species=species,
                            collection=collection,
                            catalog_number=catalog_number)

    def fetch_item(self, index: int) -> Tuple[str]:
        sample = self.parse_sample(index)
        image = Image.open(sample.path)
        metadata=sample
                  # "path":getattr(sample, self.x_col),
                  # "catalog_number":getattr(sample, self.id_col)
        return Batch(image=image,
                     target=getattr(sample, self.y_col),
                     metadata=metadata)

    def __getitem__(self, index):
        if isinstance(index, int):
            return self.fetch_item(index)
        return [self.fetch_item(i) for i in index]
        # index = index if isinstance(index, Sequence) else [index]
        # return [self.fetch_item(i) for i in index]
        # return Image.open(self.samples[index])
        
    def __repr__(self) -> str:
        out = super().__repr__() + "\n"
        out += f"num_classes: {len(self.classes)}"
        return out
        
        
        
        
        
from lightning_hydra_classifiers.data.utils import catalog_registry

class DatasetsRegistry(catalog_registry.AvailableDatasets):
    """
    """
    
    @classmethod
    def get_as_dataset(cls, query: str, version: Optional[str]="v1_0") -> SupervisedImageFileDataset:
        result = cls.search(query=query, version=version)
        dataset_name, data_dir = list(result.items())[0]
        
        path_schema: str = "{family}_{genus}_{species}_{collection}_{catalog_number}"
        if 'PNAS' in dataset_name:
            path_schema: str = "{family}_{genus}_{species}_{catalog_number}"
        
        return SupervisedImageFileDataset(data_dir,
                                          path_schema=path_schema,
                                          name=dataset_name)
    
    

In [21]:
# data = DatasetsRegistry.get_as_dataset(query="Fossil_f", version="v1_0")
data = DatasetsRegistry.get_as_dataset(query="PNAS", version="v1_0")

print("Loaded data:\n" + repr(data))

In [22]:
# data = DatasetsRegistry.get_as_dataset(query="Fossil_f", version="v1_0")
data = DatasetsRegistry.get_as_dataset(query="Extant", version="v1_0")

print("Loaded data:\n" + repr(data))

In [20]:
data

In [15]:
# print(data.samples_df.value_counts("family").to_dict())
# data.y_col

In [18]:
# data.process_all()
# data.class_counts
data.classes

In [2]:
data[0].metadata

In [3]:
# data_dir = Path(catalog_registry.available_datasets.get("Extant_Leaves_original"))
# dataset_name = "PNAS_original"
# dataset_name = "PNAS_family_100_512"
# path_schema: str = "{family}_{genus}_{species}_{collection}_{catalog_number}"
# if 'PNAS' in dataset_name:
#     path_schema: str = "{family}_{genus}_{species}_{catalog_number}"

# data_dir = Path(catalog_registry.available_datasets.get(dataset_name))
# data = SupervisedImageFileDataset(data_dir, path_schema=path_schema)
# data = ImageFileDataset(data_dir)#, path_schema=path_schema)
# data = FileDataset(data_dir)#, path_schema=path_schema)

In [48]:
# from lightning_hydra_classifiers.data.utils import catalog_registry

# class DatasetsRegistry(catalog_registry.AvailableDatasets):
#     """
#     """
    
#     @classmethod
#     def get_as_dataset(cls, query: str, version: Optional[str]="v1_0") -> SupervisedImageFileDataset:
#         result = cls.search(query=query, version=version)
#         dataset_name, data_dir = list(result.items())[0]
        
#         path_schema: str = "{family}_{genus}_{species}_{collection}_{catalog_number}"
#         if 'PNAS' in dataset_name:
#             path_schema: str = "{family}_{genus}_{species}_{catalog_number}"
        
#         return SupervisedImageFileDataset(data_dir,
#                                           path_schema=path_schema,
#                                           name=dataset_name)


    # def get_dataset(self, data_dir: str) -> td.Dataset:
    #     self.search

# dir(catalog_registry)
# catalog_registry.leavesdbv1_0
# , available_datasets

# registry = DatasetsRegistry()
# data = DatasetsRegistry.get_as_dataset(query="Fossil_f", version="v1_0")
# data[0].metadata

In [7]:
# for d in data[range(10)]:
#     print(d[-1])

# labels = files.map(lambda x: Path(x).parent.parts[-2])


# @rich.repr.auto
# class Dataset(td.Dataset):
#     def __init__(self, data_dir: Path):
#         super().__init__() # This is necessary
#         self.data_dir = data_dir
#         self.files = [file for file in Path(data_dir).glob("*")]

#     def __getitem__(self, index):
#         return self.files[index]
#     # return Image.open(self.files[index])

#     def __len__(self) -> int:
#         return len(self.files)
    
    # def __rich__(self) -> str:
    #     return f"[cyan]{type(self)}"


In [8]:
catalog_registry.available_datasets.search("512")

In [10]:
# type(catalog_registry.available_datasets.versions)
from typing import *

def query_dict(target: Dict[str,str],
               query: str) -> Dict[str,str]:
    """
    Searches a target dictionary for k,v pairs for which the key contains a substring equal to {query}.
    Returned dictionary contains a minimum of 0 and maximum of len(target) items.
    """
    return {k:v for k,v in vars(target).items() if query in k}


# query = 'PNAS'
# # result = {k:v for k,v in vars(catalog_registry.available_datasets.versions['v0_3']).items() if query in k}
# target = catalog_registry.available_datasets.versions['v0_3']
# result = query_dict(target, query=query)

version = 'v0_3'
query = 'Extant' #'PNAS'
# result = {k:v for k,v in vars(catalog_registry.available_datasets.versions['v0_3']).items() if query in k}
target = catalog_registry.available_datasets.versions[version]
result = query_dict(target, query=query)

result

In [18]:
catalog_registry.available_datasets
# catalog_registry.AvailableDatasets()

tag, data_dir = list(registry.search("PNAS_original").items())[0]
data = Dataset(data_dir)
files = list(data_dir.rglob("*"))

In [None]:
# from lightning_hydra_classifiers.data.utils.catalog_registry import available_datasets
from lightning_hydra_classifiers.data.utils import catalog_registry

    from lightning_hydra_classifiers.data.utils.catalog_registry import available_datasets
    
    available_datasets.get(tag='Fossil_2048', version='v1_0')
    >>['/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Fossil/General_Fossil/2048/full/jpg',
       '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Fossil/Florissant_Fossil/2048/full/jpg']

    available_datasets.get(tag='Extant_Leaves_family_20_1024', version='v1_0')
    >>'/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/1024/20/jpg'

    
    print(available_datasets().tags)
    print(available_datasets())


In [None]:
class SETIDataset(Dataset):
    def __init__(self, images_filepaths, targets, transform=None):
        self.images_filepaths = images_filepaths
        self.targets = targets
        self.transform = transform

    def __len__(self):
        return len(self.images_filepaths)

    def __getitem__(self, idx):
        image_filepath = self.images_filepaths[idx]
        image = np.load(image_filepath)
        image = image.astype(np.float32)
        image = np.vstack(image).transpose((1, 0))
            
        if self.transform is not None:
            image = self.transform(image=image)["image"]
        else:
            image = image[np.newaxis,:,:]
            image = torch.from_numpy(image).float()
        
        label = self.targets[idx].reshape(-1,)
        return image, label

In [None]:


class OneTaskDataModule(pl.LightningDataModule):
    # TBD: Merge this with previous BaseDataModule from common.py
    dataset_names: Dict[str,str] = {"task_0":"Extant_family_10",
                                    "task_1":"PNAS_family_100",
                                    "task_2":"Fossil_family_3"}
    def __init__(self, 
                 batch_size,
                 task_id: int=0,
                 image_size: int=224,
                 image_buffer_size: int=32,
                 num_workers: int=4,
                 pin_memory: bool=True,
                 experiment_config: Optional[TransferExperimentConfig]=None):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
                
        self.image_size = image_size
        self.image_buffer_size = image_buffer_size
        
        # TBD: Replace this with a stats accumulator instance that can manage stats calculation caching & computation.
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]
        
        
        # TBD Make Base and SingleTask versions of experiment.
        self.experiment = TransferExperiment(experiment_config)
        self.experiment_config = self.experiment.config
        
        # Train augmentation policy
        self.__init_transforms()
        self.tasks = self.experiment.get_multitask_datasets(train_transform=self.train_transform,
                                                            val_transform=self.val_transform)
        self.task_tag = None
        self.set_task(task_id)

    def __init_transforms(self):
        
        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(size=self.image_size,
                                         scale=(0.25, 1.2),
                                         ratio=(0.7, 1.3),
                                         interpolation=2),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize(self.mean, self.std),
            transforms.Grayscale(num_output_channels=3)
        ])

        self.val_transform = transforms.Compose([
            transforms.Resize(self.image_size+self.image_buffer_size),
            transforms.ToTensor(),
            transforms.CenterCrop(self.image_size),
            transforms.Normalize(self.mean, self.std),
            transforms.Grayscale(num_output_channels=3)            
        ])
        
        
    @rank_zero_only
    def update_stats(self, 
                     mean: List[float]=None,
                     std: List[float]=None):

        logger.warning(f"Updating stats: mean={mean}, std={std}")
        if mean is not None:
            self.mean = mean
        if std is not None:
            self.std = std
            
        logging.info("DataModule image normalization stats updated:\n" + f"mean={self.mean}, std={self.std}")

    def set_task(self, task_id: int):
        assert task_id in self.experiment.valid_tasks
        self.task_id = task_id
        logger.info(f"set_task(task_id={self.task_id})")
#         self.setup()
        
    @property
    def current_task(self):
        return self.tasks[self.task_id]

    def setup(self, stage=None, task_id: int=None):
#         super().setup(stage)
        if isinstance(task_id , int):
            self.set_task(task_id=task_id)
        task = self.current_task
#         logger.info(f"Task_{self.task_id}: datamodule.setup(stage={stage})")
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            self.train_dataset = task['train']
            self.val_dataset = task['val']
            
            self.classes = self.train_dataset.classes
            self.num_classes = len(self.train_dataset.label_encoder)
            self.label_encoder = self.train_dataset.label_encoder
            
            self.full_name = self.train_dataset.config.full_name
            if hasattr(self.train_dataset.config, "task_tag"):
                self.task_tag = self.train_dataset.config.task_tag
            logger.info(f"Task_{self.task_id} ({self.task_tag}): datamodule.setup(stage=fit)")
        
        if stage == 'test' or stage is None:
            self.test_dataset = task['test']
            logger.info(f"Task_{self.task_id}: datamodule.setup(stage=test)")
            
        self._has_setup_fit = False
        self._has_setup_test = False
#         else:
#             logger.warning(f"[No-Op] Task_{self.task_id}: datamodule.setup(stage={stage})")
    def get_dataset(self, stage: str="train"):
        if stage=="fit": stage="train"
        assert hasattr(self, f"{stage}_dataset")
        return getattr(self, f"{stage}_dataset")

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset,
                          batch_size=self.batch_size,
                          pin_memory=self.pin_memory,
                          num_workers=self.num_workers,
                          shuffle=True,
                          drop_last=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset,
                          batch_size=self.batch_size,
                          pin_memory=self.pin_memory,
                          num_workers=self.num_workers,
                          shuffle=False,
                          drop_last=False)
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_dataset,
                          batch_size=self.batch_size,
                          pin_memory=self.pin_memory,
                          num_workers=self.num_workers,
                          shuffle=False,
                          drop_last=False)