In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
from nb_001b import *
import sys, PIL, matplotlib.pyplot as plt, itertools, math, random, collections, torch
import scipy.stats, scipy.special

from enum import Enum, IntEnum
from torch import tensor, FloatTensor, LongTensor, ByteTensor, DoubleTensor, HalfTensor, ShortTensor
from operator import itemgetter, attrgetter
from numpy import cos, sin, tan, tanh, log, exp

from functools import reduce
from collections import defaultdict, abc, namedtuple, Iterable
from PIL import Image

# CIFAR subset data

First we want to view our data to check if everything is how we expect it to be.

## Setup

In [None]:
DATA_PATH = Path('data')
PATH = DATA_PATH/'cifar10_dog_air'
TRAIN_PATH = PATH/'train'

In [None]:
dog_fn = list((TRAIN_PATH/'dog').iterdir())[0]
dog_image = Image.open(dog_fn)
dog_image.resize((256,256))

In [None]:
air_fn = list((TRAIN_PATH/'airplane').iterdir())[0]
air_image = Image.open(air_fn)
air_image.resize((256,256))

## Simple Dataset/Dataloader

We will build a Dataset class for our image files. A Dataset class needs to have two functions: length and get-item. Our FilesDataset additionally gets the image files from their respective directories and transforms them to tensors.

In [None]:
#export
def find_classes(folder):
    classes = [d for d in folder.iterdir()
               if d.is_dir() and not d.name.startswith('.')]
    assert(len(classes)>0)
    return sorted(classes, key=lambda d: d.name)

def get_image_files(c):
    return [o for o in list(c.iterdir())
            if not o.name.startswith('.') and not o.is_dir()]

def pil2tensor(image):
    arr = torch.ByteTensor(torch.ByteStorage.from_buffer(image.tobytes()))
    arr = arr.view(image.size[1], image.size[0], -1)
    arr = arr.permute(2,0,1)
    return arr.float().div_(255)

def open_image(fn):
    x = PIL.Image.open(fn).convert('RGB')
    return pil2tensor(x)

In [None]:
#export
class FilesDataset(Dataset):
    def __init__(self, folder, classes=None):
        self.fns, self.y = [], []
        if classes is None: classes = [cls.name for cls in find_classes(folder)]
        self.classes = classes
        for i, cls in enumerate(classes):
            fnames = get_image_files(folder/cls)
            self.fns += fnames
            self.y += [i] * len(fnames)
        
    def __len__(self): return len(self.fns)
    def __getitem__(self,i): return open_image(self.fns[i]),self.y[i]

In [None]:
train_ds = FilesDataset(PATH/'train')
valid_ds = FilesDataset(PATH/'test')

In [None]:
len(train_ds), len(valid_ds)

In [None]:
#export
def image2np(image): return image.cpu().permute(1,2,0).numpy()

In [None]:
x,y = train_ds[0]
plt.imshow(image2np(x))
print(train_ds.classes[y])

In [None]:
bs=64

In [None]:
data = DataBunch(train_ds, valid_ds, bs=bs)
len(data.train_dl), len(data.valid_dl)

In [None]:
#export
def show_image(img, ax=None, figsize=(3,3), hide_axis=True):
    if ax is None: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(image2np(img))
    if hide_axis: ax.axis('off')

def show_image_batch(dl, classes, rows=None, figsize=(12,15)):
    x,y = next(iter(dl))
    if rows is None: rows = int(math.sqrt(len(x)))
    show_images(x[:rows*rows],y[:rows*rows],rows, classes)

def show_images(x,y,rows, classes, figsize=(9,9)):
    fig, axs = plt.subplots(rows,rows,figsize=figsize)
    for i, ax in enumerate(axs.flatten()):
        show_image(x[i], ax)
        ax.set_title(classes[y[i]])
    plt.tight_layout()

In [None]:
show_image_batch(data.train_dl, train_ds.classes, 6)

# Data augmentation

We are going augment our data to increase our training set with artificial images. These new images are basically "free" data that we can use in our training to help our model generalize better (reduce overfitting).

## Lighting

We will start by changing the **brightness** and **contrast** of our images.

### Method

**Brightness**

Brightness refers to where does our image stand on the dark-light spectrum. Brightness is applied by adding a positive constant to each of the image's channels. This works because each of the channels in an image goes from 0 (darkest) to 255 (brightest) in a dark-light continum. (0, 0, 0) is black (total abscence of light) and (255, 255, 255) is white (pure light). You can check how this works by experimenting by yourself [here](https://www.w3schools.com/colors/colors_rgb.asp).

