<a href="https://colab.research.google.com/github/ajcommercial/AI_color_grade_lut/blob/master/Random_image_grade.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from PIL import Image
from scipy.interpolate import RegularGridInterpolator
from PIL import ImageFilter

from sklearn.neighbors import NearestNeighbors
import numpy as np
from sklearn.utils import shuffle

import glob
import os

import zipfile

from IPython.display import clear_output 

In [2]:
def list_to_lut(input_list, target_list,size=8,exp=2,frac=1):
    Target = np.asarray(target_list)
    Input = np.asarray(input_list)
    LUT_size = size
    a = LUT_size - 1
    LUT = []
    for k in range(LUT_size):
        for j in range(LUT_size):
            for i in range(LUT_size):
                LUT.append([i / a, j / a, k / a])
    
    LUT = np.asarray(LUT)
    for i in range(LUT.shape[0]):
        counter = 0
        denominator = 0
        for j in range(Input.shape[0]):
            if (np.linalg.norm(Target[j]-Input[j]))>0:
                counter = counter + 1/(np.sqrt(abs(np.sum((LUT[i]-Input[j])**2, axis=0)))**exp)*(Target[j]-Input[j])
                denominator = denominator+1/(np.sqrt(abs(np.sum((LUT[i]-Input[j])**2, axis=0)))**exp)

        LUT[i] = LUT[i] + (counter/denominator)*frac
        
    return LUT

def brightness(n,scale):
  pt_list = [0]
  dist = 1/(n+1)
  for i in range(1,n+1):
    target = i*dist
    min = pt_list[-1] + dist/3
    max = target + dist
    pt_list.append((np.clip(np.random.normal(target,scale=(1/(n*2))),min,max)))
  pt_list.append(1)
  pt_list = np.asarray(pt_list)
  val_list = np.clip(np.random.normal(pt_list,scale=0.6),0,1)
  pt_list[0] = 0
  pt_list[-1] = 1
  val_list[0] = 0
  val_list[-1] = 1
  pt_list = np.repeat(pt_list,3)
  pt_list = np.reshape(pt_list,(-1,3))
  val_list = np.repeat(val_list,3)
  val_list = np.reshape(val_list,(-1,3))
  

  if np.random.rand(1)[0] > 0.3:
    pt_list = np.delete(pt_list, -1, 0)
    val_list = np.delete(val_list, -1, 0)

  brightness_diff_list = pt_list - val_list

  return pt_list, brightness_diff_list


def random_lut(splits = 1,scale1 = 0.05,scale2=0.1,scale3=0.04):
  pt = [0]
  pt.append(np.clip(np.random.normal(0.5,scale=scale1),0.2,0.8))
  pt.append(1)
  pt_new =[]

  for i in range(2):
    rnd = np.random.normal(0.5,scale=0.2)
    pos = pt[i+1]*rnd+pt[i]*(1-rnd)
    pt_new.append(pt[i])
    pt_new.append(pos)
  
  pt_new.append(1)

  val = [(0,0,0)]

  for i in range(len(pt_new)-2):
    ind = i + 1 
    dflt = pt_new[ind]
    val.append(np.random.normal((dflt,dflt,dflt),scale=scale2))
  val.append((1,1,1))
  pt_new = np.asarray(pt_new)
  pt_new = np.repeat(pt_new,3)
  pt_new = np.reshape(pt_new,(-1,3))

  ####################################################################################
  ############################  Brightness only ######################################
  n = np.random.choice(3, 1)[0]+1
  brightness_factor = np.random.rand(1)[0]
  brightness_pt, brightness_diff_list = brightness(n,1)

  for i in range(pt_new.shape[0]):
    diff = brightness_pt - pt_new[i,:] 
    dist = np.sum(np.abs(diff)**2,axis=-1)
    if np.amin(dist) == 0:
      brightness_diff = brightness_diff_list[np.argmin(diff),:]

      val[i] += brightness_diff * brightness_factor
    else:
      weights = 1/dist**1.3

      brightness_diff = np.average(brightness_diff_list, weights=weights,axis=0)
      val[i] += brightness_diff * brightness_factor
  val = np.clip(np.asarray(val),0,1)
  #######################################################################################
  LUT = list_to_lut(pt_new,val)
  LUT_noise = LUT[:,:]
  LUT_noise = np.random.normal(LUT,scale=scale3)
  LUT_noise[0,:] = (0,0,0)
  LUT_noise[-1,:] = (1,1,1)
  return LUT

def write_lut(filename,lut):
   
    LUT = np.nan_to_num(lut)

    output = 'content/output/' + filename + '.cube'
    print(output)
    file = open(output, "w")
    file.write('TITLE "test"\n')
    file.write('LUT_3D_SIZE 8')
    file.write('\n')
    file.write('\n')
    for i in range(len(LUT)):
        bstr = str(int(LUT[i][0]*10000) / 10000) + ' ' + str(int(LUT[i][1]*10000) / 10000) + ' ' + str(int(LUT[i][2]*10000) / 10000) + '\n'
        file.write(bstr)
    file.close()

    return output

def concat_img(im1, im2):
    concat_img_out = Image.new('RGB', (im1.width + im2.width, im1.height))
    concat_img_out.paste(im1, (0, 0))
    concat_img_out.paste(im2, (im1.width, 0))
    return concat_img_out

def create_data(files, file_num, splits = 1,scale1 = 0.05,scale2=0.1,scale3=0.04):
  i = 0
  num_files = len(files)
  for file in files:
    i = i + 1
    clear_output()
    filename = str(file_num).zfill(5)
    file_num += 1

    LUT = random_lut(scale1 = scale1, scale2 = scale2, scale3= scale3)
    image = Image.open(file)
    image = image.resize((size,size))
    im_lut = image.filter(ImageFilter.Color3DLUT(8, random_lut(), channels=3))

    concat_img(image, im_lut).save('/content/output/'+filename+'.jpg')
    write_lut(filename,LUT)
    '''fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(211)
    ax.set_title('LUT applied to image')
    ax.imshow(im_lut)'''
    print("processing file " + str(i)+ " of " + str(num_files))
  return file_num

  


In [3]:
file_num = 0 
size = 256


os.chdir("/")

forbidden = ['/content/output.zip']

for file in glob.glob("/content/*.zip"):
    print(file)
    if file not in forbidden:
      forbidden.append(file)
      path_to_zip_file = file

base= os.path.basename(path_to_zip_file)
filename = os.path.splitext(base)[0] 

PATH = os.path.join('/content/',filename+ '/')

print(PATH)
try:
    os.mkdir('/content/output')
except OSError:
    print ("Directory  already exists" )
else:
    print ("Successfully created the directory " )

with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
    zip_ref.extractall(PATH)

files = (glob.glob(os.path.join(PATH, '*jpg')))


file_num = create_data(files,file_num)
file_num = create_data(files,file_num,scale2=0.05)
file_num = create_data(files,file_num,scale2=0.02,scale3 = 0.01)
file_num = create_data(files,file_num,scale1 = 0.002,scale2=0.002,scale3 = 0.001)

def zipdir(path, ziph):
    # ziph is zipfile handle
   files = (glob.glob(os.path.join('/content/output', '*')))
   for file in files:
      ziph.write(file)


zipf = zipfile.ZipFile('/content/output.zip', 'w', zipfile.ZIP_DEFLATED)
zipdir('/content/output/', zipf)
zipf.close()



content/output/03003.cube
processing file 751 of 751
