In [None]:
#| default_exp datasets

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export
from __future__ import annotations
from cfnet.import_essentials import *
from sklearn.preprocessing import StandardScaler,MinMaxScaler,OneHotEncoder
from urllib.request import urlretrieve

In [None]:
#| export
try:
    import torch.utils.data as torch_data
except ModuleNotFoundError:
    torch_data = None

## Dataloader

In [None]:
#| export
class Dataset:
    def __init__(self, X, y):
        self.X = X
        self.y = y
        assert self.X.shape[0] == self.y.shape[0]

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [None]:
#| export
class BaseDataLoader(ABC):
    def __init__(
        self, 
        dataset,
        backend: str,
        *,
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        num_workers: int = 0,
        drop_last: bool = False,
        **kwargs
    ):
        pass 
    
    def __len__(self):
        raise NotImplementedError
    
    def __next__(self):
        raise NotImplementedError
    
    def __iter__(self):
        raise NotImplementedError

#### Pytorch Dataloader

In [None]:
#| exporti
# copy from https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html#data-loading-with-pytorch
def _numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return [_numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

def _convert_dataset_pytorch(dataset: Dataset):
    class DatasetPytorch(torch_data.Dataset):
        def __init__(self, dataset: Dataset): self.dataset = dataset
        def __len__(self): return len(self.dataset)
        def __getitem__(self, idx): return self.dataset[idx]
    
    return DatasetPytorch(dataset)

In [None]:
#| export
class DataLoaderPytorch(BaseDataLoader):
    def __init__(
        self, 
        dataset: Dataset,
        backend: str = 'pytorch', # positional argument
        *,
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        num_workers: int = 0,
        drop_last: bool = False,
        **kwargs
    ):
        if torch_data is None:
            raise ModuleNotFoundError("`pytorch` library needs to be installed. Try `pip install torch`."
            "Please refer to pytorch documentation for details: https://pytorch.org/get-started/.")
        
        dataset = _convert_dataset_pytorch(dataset)
        self.dataloader = torch_data.DataLoader(
            dataset, 
            batch_size=batch_size, 
            shuffle=shuffle, 
            num_workers=num_workers, 
            drop_last=drop_last,
            collate_fn=_numpy_collate,
            **kwargs
        ) 

    def __len__(self):
        return len(self.dataloader)

    def __next__(self):
        return next(self.dataloader)

    def __iter__(self):
        return self.dataloader.__iter__()

#### Jax Dataloader

In [None]:
#| export
class DataLoaderJax(BaseDataLoader):
    def __init__(
        self, 
        dataset: Dataset,
        backend: str,
        *,
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        num_workers: int = 0,
        drop_last: bool = False,
        **kwargs
    ):
        # Attributes from pytorch data loader (implemented)
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

        self.data_len: int = len(dataset)  # Length of the dataset
        self.indices: np.ndarray = np.arange(self.data_len) # available indices in the dataset
        self.pose: int = 0  # record the current position in the dataset

    def __len__(self):
        if self.drop_last:
            batches = len(self.dataset) // self.batch_size  # get the floor of division
        else:
            batches = -(
                len(self.dataset) // -self.batch_size
            )  # get the ceil of division
        return batches

    def __next__(self):
        if self.pose <= self.data_len:
            if self.shuffle:
                self.indices = np.random.permutation(self.indices)
            batch_data = self.dataset[self.indices[: self.batch_size]]
            self.indices = self.indices[self.batch_size :]
            if self.drop_last and len(self.indices) < self.batch_size:
                self.pose = 0
                self.indices = np.arange(self.data_len)
                raise StopIteration
            self.pose += self.batch_size
            return batch_data
        else:
            self.pose = 0
            self.indices = np.arange(self.data_len)
            raise StopIteration

    def __iter__(self):
        return self

#### Dataloader

In [None]:
#| export
backend2dataloader = {
    'jax': DataLoaderJax,
    'pytorch': DataLoaderPytorch,
    'tensorflow': None,
    'merlin': None,
}

In [None]:
#| exporti
def _dispatch_datalaoder(backend: str):
    dataloader_backends = backend2dataloader.keys()
    if not backend in dataloader_backends:
        raise ValueError(f"backend=`{backend}` is an invalid backend for dataloader. "
            f"Should be one of {dataloader_backends}.")
    
    dataloader_cls = backend2dataloader[backend]
    if dataloader_cls is None:
        raise NotImplementedError(f'backend=`{backend}` is not supported yet.')
    return dataloader_cls


In [None]:
#| export
class DataLoader(BaseDataLoader):
    def __init__(
        self,
        dataset,
        backend,
        *,
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        num_workers: int = 0,
        drop_last: bool = False,
        **kwargs
    ):
        self.__class__ = _dispatch_datalaoder(backend)
        self.__init__(
            dataset=dataset, 
            backend=backend, 
            batch_size=batch_size, 
            shuffle=shuffle, 
            num_workers=num_workers,
            drop_last=drop_last,
            **kwargs
        )

We want to train a simple regression model.

In [None]:
from sklearn.datasets import make_regression

In [None]:
def loss(w, x, y):
    return jnp.mean(vmap(optax.l2_loss)(x @ w.T, y))

def step(w, x, y):
    lr = 0.1
    grad = jax.grad(loss)(w, x, y)
    w -= lr * grad
    return w

def train(dataloader, key):
    w = jax.random.normal(key, shape=(1, 20))
    n_epochs = 10
    for _ in range(n_epochs):
        for x, y in dataloader:
            w = step(w, x, y)
    return w
    

In [None]:
X, y = make_regression(n_samples=10000, n_features=20)
dataset = Dataset(X, y.reshape(-1, 1))
keys = hk.PRNGSequence(0)

In [None]:
#| slow
dataloader = DataLoader(dataset, 'pytorch', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
w.shape

(1, 20)

In [None]:
#| slow
dataloader = DataLoader(dataset, 'jax', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
w.shape

(1, 20)

In [None]:
#| include: false
# TODO: bug when N_data % batch_size == 0
# dataloader = DataLoader(dataset, batch_size=100, shuffle=True)
# w = train(dataloader, next(keys)).block_until_ready()
# train(dataloader)

## DataModule

In [None]:
#| export
def find_imutable_idx_list(
    imutable_col_names: List[str],
    discrete_col_names: List[str],
    continuous_col_names: List[str],
    cat_arrays: List[List[str]],
) -> List[int]:
    imutable_idx_list = []
    for idx, col_name in enumerate(continuous_col_names):
        if col_name in imutable_col_names:
            imutable_idx_list.append(idx)

    cat_idx = len(continuous_col_names)

    for i, (col_name, cols) in enumerate(zip(discrete_col_names, cat_arrays)):
        cat_end_idx = cat_idx + len(cols)
        if col_name in imutable_col_names:
            imutable_idx_list += list(range(cat_idx, cat_end_idx))
        cat_idx = cat_end_idx
    return imutable_idx_list

In [None]:
#| export
def _data_name2configs(data_name: str):
    with open('../assets/configs/{}.json'.format(data_name)) as json_file:
        data = json.load(json_file)
        data_configs['data_name'] = data_name
        data_configs['discret_cols'] = data['discret_cols']
        data_configs['continous_cols'] = data['continous_cols']
        data_configs['imutable_cols'] = data.get('imutable_cols', [])
        data_configs['sample_frac'] = data.get('sample_frac', [])
        data_configs['normalizer'] = data.get('normalizer', [])
        data_configs['encoder'] = data.get('encoder', [])
        data_configs['data_dir'] = _download_data(data_name)
    return data_configs

def _download_data(data_name: str):
        url = 'https://github.com/BirkhoffG/cfnet/raw/master/assets/data/{}.csv'.format(data_name)
        path = Path(os.getcwd())
        path = path / "cf_data"
        if not path.exists():
            os.makedirs(path)
        path = path / f'{data_name}.csv'
        if path.is_file():
            return path
        else:
            urlretrieve(url,path)
            return path

In [None]:
#| export
class DataModuleConfigs(BaseParser):
    batch_size: int
    discret_cols: List[str] = []
    continous_cols: List[str] = []
    imutable_cols: List[str] = []
    normalizer: Optional[Any] = None
    encoder: Optional[Any] = None
    sample_frac: Optional[float] = None
    backend: str = 'jax'

In [None]:
#| export
class TabularDataModule:
    discret_cols: List[str] = []
    continous_cols: List[str] = []
    imutable_cols: List[str] = []
    normalizer: Optional[Any] = None
    encoder: Optional[OneHotEncoder] = None
    data: Optional[pd.DataFrame] = None
    sample_frac: Optional[float] = None
    batch_size: int = 128
    backend: str = 'jax'
    data_name: str = ""

    def __init__(self, data_configs: dict | str = None):
        if isinstance(data_configs, str):
            data_configs = _data_name2configs(data_configs)
            self.data = pd.read_csv(Path(data_configs['data_dir']))
        elif isinstance(data_configs, dict):
            # read data
            self.data = pd.read_csv(Path(data_configs['data_dir']))

        # update configs
        self._update_configs(data_configs)
        self.check_cols()
        # update cat_idx
        self.cat_idx = len(self.continous_cols)
        # prepare data
        self.prepare_data()


    def check_cols(self):
        self.data = self.data.astype({col: np.float for col in self.continous_cols})
        # check imutable cols
        cols = self.continous_cols + self.discret_cols
        for col in self.imutable_cols:
            assert (
                 col in cols
             ), f"imutable_cols=[{col}] is not specified in `continous_cols` or `discret_cols`."

    def _update_configs(self, configs):
        for k, v in configs.items():
            setattr(self, k, v)

    def prepare_data(self):
        def split_x_and_y(data: pd.DataFrame):
            X = data[data.columns[:-1]]
            y = data[[data.columns[-1]]]
            return X, y

        X, y = split_x_and_y(self.data)

        # preprocessing
        if self.normalizer:
            X_cont = self.normalizer.transform(X[self.continous_cols])
        else:
            self.normalizer = MinMaxScaler()
            X_cont = (
                 self.normalizer.fit_transform(X[self.continous_cols])
                 if self.continous_cols
                 else np.array([[] for _ in range(len(X))])
             )

        if self.encoder:
            X_cat = self.encoder.transform(X[self.discret_cols])
        else:
            self.encoder = OneHotEncoder(sparse=False)
            X_cat = (
                 self.encoder.fit_transform(X[self.discret_cols])
                 if self.discret_cols
                 else np.array([[] for _ in range(len(X))])
             )
        X = np.concatenate((X_cont, X_cat), axis=1)
        # get categorical arrays
        self.cat_arrays = self.encoder.categories_ if self.discret_cols else []
        self.imutable_idx_list = find_imutable_idx_list(
            imutable_col_names=self.imutable_cols,
            discrete_col_names=self.discret_cols,
            continuous_col_names=self.continous_cols,
            cat_arrays=self.cat_arrays,
        )

        # prepare train & test
        train_test_tuple = train_test_split(X, y.to_numpy(), shuffle=False)
        train_X, test_X, train_y, test_y = map(
             lambda x: x.astype(jnp.float32), train_test_tuple
         )
        if self.sample_frac:
            train_size = int(len(train_X) * self.sample_frac)
            train_X, train_y = train_X[:train_size], train_y[:train_size]
        self.train_dataset = Dataset(train_X, train_y)
        self.val_dataset = Dataset(test_X, test_y)
        self.test_dataset = self.val_dataset

    def train_dataloader(self, batch_size):
        return DataLoader(
             self.train_dataset,
             self.backend,
             batch_size=batch_size,
             shuffle=True,
             num_workers=0,
             drop_last=False
         )

    def val_dataloader(self, batch_size):
        return DataLoader(
             self.val_dataset,
             self.backend,
             batch_size=batch_size,
             shuffle=True,
             num_workers=0,
             drop_last=False
         )

    def test_dataloader(self, batch_size):
        return DataLoader(
             self.val_dataset,
             self.backend,
             batch_size=batch_size,
             shuffle=True,
             num_workers=0,
             drop_last=False
         )

    def get_sample_X(self, frac: float | None = None):
        train_X, _ = self.get_samples(frac)
        return train_X

    def get_samples(self, frac: float | None = None):
        if frac is None:
            frac = 0.1
        train_X, train_y = self.train_dataset[:]
        train_size = int(len(train_X) * frac)
        return train_X[:train_size], train_y[:train_size]

In [None]:
#| include: false
data_configs = {
    "data_dir": "assets/data/s_adult.csv",
    "data_name": "adult",
    "batch_size": 256,
    'sample_frac': 0.1,
    "continous_cols": [
        "age",
        "hours_per_week"
    ],
    "discret_cols": [
        "workclass",
        "education",
        "marital_status",
        "occupation",
        "race",
        "gender"
    ],
}

In [None]:
dm = TabularDataModule(data_configs)
seed=42
batch_size=data_configs["batch_size"]
t_dataloader = dm.train_dataloader(batch_size)
x, y = next(iter(t_dataloader))
assert x.shape[0] == 256
assert x.shape[1] == 29
assert dm.sample_frac == 0.1

l = 0
t_dataloader = dm.train_dataloader(batch_size)
for i in t_dataloader:
    l += 1
assert l == len(t_dataloader)

t_dataloader = dm.val_dataloader(batch_size)
x, y = next(iter(t_dataloader))
assert x.shape[0] == 256
assert x.shape[1] == 29



Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  self.data = self.data.astype({col: np.float for col in self.continous_cols})


In [None]:
dm_2 = TabularDataModule("adult")
seed=42
batch_size=dm_2.batch_size
t_dataloader = dm.train_dataloader(batch_size)
x, y = next(iter(t_dataloader))
print(dm.data)
assert x.shape[0] == 256
assert x.shape[1] == 29
assert dm.sample_frac == 0.1

l = 0
t_dataloader = dm.train_dataloader(batch_size)
for i in t_dataloader:
    l += 1
assert l == len(t_dataloader)

t_dataloader = dm.val_dataloader(batch_size)
x, y = next(iter(t_dataloader))
assert x.shape[0] == 256
assert x.shape[1] == 29

        age  hours_per_week      workclass     education marital_status  \
0      42.0            45.0        Private       HS-grad        Married   
1      32.0            40.0  Self-Employed  Some-college        Married   
2      35.0            40.0        Private         Assoc         Single   
3      36.0            40.0        Private       HS-grad         Single   
4      57.0            35.0        Private        School        Married   
...     ...             ...            ...           ...            ...   
32556  66.0            40.0        Private     Bachelors        Married   
32557  35.0            80.0  Self-Employed       HS-grad        Married   
32558  21.0            10.0     Government  Some-college         Single   
32559  24.0            40.0        Private  Some-college        Married   
32560  46.0            40.0        Private       HS-grad        Married   

         occupation   race  gender  income  
0       Blue-Collar  White    Male       1  
1       B

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  self.data = self.data.astype({col: np.float for col in self.continous_cols})
