In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from skimage import data
from skimage.transform import rescale
import numpy as np

In [None]:
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor

In [None]:
import skimage as sk
import os
import random

In [None]:
from skimage.exposure import is_low_contrast

# Grayscale

In [None]:
from skimage.color import rgb2gray

In [None]:
PATH = Path('data_draw/data-aug/')

classes = os.listdir(PATH)

classes

In [None]:
def grayscale_internet(c):
    lc_count=0
    c_path = PATH/c
    fnames = [f for f in os.listdir(c_path) if f[:3]!=c[:3]]
#     fnames = [f for f in os.listdir(c_path)]
    for fname in fnames:
        img = sk.io.imread(c_path/fname)
        if not is_low_contrast(img):
            gc_img=np.repeat(rgb2gray(img)[:, :, np.newaxis], 3, axis=2)
            new_fname='.'.join(str(fname).split('.')[:-1]) + '-gc.png'
            sk.io.imsave(c_path/new_fname, gc_img)
        else: lc_count+=1
    print(f'{c} low contrast:{lc_count}')

In [None]:
def grayscale_movie(c):
    lc_count=0
    c_path = PATH/c
    fnames = [f for f in os.listdir(c_path) if f[:3]==c[:3]]
    random.shuffle(fnames)
    for fname in fnames[:len(fnames)//2]:
        img = sk.io.imread(c_path/fname)
        if not is_low_contrast(img):
            gc_img=np.repeat(rgb2gray(img)[:, :, np.newaxis], 3, axis=2)
            new_fname='.'.join(str(fname).split('.')[:-1]) + '-gc.png'
            sk.io.imsave(c_path/new_fname, gc_img)
        else: lc_count+=1
    print(f'{c} low contrast:{lc_count}')

In [None]:
# test run
# grayscale_internet('hercules')

In [None]:
# test run
# grayscale_movie('pocahontas')

In [None]:
# real run with multiprocessing
with ProcessPoolExecutor(max_workers=4) as executor:
    for c,_ in zip(classes,executor.map(grayscale_internet,classes)):
        continue

In [None]:
# real run with multiprocessing
with ProcessPoolExecutor(max_workers=4) as executor:
    for c,_ in zip(classes,executor.map(grayscale_movie,classes)):
        continue

# Other type of augmentation

In [None]:
from skimage.util import random_noise
from skimage import util
from skimage.transform import rotate
from skimage import exposure
from skimage.transform import rescale

In [None]:
PATH = Path('data_draw/data-aug/')

In [None]:
classes = os.listdir(PATH)

In [None]:
# def img_comparison(a,b):
#     return (a==b).sum() == a.shape[0]*a.shape[1]*a.shape[2]
def augmentation_internet(c):
    lc_count=0
    c_path = PATH/c
    fnames = [f for f in os.listdir(c_path) if f[:3]!=c[:3]]
    for fname in fnames:
        img = sk.io.imread(c_path/fname)
        if not is_low_contrast(img):
            aug_img=np.invert(img)
            new_fname='.'.join(str(fname).split('.')[:-1]) + '-invert.png'
            sk.io.imsave(c_path/new_fname, aug_img)       

            
            aug_img=random_noise(img,mode='poisson')
            new_fname='.'.join(str(fname).split('.')[:-1]) + '-noise.png'
            sk.io.imsave(c_path/new_fname, aug_img)


#         aug_img = rescale(img, 3.0 / 4.0)
#         new_fname='.'.join(str(fname).split('.')[:-1]) + '-rescale.png'
#         sk.io.imsave(c_path/new_fname, aug_img)
        
        v_min, v_max = np.percentile(img, (0.2, 99.8))
        aug_img = exposure.rescale_intensity(img, in_range=(v_min, v_max))
        new_fname='.'.join(str(fname).split('.')[:-1]) + '-intensity.png'
        sk.io.imsave(c_path/new_fname, aug_img)
        if is_low_contrast(img):
            lc_count+=1
            aug_img=np.invert(aug_img)
            new_fname+='.invert.png'
            sk.io.imsave(c_path/new_fname, aug_img) 
            
        
        aug_img=exposure.adjust_log(img)
        new_fname='.'.join(str(fname).split('.')[:-1]) + '-log.png'
        if not is_low_contrast(aug_img):
            sk.io.imsave(c_path/new_fname, aug_img)

        aug_img=exposure.adjust_sigmoid(img)
        new_fname='.'.join(str(fname).split('.')[:-1]) + '-sigmoid.png'
        if not is_low_contrast(aug_img):
            sk.io.imsave(c_path/new_fname, aug_img)
            
    print(f'{c} low contrast:{lc_count}')

In [None]:
# test run
# augmentation_internet('hercules')

In [None]:
# real run multiprocessing
with ProcessPoolExecutor(max_workers=4) as executor:
    for c,_ in zip(classes,executor.map(augmentation_internet,classes)):
        continue

# Add grayscale to validation set (since test set will have grayscale sketch)

In [None]:
from skimage.color import rgb2gray

PATH = Path('data_draw/valid/')

classes = os.listdir(PATH)


In [None]:
def grayscale_valid(c):
    lc_count=0
    c_path = PATH/c
    fnames = [f for f in os.listdir(c_path)]
    for fname in fnames:
        img = sk.io.imread(c_path/fname)
        if not is_low_contrast(img):
            gc_img=np.repeat(rgb2gray(img)[:, :, np.newaxis], 3, axis=2)
            new_fname='.'.join(str(fname).split('.')[:-1]) + '-gc.png'
            sk.io.imsave(c_path/new_fname, gc_img)
        else: lc_count+=1
    print(f'{c} low contrast:{lc_count}')

In [None]:
with ProcessPoolExecutor(max_workers=4) as executor:
    for c,_ in zip(classes,executor.map(grayscale_valid,classes)):
        continue