_Parameters_

1. **Change** How much brightness do we want to add to (or take from) the image.

    $C \in \mathbb{R}$
    
**Contrast**

Contrast refers to how sharp a distinction there is between birghter and darker sections of our image. To increase contrast we need darker pixels to be darker and lighter pixels to be lighter. In other words, we would like channels with a value smaller than 128 to decrease and channels with a value of greater than 128 to increase.

_Parameters_

1. **Scale** How much contrast do we want to add to (or remove from) the image.

    $C \in [0, inf]$
    
***On logit and sigmoid***

Notice that for both transformations we first apply the logit to our tensor, then apply the transformation and finally take the sigmoid. This is important for two reasons. 

First, we don't want to overflow our tensor values. In other words, we need our final tensor values to be $T_{ij} \in [0,1]$.  Imagine, for instance, a tensor value at 0.99. We want to increase its brightness, but we can’t go over 1.0. By doing logit first, which first moves our space to $-inf$ to $+inf$, this works fine. The same applies to contrast if we have a scale $S > 1$ (might make some of our tensor values greater than one).

Second, when we apply contrast, we need to affect the dispersion of values around the middle value. Say we want to increase contrast. Then we need the bright values ($>0.5$) to get brighter and dark values ($<0.5$) to get darker. We must first transform our tensor values so our values which were originally $<0.5$ are now negative and our values which were originally $>0.5$ are positive. This way, when we multiply by a constant, the dispersion around 0 will increase. The logit function does exactly this and allows us to increase or decrease dispersion around a mid value.

In [None]:
#export
def logit(x):  return -(1/x-1).log()
def logit_(x): return (x.reciprocal_().sub_(1)).log_().neg_()

def brightness(x, change): return x.add_(scipy.special.logit(change))
def contrast(x, scale): return x.mul_(scale)

def _apply_lighting(x, func):
    if func is None: return x
    return func(logit_(x)).sigmoid()

def apply_lighting(func): return partial(_apply_lighting, func=func)

In [None]:
def apply_contrast(scale): return apply_lighting(partial(contrast, scale=scale))

In [None]:
x = lambda: train_ds[1][0]

In [None]:
_,axes = plt.subplots(1,4, figsize=(12,3))

show_image(x(), axes[0])
show_image(apply_contrast(1.0)(x()), axes[1])
show_image(apply_contrast(0.5)(x()), axes[2])
show_image(apply_contrast(2.0)(x()), axes[3])

In [None]:
def apply_brightness(change):
    return apply_lighting(partial(brightness, change=change))

In [None]:
_,axes = plt.subplots(1,4, figsize=(12,3))

show_image(x(), axes[0])
show_image(apply_brightness(0.5)(x()), axes[1])
show_image(apply_brightness(0.8)(x()), axes[2])
show_image(apply_brightness(0.2)(x()), axes[3])

In [None]:
#export
def listify(p=None, q=None):
    if p is None: p=[]
    elif not isinstance(p, Iterable): p=[p]
    n = q if type(q)==int else 1 if q is None else len(q)
    if len(p)==1: p = p * n
    return p

def compose(funcs):
    def _inner(x, *args, **kwargs):
        for f in funcs: x = f(x, *args, **kwargs)
        return x
    return _inner if funcs else None

In [None]:
def apply_brightness_contrast(scale_contrast, change_brightness):
    return apply_lighting(compose([
        partial(contrast, scale=scale_contrast),
        partial(brightness, change=change_brightness)
    ]))

In [None]:
_,axes = plt.subplots(1,4, figsize=(12,3))

show_image(apply_brightness_contrast(0.75, 0.7)(x()), axes[0])
show_image(apply_brightness_contrast(1.3,  0.3)(x()), axes[1])
show_image(apply_brightness_contrast(1.3,  0.7)(x()), axes[2])
show_image(apply_brightness_contrast(0.75, 0.3)(x()), axes[3])

## Random lighting

Next, we will make our previous transforms random since we are interested in automatizing the pipeline. We will achieve this by making our parameters stochastic with a specific distribution. 

