In [1]:
import matplotlib.pyplot as plt
import numpy as np
import numpy.linalg as la

from src.structured_random_features.src.models.weights import V1_weights

# Packages for fft and fitting data
from scipy import fftpack as fft
from sklearn.linear_model import Lasso

# Package for importing image representation
from PIL import Image, ImageOps

from src.V1_Compress import generate_Y, compress
import pandas as pd
import itertools
import dask
from dask.distributed import Client, progress
import seaborn as sns
import time
import os.path

In [2]:
def opt_hyperparams(data): 
    # Try to use pd group_by to group repetition and get avg
    
    # Among those average, get the lowest error hyperparam
    ordered_data = pd.DataFrame(data).sort_values(by = 'error', ascending = True)
    print(ordered_data.head(5))
    
    return ordered_data.head(1)

In [3]:
def run_sim(rep, alp, num, sz, freq, img_arr):
    num = int(num)
    img_arr = np.array([img_arr]).squeeze()
    dim = img_arr.shape
    n, m = dim

    # Generate V1 weight with y
    W = V1_weights(num, dim, sz, freq) 
    y = generate_Y(W, img_arr)
    W_model = W.reshape(num, n, m)
    
    # Call function and calculate error
    theta, reform, s = compress(W_model, y, alp)
    error = np.linalg.norm(img - reform, 'fro') / np.sqrt(m*n)
    
    return error, theta, reform, s

In [None]:
#DF version after looking at Desk method

# Set up hyperparameters that would affect results
delay_list = []
params = []
# result = []
alpha = np.logspace(-3, 3, 7)
rep = np.arange(10)
num_cell = [100, 200, 500]
cell_sz = [2, 5, 7]
sparse_freq = [1, 2, 5]

# Load Image
image_path = 'image/tree_part1.jpg'
image_nm = image_path.split('/')[1].split('.')[0]
img = Image.open(image_path)
img = ImageOps.grayscale(img)
img_arr = np.asarray(img)
save_path = os.path.join("result/{img_nm}".format(img_nm = image_nm))



search_list = [rep, alpha, num_cell, cell_sz, sparse_freq]

# All combinations of hyperparameter to try
search = list(itertools.product(*search_list))             
search_df = pd.DataFrame(search, columns= [ 'rep', 'alp', 'num_cell', 'cell_sz', 'sparse_freq'])
print(search_df.head())

# Call dask
client = Client()
client

# counter = 0; # Keep track of number of iteration. Debugging method
for p in search_df.values:
    delay = dask.delayed(run_sim)(*p, img_arr)
    delay_list.append(delay)
    
print('running dask completed')

futures = dask.persist(*delay_list)
print('futures completed')
progress(futures)
print('progressing futures')

# Compute the result
results = dask.compute(*futures)
print('result computed')
results_df = pd.DataFrame(results, columns=['error', 'theta', 'reform', 's'])

# Add error onto parameter
params_result_df = search_df.join(results_df['error'])

# save parameter_error data with error_results data
params_result_df.to_csv(os.path.join(save_path, "param_" + "_".join(str.split(time.ctime().replace(":", "_"))) + ".csv"))
results_df.to_csv(os.path.join(save_path, "result_" + "_".join(str.split(time.ctime().replace(":", "_"))) + ".csv"))




   rep    alp  num_cell  cell_sz  sparse_freq
0    0  0.001       100        2            1
1    0  0.001       100        2            2
2    0  0.001       100        2            5
3    0  0.001       100        5            1
4    0  0.001       100        5            2


Perhaps you already have a cluster running?
Hosting the HTTP server on port 40601 instead
distributed.diskutils - INFO - Found stale lock file and directory '/home/bans/Documents/research/dask-worker-space/worker-n_phe0v5', purging
distributed.diskutils - INFO - Found stale lock file and directory '/home/bans/Documents/research/dask-worker-space/worker-y5_se46z', purging
distributed.diskutils - INFO - Found stale lock file and directory '/home/bans/Documents/research/dask-worker-space/worker-nx9l0vbo', purging
distributed.diskutils - INFO - Found stale lock file and directory '/home/bans/Documents/research/dask-worker-space/worker-huinoxe4', purging
distributed.diskutils - INFO - Found stale lock file and directory '/home/bans/Documents/research/dask-worker-space/worker-s2cu7eq4', purging
distributed.diskutils - INFO - Found stale lock file and directory '/home/bans/Documents/research/dask-worker-space/worker-ogwbiikx', purging
distributed.diskutils - INFO - Found stale lock file and d

running dask completed
futures completed
progressing futures


In [17]:
results_df = pd.DataFrame(results, columns=['error', 'theta', 'reform', 's'])
results_df

Unnamed: 0,error,theta,reform,s
0,120.462825,"[[1.3314698558103133, 1.4450490146676267, 0.25...","[[214.5524440606901, 421.37164766963167, 68.86...","[4006.5678338553676, 549.7383766153997, 859.24..."
1,1174.880316,"[[2.2658501931210258, 2.159015441463411, -0.10...","[[418.31973370410583, 628.6285591126782, -611....","[4427.58524644345, 338.71426137802086, 358.299..."
2,3218.789938,"[[-8.056397776508536, -10.914925866409241, -9....","[[1176.5329070761652, 201.77278920080857, -319...","[3627.6932363731135, 625.568655037431, 485.559..."
3,130.523395,"[[-1.99959489673949, -1.913904812362233, -0.11...","[[199.00489307999956, 199.4695008695916, 295.7...","[3814.4248700667736, 246.31643509804474, 373.2..."
4,1968.693287,"[[-3.165958141351731, 0.02044877171218687, 4.4...","[[96.07424176748549, 369.1576036817019, 83.208...","[4264.176171847444, 254.82562175562234, 621.15..."
...,...,...,...,...
1885,29.270912,"[[-4.493912303133953, 3.166294401211231, 1.970...","[[199.37130161646087, 200.44801128157897, 202....","[3908.755555159066, 329.7636997529017, 478.699..."
1886,29.716334,"[[0.4154514135117409, 0.6893957861801724, -2.9...","[[233.55422382725052, 230.97741318161283, 226....","[4008.6603909481546, 344.83701158362015, 514.8..."
1887,41.519969,"[[-2.9000191272963067, -1.6099846134944231, 4....","[[169.83444240575696, 169.53310055167825, 168....","[3560.416059777383, 220.3930225906392, 169.506..."
1888,29.674167,"[[2.8734925259700486, -2.0652079758592263, -3....","[[218.15129451564832, 216.52314569720008, 213....","[3943.4299242582847, 337.40189053378356, 466.3..."


In [20]:
temp = search_df.join(results_df['error'])
opt_hyperparams(temp)

      rep   alp  num_cell  cell_sz  sparse_freq      error
808     4  0.01       500        7            2  12.431534
52      0  0.01       500        7            2  12.445648
1018    5  0.10       500        2            2  12.492383
1750    9  0.01       500        5            2  12.547963
1564    8  0.01       500        7            2  12.556109


Unnamed: 0,rep,alp,num_cell,cell_sz,sparse_freq,error
808,4,0.01,500,7,2,12.431534


In [21]:
image_path = 'image/tree_part1.jpg'

In [40]:
img_nm = image_path.split('/')[1].split('.')[0]

'image/tree_part1'

In [38]:
image_nm

'image/tree_part1.jpg'