In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from nb_004c import *

# Carvana

In [None]:
PATH = Path('data/carvana')
PATH_PNG = PATH/'train_masks_png'
PATH_X = PATH/'train-128'
PATH_Y = PATH/'train_masks-128'

## Convert and resize data

In [None]:
PATH_PNG.mkdir(exist_ok=True)
PATH_X.mkdir(exist_ok=True)
PATH_Y.mkdir(exist_ok=True)

In [None]:
def convert_img(fn): Image.open(fn).save(PATH_PNG/f'{fn.name[:-4]}.png')

In [None]:
files = list((PATH/'train_masks').iterdir())
with ThreadPoolExecutor(8) as e: e.map(convert_img, files)

In [None]:
def resize_img(fn, dirname):
    Image.open(fn).resize((128,128)).save((fn.parent.parent)/dirname/fn.name)

In [None]:
files = list(PATH_PNG).iterdir())
with ThreadPoolExecutor(8) as e: e.map(partial(resize_img, dirname='train_masks-128'), files)

In [None]:
files = list((PATH/'train').iterdir())
with ThreadPoolExecutor(8) as e: e.map(partial(resize_img, dirname='train-128'), files)

## Basic transforms

In [None]:
img_f = next(PATH_X.iterdir())
img_x = open_image(img_f)
show_image(img_x)

In [None]:
def get_y_fn(x_fn): return f'{x_fn[:-4]}_mask.png'

In [None]:
img_y_f = PATH_Y/get_y_fn(img_f.name)
img_y = open_image(img_y_f)
show_image(img_y)

In [None]:
def x(): return open_image(img_f)
def y(): return open_image(img_y_f)

In [None]:
tfms = [flip_lr_tfm(p=0.5),
        rotate_tfm(degrees=(-10,10.), p=0.25),
        zoom_tfm(scale=(0.8,1.2), p=0.25),
        contrast_tfm(scale=(0.8,1.2)),
        brightness_tfm(change=(0.4,0.6))
]

In [None]:
_,axes = plt.subplots(1,4, figsize=(12,3))
for ax in axes: show_image(apply_pipeline(x(), tfms), ax)

## Rotation

In [None]:
def xy(): return x(),y()

In [None]:
resolve_args(brightness, change=(0.4,0.6))

In [None]:
def rotate_rand(x, y=None, smooth_y=True):
    args = resolve_args(rotate, degrees=(-45,45.))
    m = rotate(**args)
    x = do_affine(x, m)
    if y is None: return x
    
    y = do_affine(y, m)
    if not smooth_y: torch.round_(y)
    return x, y

In [None]:
imgx,imgy = rotate_rand(*xy(), smooth_y=False)
assert(torch.any((imgy>0.) & (imgy<1.)) == 0)

In [None]:
_,axes = plt.subplots(2,4, figsize=(12,6))
for i in range(4):
    imgx,imgy = rotate_rand(*xy(), smooth_y=False)
    show_image(imgx, axes[0][i])
    show_image(imgy, axes[1][i])

In [None]:
_,axes = plt.subplots(2,4, figsize=(12,6))
for i in range(4):
    imgx,imgy = rotate_rand(x(),x())
    show_image(imgx, axes[0][i])
    show_image(imgy, axes[1][i])

In [None]:
_,axes = plt.subplots(1,4, figsize=(12,6))
for ax in axes: show_image(rotate_rand(x()), ax)

## Affine transforms

In [None]:
def do_affine(img_x, img_y=None, m=None, funcs=None, smooth_y=True):
    if m is None: m=eye_new(img_x, 3)
    c = affine_grid(img_x,  img_x.new_tensor(m))
    c = compose(funcs)(c)
    img_x = grid_sample(img_x, c, padding='zeros')
    if img_y is None: return img_x

    img_y = grid_sample(img_y, c, padding='zeros')
    if not smooth_y: torch.round_(img_y)
    return img_x, img_y

In [None]:
def apply_pixel_tfm(func): 
    def _inner(x,y=None):
        logit_(x)
        if y is None: return func(x).sigmoid()
        logit_(y)
        x,y = func(x,y)
        return x.sigmoid(),y.sigmoid()
    
    return _inner

In [None]:
def apply_pipeline(tfms, x, y=None, smooth_y=True):
    tfms = listify(tfms)
    if len(tfms)==0: return x
    grouped_tfms = dict_groupby(tfms, lambda o: o.__annotations__['return'])
    pixel_tfms,coord_tfms,affine_tfms = map(grouped_tfms.get, TfmType)
    x = apply_pixel_tfm(compose(pixel_tfms))(x,y)
    if isinstance(x,tuple): x,y = x
    matrices = [f() for f in listify(affine_tfms)]
    return do_affine(x, y, affines_mat(x, matrices), funcs=coord_tfms, smooth_y=smooth_y)

In [None]:
tfms = [rotate_tfm(degrees=(-45,45.)), brightness_tfm(change=(0.3,0.7))]

In [None]:
imgx,imgy = apply_pipeline(tfms, *xy(), smooth_y=False)
assert(torch.any((imgy>0.) & (imgy<1.)) == 0)

In [None]:
_,axes = plt.subplots(2,4, figsize=(12,6))
for i in range(4):
    imgx,imgy = apply_pipeline(tfms, *xy(), smooth_y=False)
    show_image(imgx, axes[0][i])
    show_image(imgy, axes[1][i])

In [None]:
_,axes = plt.subplots(2,4, figsize=(12,6))
for i in range(4):
    imgx,imgy = apply_pipeline(tfms, x(),x())
    show_image(imgx, axes[0][i])
    show_image(imgy, axes[1][i])

In [None]:
_,axes = plt.subplots(1,4, figsize=(12,6))
for ax in axes: show_image(apply_pipeline(tfms, x()), ax)

In [None]:
tfms2 = [jitter_tfm(magnitude=(-0.1,0.1))]

_,axes = plt.subplots(1,4, figsize=(12,6))
for ax in axes: show_image(apply_pipeline(tfms2, x()), ax)