# Rectangular data loader

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
from nb_003a import *

In [None]:
DATA_PATH = Path('data')
PATH = DATA_PATH/'caltech101'

np.random.seed(42)
train_ds,valid_ds = FilesDataset.from_folder(PATH, test_pct=0.2)

x = train_ds[-1][0]
classes = train_ds.classes
c = len(classes)

len(train_ds),len(valid_ds),c

## Closest ntile

In [None]:
show_image(train_ds[1][0], figsize=(6,3))
x.shape

In [None]:
asp_ratios = [operator.truediv(*Image.open(fn).size) for fn in train_ds.fns]
asp_ratios[:4]

In [None]:
asp_ntiles = np.percentile(asp_ratios, [2,20,50,80,98])
asp_ntiles

In [None]:
def closest_ntile(aspect, ntiles):
    return ntiles[np.argmin(abs(log(aspect)-log(ntiles)))]

In [None]:
aspect = x.shape[2]/x.shape[1]
nearest_aspect = closest_ntile(aspect, asp_ntiles)
aspect,nearest_aspect

In [None]:
target_px = 128*128; target_px

In [None]:
target_r = int(math.sqrt(target_px/nearest_aspect))
target_c = int(target_r*nearest_aspect)
target_r,target_c,target_r*target_c

## SortAspectBatchSampler

In [None]:
asp_nearests = [closest_ntile(o, asp_ntiles) for o in asp_ratios]
asp_nearests[:10]

In [None]:
bs=32

In [None]:
from itertools import groupby

In [None]:
sort_nearest = sorted(enumerate(asp_nearests), key=itemgetter(1))
groups = [list(b) for a,b in groupby(sort_nearest, key=itemgetter(1))]
len(groups)

In [None]:
groups[0][:5]

In [None]:
sum(math.ceil(len(g)/bs) for g in groups)

In [None]:
@dataclass
class SortAspectBatchSampler(Sampler):
    ds:Dataset; bs:int; shuffle:bool = False

    def __post_init__(self):
        asp_ratios = [operator.truediv(*Image.open(img).size) for img in self.ds.fns]
        asp_ntiles = np.percentile(asp_ratios, [2,20,50,80,98])
        asp_nearests = [closest_ntile(o, asp_ntiles) for o in asp_ratios]
        sort_nearest = sorted(enumerate(asp_nearests), key=itemgetter(1))
        self.groups = [list(b) for a,b in groupby(sort_nearest, key=itemgetter(1))]
        self.n = sum(math.ceil(len(g)/bs) for g in self.groups)

    def __len__(self): return self.n
    
    def __iter__(self):
        if self.shuffle: self.groups = sample(self.groups, len(self.groups))
        for group in self.groups:
            group = [(a,{'aspect':b}) for a,b in group]
            if self.shuffle: group = sample(group, len(group))
            for i in range(0, len(group), self.bs): yield group[i:i+self.bs]

In [None]:
next(iter(SortAspectBatchSampler(train_ds, 4)))

In [None]:
next(iter(SortAspectBatchSampler(train_ds, 4, True)))

## Rectangular dataset

In [None]:
class TfmDataset(Dataset):
    def __init__(self, ds: Dataset, tfms: Collection[Callable] = None, **kwargs):
        self.ds,self.tfms,self.kwargs = ds,tfms,kwargs
        
    def __len__(self): return len(self.ds)
    
    def __getitem__(self,idx):
        if isinstance(idx, tuple): idx,xtra = idx
        else: xtra={}
        x,y = self.ds[idx]
        return apply_tfms(tfms)(x, **self.kwargs, **xtra), y

In [None]:
tfms = [
    rotate_tfm(degrees=(-20,20.)),
    zoom_tfm(scale=(1.,2.))
]

In [None]:
train_tds = TfmDataset(train_ds, tfms)

In [None]:
_,axes = plt.subplots(2,2, figsize=(8,6))
for ax in axes.flat: show_image(train_tds[1][0], ax, hide_axis=False)

In [None]:
tfms = [
    rotate_tfm(degrees=(-20,20.)),
    zoom_tfm(scale=(1.,3.)),
    crop_pad_tfm()
]

train_tds = TfmDataset(train_ds, tfms)

In [None]:
xtra = {'size':100}
train_tds[(1,xtra)][0].shape

In [None]:
_,axes = plt.subplots(2,2, figsize=(8,6))
for ax in axes.flat: show_image(train_tds[(1,xtra)][0], ax, hide_axis=False)