We will use a <a href="https://en.wikipedia.org/wiki/Uniform_distribution_(continuous)"> uniform</a> distribution for brightness change since its domain is the real numbers and the impact varies linearly with the scale. For contrast we use [log_uniform](https://www.vosesoftware.com/riskwiki/LogUniformdistribution.php) for two reasons. First, contrast scale has a domain of $[0, inf]$. Second, the impact of the scale in the transformation is non-linear (i.e. 0.5 is as extreme as 2.0, 0.2 is as extreme as 5). The log_uniform function is appropriate because it has the same domain and correctly represents the non-linearity of the transform, $P(0.5) = P(2)$.

In [None]:
#export
def uniform(low, high, size=None):
    return random.uniform(low,high) if size is None else torch.FloatTensor(size).uniform_(low,high)

def log_uniform(low, high, size=None):
    res = uniform(log(low), log(high), size)
    return exp(res) if size is None else res.exp_()

def rand_bool(p, size=None): return uniform(0,1,size)<p

TfmType = IntEnum('TfmType', 'Start Affine Coord Pixel Lighting')

def brightness(x, change:uniform) -> TfmType.Lighting:
    return x.add_(scipy.special.logit(change))

def contrast(x, scale:log_uniform) -> TfmType.Lighting:
    return x.mul_(scale)

In [None]:
brightness.__annotations__

In [None]:
scipy.stats.gmean([log_uniform(0.5,2.0) for _ in range(1000)])

In [None]:
#export
import inspect

def get_default_args(func):
    return {k: v.default
            for k, v in inspect.signature(func).parameters.items()
            if v.default is not inspect.Parameter.empty}

def resolve_args(func, **kwargs):
    def_args = get_default_args(func)
    for k,v in func.__annotations__.items():
        if k == 'return': continue
        if not k in kwargs and k in def_args:
            kwargs[k] = def_args[k]
        else:
            arg = listify(kwargs.get(k, 1))
            kwargs[k] = v(*arg)
    return kwargs

def noop(x=None, *args, **kwargs): return x

In [None]:
resolve_args(brightness, change=(0.25,0.75),)

In [None]:
#export
class Transform():
    def __init__(self, func, p=1., **kwargs):
        self.func,self.p,self.kw = func,p,kwargs
        self.tfm_type = self.func.__annotations__['return']

    def __repr__(self):
        return f'{self.func.__name__}_tfm->{self.tfm_type.name}; {self.kw} (p={self.p})'

    def resolve(self):
        self.resolved = resolve_args(self.func, **self.kw)
        self.do_run = rand_bool(self.p)
    
    def __call__(self, x, *args, **kwargs):
        return self.func(x, *args, **self.resolved, **kwargs) if self.do_run else x

In [None]:
contrast_tfm = partial(Transform, contrast)
tfm = contrast_tfm(scale=(0.3,3))
tfm

In [None]:
# all the same
tfm.resolve()

_,axes = plt.subplots(1,4, figsize=(12,3))
for ax in axes: show_image(apply_lighting(tfm)(x()), ax)

In [None]:
tfm = contrast_tfm(scale=(0.3,3), p=0.5)

# different
_,axes = plt.subplots(1,4, figsize=(12,3))
for ax in axes:
    tfm.resolve()
    show_image(apply_lighting(tfm)(x()), ax)

## Decorator and composition

We are interested in composing the transform functions so as to apply them all at once. We will try to feed a list of transforms to our pipeline for it to apply all of them.

Applying a function to our transforms before calling them in Python is easiest if we use a decorator. You can find more about decorators [here](https://www.thecodeship.com/patterns/guide-to-python-function-decorators/).

In [None]:
#export
def reg_partial(cl, func):
    setattr(sys.modules[func.__module__], f'{func.__name__}_tfm', partial(cl,func))
    return func

def reg_transform(func): return reg_partial(Transform, func)

def resolve_tfms(tfms):
    for f in listify(tfms): f.resolve()

@reg_transform
def brightness(x, change: uniform) -> TfmType.Lighting:  return x.add_(scipy.special.logit(change))

@reg_transform
def contrast(x, scale: log_uniform) -> TfmType.Lighting: return x.mul_(scale)

In [None]:
def _apply_tfms(tfms, x):
    resolve_tfms(tfms)
    return apply_lighting(compose(tfms))(x.clone())

def apply_tfms(tfms): return partial(_apply_tfms, tfms)

In [None]:
x = train_ds[1][0]

In [None]:
tfms = [contrast_tfm(scale=(0.3,3.0), p=0.5),
        brightness_tfm(change=(0.35,0.65), p=0.5)]

_,axes = plt.subplots(1,4, figsize=(12,3))
for ax in axes: show_image(apply_tfms(tfms)(x), ax)

In [None]:
_,axes = plt.subplots(2,4, figsize=(12,6))
for i in range(4):
    tfm = apply_tfms(tfms)
    show_image(tfm(x), axes[0][i])
    show_image(tfm(x), axes[1][i])

In [None]:
show_image(apply_tfms([])(x))

# Affine

We will now add affine transforms that operate on the coordinates instead of pixels like the lighting transforms we just saw. An [affine transformation](https://en.wikipedia.org/wiki/Affine_transformation) is a function "(...) between affine spaces which preserves points, straight lines and planes." It is basically a transformation $f:\mathcal{X}\mapsto\mathcal{Y}$ of the form $ \mathcal{x}\mapsto \mathcal{M}\mathcal{x}+\mathcal{b}$ where $\mathcal{M}$ is a linear transformation on $\mathcal{X}$ and $\mathcal{b}$ is a vector in $\mathcal{Y}$.


### Affine Method

Our implementation first creates a grid of coordinates for the original image. The grid is normalized to a [-1, 1] range with (-1, -1) representing the top left corner, (1, 1) the bottom right corner and (0, 0) the center. Next, we build an affine matrix representing our desired transform and we multiply it by our original grid coordinates. The result will be a set of x, y coordinates which references where in the input image will each of the pixels in the output image be mapped. It has a size of w \* h \* 2 since it needs two coordinates for each of the h * w pixels of the output image. 

This is clearest if we see it graphically. We will build an affine matrix of the following form:

$$
\begin{bmatrix}
a & b & c \\
c & d  & f\\
\end{bmatrix}
$$

with which we will transform each pair of x, y coordinates in our original grid into our transformation grid:

$$
\begin{bmatrix}
a & b \\
c & d \\
\end{bmatrix} 
\times
\begin{bmatrix}
x \\
y \\
\end{bmatrix}
+
\begin{bmatrix}
c \\
f \\
\end{bmatrix}
=
\begin{bmatrix}
x^\prime \\
y^\prime \\
\end{bmatrix}
$$

So after the transform we will get a new grid with which to map our input image into our output image. This will be our **map of where from exactly does our transformation source each pixel in the output image**:

$$
\begin{bmatrix}
-1, -1 & x_2, -1 & ... & x_{n/2}, -1 & ... & 1, -1 \\
-1, y_2 & x_2, y_2  & ... & x_{n/2}, y_2 & ... & 1, y_2 \\
\vdots   & \vdots & \vdots & \vdots & \vdots & \vdots \\ 
-1, y_{n/2} & x_2, y_{n/2} & ... & 0, 0 & ... & 1, y_{n/2} \\
\vdots   & \vdots & \vdots & \vdots & \vdots & \vdots \\ 
-1, 1 & x_2, 1 & ... & x_{n/2}, 1 & ... & 1, 1        \\
\end{bmatrix}
\longmapsto
\begin{bmatrix}
x^\prime_1, y^\prime_1 & x^\prime_2, y^\prime_1 & ... & x^\prime_{n/2}, y^\prime_1 & ... & ^\prime_n, y^\prime_1 \\
x^\prime_1, y^\prime_2 & x^\prime_2, y^\prime_2  & ... & x^\prime_{n/2}, y^\prime_2 & ... & x^\prime_n, y^\prime_2 \\
\vdots   & \vdots & \vdots & \vdots & \vdots & \vdots \\ 
x^\prime_1, y^\prime_{n/2} & x^\prime_2, y^\prime_{n/2} & ... & x^\prime_{n/2}, y^\prime_{n/2} & ... & x^\prime_n, y^\prime_{n/2} \\
\vdots   & \vdots & \vdots & \vdots & \vdots & \vdots \\ 
x^\prime_1, y^\prime_n & x^\prime_2, y^\prime_n & ... & x^\prime_{n/2}, y^\prime_n & ... & x^\prime_n, y^\prime_n        \\
\end{bmatrix}
$$

**Enter problems**

Affine transforms face two problems that must be solved independently:
1. **The interpolation problem**: The result of our transformation gives us float coordinates, and we need to decide, for each (i,j), how to assign these coordinates to pixels in the input image.
2. **The missing pixel problem**: The result of our transformation may have coordinates which exceed the [-1, 1] range of our original grid and thus fall outside of our original grid.

**Solutions to problems**

1.  **The interpolation problem**: We will perform a [bilinear interpolation](https://en.wikipedia.org/wiki/Bilinear_interpolation). This takes an average of the values of the pixels corresponding to the four points in the grid surrounding the result of our transformation, with weights depending on how close we are to each of those points. 
2. **The missing pixel problem**: For these values we need padding, and we face a few options:

    1. Adding zeros on the side (so the pixels that fall out will be black)
    2. Replacing them by the value at the border
    3. Mirroring the content of the picture on the other side (reflect padding).
    
    
### Transformation Method

**Zoom**

Zoom changes the focus of the image according to a scale. If a scale of >1 is applied, grid pixels will be mapped to coordinates that are more central than the pixel's coordinates (closer to 0,0) while if a scale of <1 is applied, grid pixels will be mapped to more perispheric coordinates (closer to the borders) in the input image.

We can also translate our transform to zoom into a non-centrical area of the image. For this we use $col_c$ which displaces the x axis and $row_c$ which displaces the y axis.

_Parameters_

1. **Scale** How much do we want to zoom in or out to our image.

    $S \in \mathbb{R}$
        
2. **Col_pct** How much do we want to displace our zoom along the x axis.

    $C_{pct} \in [0, 1]$
    
    
3. **Row_pct** How much do we want to displace our zoom along the y axis.

    $R_{pct} \in [0, 1]$
    

<u>Affine matrix</u>

$$
\begin{bmatrix}
\frac{1}{scale} & 0 & col_c\\
0 & \frac{1}{scale} & row_c\\
\end{bmatrix}
$$

<u>Transform tensor</u>

$$
\begin{bmatrix}
\frac{x_1}{scale}+col_c, \frac{y_1}{scale}+row_c & \frac{x_2}{scale}+col_c, \frac{y_1}{scale}+row_c & ... & \frac{x_n}{scale}+col_c, \frac{y_1}{scale}+row_c \\
\frac{x_1}{scale}+col_c, \frac{y_2}{scale}+row_c & \frac{x_2}{scale}+col_c, \frac{y_2}{scale}+row_c & ... & \frac{x_n}{scale}+col_c, \frac{y_2}{scale}+row_c \\
\vdots   & \vdots & \vdots & \vdots \\ 
\frac{x_1}{scale}+col_c, \frac{y_n}{scale}+row_c & \frac{x_2}{scale}+col_c, \frac{y_n}{scale}+row_c & ... & \frac{x_n}{scale}+col_c, \frac{y_n}{scale}+row_c \\                                                
\end{bmatrix}
$$

**Rotate**

Rotate shifts the image around its center in a given angle $\theta$. The rotation is counterclockwise if $\theta$ is positive and clockwise if $\theta$ is negative. If you are curious about the derivation of the rotation matrix you can find it [here](https://matthew-brett.github.io/teaching/rotation_2d.html).

_Parameters_

1. **Degrees** By which angle do we want to rotate our image.

    $D \in \mathbb{R}$
        
<u>Affine matrix</u>

$$
\begin{bmatrix}
\cos(\theta) & -\sin(\theta) & 0\\
\sin(\theta) &  \cos(\theta) & 0\\
\end{bmatrix}
$$

<u>Transform tensor</u>

$$
\begin{bmatrix}
\cos(\theta)\cdot{x_1}-\sin(\theta)\cdot{y_1}, \sin(\theta)\cdot{x_1}+\cos(\theta)\cdot{y_1} & \cos(\theta)\cdot{x_2}-\sin(\theta)\cdot{y_1}, \sin(\theta)\cdot{x_2}+\cos(\theta)\cdot{y_1} & ... & \cos(\theta)\cdot{x_n}-\sin(\theta)\cdot{y_1}, \sin(\theta)\cdot{x_n}+\cos(\theta)\cdot{y_1} \\
\cos(\theta)\cdot{x_1}-\sin(\theta)\cdot{y_2}, \sin(\theta)\cdot{x_1}+\cos(\theta)\cdot{y_2} & \cos(\theta)\cdot{x_2}-\sin(\theta)\cdot{y_2}, \sin(\theta)\cdot{x_2}+\cos(\theta)\cdot{y_2} & ... & \cos(\theta)\cdot{x_n}-\sin(\theta)\cdot{y_2}, \sin(\theta)\cdot{x_n}+\cos(\theta)\cdot{y_2} \\
\vdots   & \vdots & \vdots & \vdots \\ 
\cos(\theta)\cdot{x_1}-\sin(\theta)\cdot{y_n}, \sin(\theta)\cdot{x_1}+\cos(\theta)\cdot{y_n} & \cos(\theta)\cdot{x_2}-\sin(\theta)\cdot{y_n}, \sin(\theta)\cdot{x_2}+\cos(\theta)\cdot{y_n} & ... & \cos(\theta)\cdot{x_n}-\sin(\theta)\cdot{y_n}, \sin(\theta)\cdot{x_n}+\cos(\theta)\cdot{y_n} \\
\end{bmatrix}
$$

## Deterministic affine

In [None]:
#export
def grid_sample_nearest(input, coords, padding_mode='zeros'):
    if padding_mode=='border': coords.clamp(-1,1)
    bs,ch,h,w = input.size()
    sz = torch.tensor([w,h]).float()[None,None]
    coords.add_(1).mul_(sz/2)
    coords = coords[0].round_().long()
    if padding_mode=='zeros':
        mask = (coords[...,0] < 0) + (coords[...,1] < 0) + (coords[...,0] >= w) + (coords[...,1] >= h)
        mask.clamp_(0,1)
    coords[...,0].clamp_(0,w-1)
    coords[...,1].clamp_(0,h-1)
    result = input[...,coords[...,1],coords[...,0]]
    if padding_mode=='zeros': result[...,mask] = result[...,mask].zero_()
    return result

In [None]:
#export
def grid_sample(x, coords, mode='bilinear', padding_mode='reflect'):
    if padding_mode=='reflect': padding_mode='reflection'
    if mode=='nearest': return grid_sample_nearest(x[None], coords, padding_mode)[0]
    return F.grid_sample(x[None], coords, mode=mode, padding_mode=padding_mode)[0]

def affine_grid(x, matrix, size=None):
    if size is None: size = x.size()
    elif isinstance(size, int): size=(x.size(0), size, size)
    return F.affine_grid(matrix[None,:2], torch.Size((1,)+size))

In [None]:
def rotate(x, degrees):
    angle = degrees * math.pi / 180
    return [[cos(angle), -sin(angle), 0.],
            [sin(angle),  cos(angle), 0.],
            [0.        ,  0.        , 1.]]

In [None]:
m = rotate(x, 30)
m = x.new_tensor(m)
c = affine_grid(x, m)
img2 = grid_sample(x, c, padding_mode='zeros')
show_image(img2)

In [None]:
#export
def affines_mat(matrices=None):
    if matrices is None or len(matrices) == 0: return None
    matrices = [FloatTensor(m) for m in matrices if m is not None]
    return reduce(torch.matmul, matrices, torch.eye(3))

def affine_mult(c,m):
    size = c.size()
    c = c.view(-1,2)
    c = torch.addmm(m[:2,2], c,  m[:2,:2].t()) 
    return c.view(size)

def _apply_affine(img, size=None, mats=None, func=None, **kwargs):
    c = affine_grid(img, torch.eye(3), size=size)
    if func is not None: c = func(c, img.size())
    if mats:
        m = affines_mat(mats)
        c = affine_mult(c, img.new_tensor(m))
    return grid_sample(img, c, **kwargs)

def apply_affine(mats=None, func=None): return partial(_apply_affine, mats=mats, func=func)

In [None]:
def zoom(x, scale: uniform, row_pct = 0.5, col_pct = 0.5):
    s = 1-1/scale
    col_c = s * (2*col_pct - 1)
    row_c = s * (2*row_pct - 1)
    return [[1/scale, 0,       col_c],
            [0,       1/scale, row_c],
            [0,       0,       1.    ]]

In [None]:
show_image(apply_affine([zoom(x, 0.6)])(x))

In [None]:
show_image(apply_affine([zoom(x, 0.6)])(x, padding_mode='zeros'))

In [None]:
show_image(apply_affine([zoom(x, 2, 0.2, 0.2)])(x))

In [None]:
img2 = apply_affine([rotate(x, 30)])(x)
img2 = apply_affine([zoom(x, 1.6)])(img2)
show_image(img2)

In [None]:
show_image(apply_affine([zoom(x,1.6), rotate(x,30)])(x))

In [None]:
show_image(x)

In [None]:
m = [zoom(x,1.6, 0.8, 0.2), rotate(x,30)]
show_image(apply_affine(m)(x, size=48))

In [None]:
m = [zoom(x,1.6, 0.8, 0.2), rotate(x,30)]
show_image(apply_affine(m)(x, size=24), hide_axis=False)

In [None]:
m = [zoom(x,1.6, 0.8, 0.2), rotate(x,30)]
show_image(apply_affine(m)(x, size=48, mode='nearest'))

In [None]:
show_image(apply_affine([zoom(x,1.6)])(x))

In [None]:
show_image(apply_affine([])(x))

## Random affine

As we did with the Lighting transform, we now want to build randomness into our pipeline so we can increase the automatization of the transform process. 

We will use a uniform distribution for both our transforms since their impact is linear and their domain is $\mathbb{R}$.

**Apply all transforms**

We will build a function called *apply_tfms* which will apply all the transforms to our image. We will make all transforms trying to do as little calculations as possible.

We do only one affine transformation by multiplying all the affine matrices of the transforms, then we apply to the coords any non-affine transformation we might want (jitter, elastic distorsion). Next, we crop the coordinates we want to keep and, by doing it before the interpolation, we don't need to compute pixel values that won't be used afterwards. Finally we perform the interpolation and we apply all the transforms that operate pixelwise (brightness, contrast).

In [None]:
#export
class AffineTransform(Transform):
    def __call__(self, *args, **kwargs):
        return self.func(*args, **self.resolved, **kwargs) if self.do_run else None
    
def dict_groupby(iterable, key=None):
    return {k:list(v) for k,v in itertools.groupby(sorted(iterable, key=key), key=key)}

def _apply_tfm_funcs(pixel_func,lighting_func,affine_func,start_func, x,**kwargs):
    if not np.any([pixel_func,lighting_func,affine_func,start_func]): return x
    x = x.clone()
    if start_func is not None:  x = start_func(x)
    if affine_func is not None: x = affine_func(x, **kwargs)
    if lighting_func is not None: x = lighting_func(x)
    if pixel_func is not None: x = pixel_func(x)
    return x

def apply_tfms(tfms):
    resolve_tfms(tfms)
    grouped_tfms = dict_groupby(listify(tfms), lambda o: o.tfm_type)
    start_tfms,affine_tfms,coord_tfms,pixel_tfms,lighting_tfms = [
        (grouped_tfms.get(o)) for o in TfmType]
    lighting_func = apply_lighting(compose(lighting_tfms))
    mats = [o() for o in listify(affine_tfms)]
    affine_func = apply_affine(mats, func=compose(coord_tfms))
    return partial(_apply_tfm_funcs,
        compose(pixel_tfms),lighting_func,affine_func,compose(start_tfms))

def reg_affine(func): return reg_partial(AffineTransform, func)

In [None]:
#export
@reg_affine
def rotate(degrees:uniform) -> TfmType.Affine:
    angle = degrees * math.pi / 180
    return [[cos(angle), -sin(angle), 0.],
            [sin(angle),  cos(angle), 0.],
            [0.        ,  0.        , 1.]]

def get_zoom_mat(sw, sh, c, r):
    return [[sw, 0,  c],
            [0, sh,  r],
            [0,  0, 1.]]

@reg_affine
def zoom(scale:uniform=1.0, row_pct:uniform=0.5, col_pct:uniform=0.5) -> TfmType.Affine:
    s = 1-1/scale
    col_c = s * (2*col_pct - 1)
    row_c = s * (2*row_pct - 1)
    return get_zoom_mat(1/scale, 1/scale, col_c, row_c)

@reg_affine
def squish(scale:uniform=1.0, row_pct:uniform=0.5, col_pct:uniform=0.5) -> TfmType.Affine:
    if scale <= 1: 
        col_c = (1-scale) * (2*col_pct - 1)
        return get_zoom_mat(scale, 1, col_c, 0.)
    else:          
        row_c = (1-1/scale) * (2*row_pct - 1)
        return get_zoom_mat(1, 1/scale, 0., row_c)

In [None]:
tfms = [rotate_tfm(degrees=(-45,45.), p=0.75),
        zoom_tfm(scale=(0.5,2.0), p=0.75)]

In [None]:
_,axes = plt.subplots(1,4, figsize=(12,3))
for ax in axes: show_image(apply_tfms(tfms)(x), ax)

In [None]:
tfms = [rotate_tfm(degrees=(-45,45.), p=0.75),
        zoom_tfm(scale=(1.0,2.0), row_pct=(0,1.), col_pct=(0,1.))]

_,axes = plt.subplots(1,4, figsize=(12,3))
for ax in axes: show_image(apply_tfms(tfms)(x, size=64, padding_mode='zeros'), ax)

In [None]:
scales = [0.75,0.9,1.1,1.33]

_,axes = plt.subplots(1,4, figsize=(12,3))
for i, ax in enumerate(axes): 
    show_image(apply_affine([squish(scales[i])])(x, size=64, padding_mode='zeros'), ax)

In [None]:
tfms = [squish_tfm(scale=(0.5,2), row_pct=(0,1.), col_pct=(0,1.))]

_,axes = plt.subplots(1,4, figsize=(12,3))
for ax in axes: show_image(apply_tfms(tfms)(x, size=64, padding_mode='zeros'), ax)

# Coord and pixel

## Jitter / flip

The last two transforms we will use are **jitter** and **flip**. 

**Jitter**

Jitter is a transform which adds a random value to each of the pixels to make them somewhat different than the original ones. In our implementation we first get a random number between (-1, 1) and we multiply it by a constant $M$ which scales it.

_Parameters_

1. **Magnitude** How much random noise do we want to add to each of the pixels in our image.

    $M \in [0, 1]$
    
**Flip**

Flip is a transform that reflects the image on a given axis.

_Parameters_

1. **P** Probability of applying the transformation to an input.

    $P \in [0, 1]$

In [None]:
#export
@reg_transform
def jitter(x, size, magnitude: uniform) -> TfmType.Coord:
    return x.add_((torch.rand_like(x)-0.5)*magnitude*2)

@reg_transform
def flip_lr(x) -> TfmType.Pixel: return x.flip(2)

In [None]:
tfm = jitter_tfm(magnitude=(0,0.1))

_,axes = plt.subplots(1,4, figsize=(12,3))
for ax in axes: show_image(apply_tfms(tfm)(x), ax)

In [None]:
tfm = flip_lr_tfm(p=0.5)

_,axes = plt.subplots(1,4, figsize=(12,3))
for ax in axes: show_image(apply_tfms(tfm)(x), ax)

## Combine

In [None]:
tfms = [flip_lr_tfm(p=0.5),
        rotate_tfm(degrees=(-45,45.), p=0.5),
        zoom_tfm(scale=(0.6,1.6), p=0.8),
        contrast_tfm(scale=(0.5,2.0)),
        brightness_tfm(change=(0.3,0.7))
]

In [None]:
_,axes = plt.subplots(1,4, figsize=(12,3))
for ax in axes: show_image(apply_tfms(tfms)(x), ax)

In [None]:
_,axes = plt.subplots(2,4, figsize=(12,6))
for i in range(4):
    tfm = apply_tfms(tfms)
    show_image(tfm(x, padding_mode='zeros', size=48), axes[0][i])
    show_image(tfm(x, mode='nearest'), axes[1][i])

## RandomResizedCrop (Torchvision version)

In [None]:
#export
def compute_zs_mat(sz, scale, squish, invert, row_pct, col_pct):
    orig_ratio = math.sqrt(sz[2]/sz[1])
    for s,r, i in zip(scale,squish, invert):
        s,r = math.sqrt(s),math.sqrt(r)
        if s * r <= 1 and s / r <= 1: #Test if we are completely inside the picture
            w,h = (s/r, s*r) if i else (s*r,s/r)
            w /= orig_ratio
            h *= orig_ratio
            col_c = (1-w) * (2*col_pct - 1)
            row_c = (1-h) * (2*row_pct - 1)
            return get_zoom_mat(w, h, col_c, row_c)
        
    #Fallback, hack to emulate a center crop without cropping anything yet.
    if orig_ratio > 1: return get_zoom_mat(1/orig_ratio**2, 1, 0, 0.)
    else:              return get_zoom_mat(1, orig_ratio**2, 0, 0.)

@reg_transform
def zoom_squish(c, sz, scale: uniform = 1.0, squish: uniform=1.0, invert: rand_bool = False, 
                row_pct:uniform = 0.5, col_pct:uniform = 0.5) -> TfmType.Coord:
    #This is intended for scale, squish and invert to be of size 10 (or whatever) so that the transform
    #can try a few zoom/squishes before falling back to center crop (like torchvision.RandomResizedCrop)
    m = compute_zs_mat(sz, scale, squish, invert, row_pct, col_pct)
    return affine_mult(c, FloatTensor(m))

In [None]:
random_resized_crop = zoom_squish_tfm(scale=(0.5,1,10), squish=(0.75,1.33,10), invert=(0.5,10),
                                      row_pct=(0,1.), col_pct=(0,1.))

In [None]:
_,axes = plt.subplots(2,4, figsize=(12,6))
for i in range(4):
    tfm = apply_tfms(random_resized_crop)
    show_image(tfm(x, size=(3,48,48)), axes[0][i])
    show_image(tfm(x, mode='nearest', size=(3,32,32)), axes[1][i])