In [None]:
#| default_exp data

# Data

In [None]:
#| export
from datasets.dataset_dict import DatasetDict
import torchvision.transforms.functional as TF
from datasets import load_dataset
from operator import itemgetter
from typing import Tuple, Mapping, Sequence
import torch
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, default_collate

In [None]:
#| export
def get_dls(train_ds: Dataset, valid_ds: Dataset, 
            bs: int, **kwargs) -> Tuple[DataLoader, DataLoader]:
    return (DataLoader(train_ds, bs, shuffle=True, **kwargs),
            DataLoader(valid_ds, bs, shuffle=False, **kwargs))

class DataLoaders:
    def __init__(self, *dls: DataLoader) -> None:
        self.train, self.valid = dls[:2]
        
    @classmethod
    def from_dd(cls, dd: DatasetDict, bs: int, **kwargs):
        fn = collate_dict(dd['train'])
        dls = get_dls(*dd.values(), bs=bs, collate_fn=fn, **kwargs)
        
        return cls(*dls)

In [None]:
x, y = "image", "label"
name = "fashion_mnist"
dsd = load_dataset(name)

def transform(b):
    b[x] = [torch.flatten(TF.to_tensor(o)) for o in b[x]]
    return b

bs = 1024
tds = dsd.with_transform(transform)
dls = Dataloaders.from_dd(tds, bs)

Found cached dataset fashion_mnist (/Users/tk541/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/0a671f063342996f19779d38c0ab4abef9c64f757b35af8134b331c294d7ba48)


  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
#| export
def_device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
#| export
def to_device(x: Tensor, device: str = def_device) -> Tensor:
    if isinstance(x, Tensor): 
        return x.to(device)
        
    if isinstance(x, Mapping): 
        return {key: value.to(device) for key, value in x.items()}
        
    return type(x)(to_device(item) for item in x)

def to_cpu(x: Tensor) -> Tensor:
    if isinstance(x, Mapping):
        return {key: to_cpu(value) for key, value in x}
    
    if isinstance(x, Sequence):
        return type(x)(to_cpu(item) for item in x)
        
    result = x.detach().cpu()
    
    return result.float() if result.dtype == torch.float16 else result

In [None]:
#| export
def collate_dict(ds):
    get = itemgetter(*ds.features)
    
    def _fn(b):
        return get(default_collate(b))
    
    return _fn

def collate_device(b):
    return to_device(default_collate(b))

In [None]:
import nbdev; nbdev.nbdev_export()