# datamodule

> I'll will try to follow Lightning and Hydra attitude

In [1]:
#| default_exp datamodule

In [2]:
#| export
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Tuple, List, Mapping, Callable
import numpy as np
from fastcore.test import test_eq, ExceptionExpected
from torch.utils.data import Dataset, DataLoader

In [3]:
#| hide
from nbdev.showdoc import *

## Transform

It should:
* transform items according to it's descriptor
* compose som transforms together

In [4]:
#| export
class Desc(Enum):
  IMAGE = 1
  LABEL = 2

def noop(x): return x

class Transform:
  dtype:np.dtype = np.float32
  kinds:Mapping[Desc, Callable] = {}

  def do(self, item, desc:Desc):
    'Transform only items declared in `self.kinds`'
    func = self.kinds.get(desc, noop)
    return func(item)

  def __call__(self, items: List[Any], descriptor: Tuple[str] = None) -> Any:
    descriptor = descriptor or self.descriptor
    if not descriptor:
      raise Exception(f'{self.__class__.__name__} got empty descriptor')
    return [self.do(item, desc) for item, desc in zip(items, descriptor)]

class Compose(Transform):
  transforms:List[Transform]

  def __init__(self, transforms:List[Transform]) -> None:
    self.transforms = transforms

  def __call__(self, items: List[Any], descriptor:Tuple[Desc]) -> Any:
    return [transform(items, descriptor) for transform in self.transforms]

class ToFloat(Transform):

  def __init__(self, descriptor:Tuple[Desc] = None) -> None:
    self.kinds = {Desc.IMAGE:self.image_to_float}
    self.descriptor = descriptor

  def image_to_float(self, item):
    return np.array(item, dtype=self.dtype) / 255

Transform could get descriptor in constructor, the the transformation is static, or in a call itself, then the transformation is dynamic per call. Descriptor in the call takes precedence over given in a constructor.

In [5]:
dummy_image, label = np.random.randint(0,256,size=(32,32,3)), 5
descriptor = (Desc.IMAGE, Desc.LABEL)
# in constructor
#
tfm = ToFloat(descriptor)
image, label = tfm([dummy_image, label])
image_transformed = bool(0 <= image.all() and image.all() <= 1 and image.dtype==np.float32)
test_eq(image_transformed, True)
test_eq(label, 5)
# in call
#
tfm = ToFloat()
image, label = tfm([dummy_image, label], descriptor)
image_transformed = bool(0 <= image.all() and image.all() <= 1 and image.dtype==np.float32)
test_eq(image_transformed, True)
test_eq(label, 5)
# different descriptors in both
#
tfm = ToFloat((Desc.IMAGE, Desc.IMAGE, Desc.LABEL))
image, label = tfm([dummy_image, label], descriptor)
image_transformed = bool(0 <= image.all() and image.all() <= 1 and image.dtype==np.float32)
test_eq(image_transformed, True)
test_eq(label, 5)

In [None]:
class Normalize(Transform):

  def __init__(self, descriptor:Tuple[Desc] = None, stats:Any = None) -> None:
    self.kinds = {Desc.IMAGE:self.normalize}
    self.descriptor = descriptor
    if not isinstance(stats, (list, tuple, np.array)):
      raise Exception(
        f'Supported types for statistic are `list`, `tuple` or `np.array` got {type(stats)}')
    self.stats = np.array(stats)

  def normalize(self, item):
    if item.shape[1:] != self.stats.shape:
      raise Exception(
        f'Expected shape is {self.stats.shape}, got {item.shape[1:]} (excluding batch dimension).')
    return np.array(item, dtype=self.dtype) / 255

Let's try to create pretty print function

## DataModule

It should:
* prepare dataset(s)
* split the dataset(s)
* transform items
* create train and val (and maybe more) dataloaders

In [6]:
#| export
class DataModule(ABC):
  pass


In [7]:
#| hide
import nbdev; nbdev.nbdev_export()