# Solve for Gaussian approximations using optimization

In [1]:
%load_ext autoreload
%autoreload 2
# this only works on startup!
from jax import config
config.update("jax_enable_x64", True)

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
from gpu_utils import limit_gpu_memory_growth
limit_gpu_memory_growth()

from cleanplots import *
from tqdm import tqdm
from information_estimation import *
from image_utils import *
from gaussian_process_utils import *

from led_array.bsccm_utils import *
from bsccm import BSCCM
from jax import jit
import numpy as onp
import jax.numpy as np

bsccm = BSCCM('/home/hpinkard_waller/data/BSCCM/')

Opening BSCCM
Opened BSCCM


In [2]:
# load images, extract patches, and compute cov mats
edge_crop = 32
patch_size = 10
num_images = 20000
num_patches = 1000
channel = 'LED119'
eigenvalue_floor = 1e0

images = load_bsccm_images(bsccm, channel=channel, num_images=num_images, edge_crop=edge_crop, median_filter=False)

## Search through hyperparameter combos 

In [3]:
learning_rates = np.logspace(1, -8, 20)
batch_sizes = np.linspace(2, 50, 20).astype(int)
momentums = np.linspace(0, 0.999, 20)

# generate tuples of random hyperparameters
hyperparameter_tuples = []
for i in range(10000):
    lr = onp.random.choice(learning_rates)
    bs = onp.random.choice(batch_sizes)
    m = onp.random.choice(momentums)
    hyperparameter_tuples.append((lr, bs, m))

results = {}
for i, (learning_rate, batch_size, momentum) in enumerate(hyperparameter_tuples):
    best_hp_loss = np.inf

    patches = extract_patches(images, patch_size, num_patches=num_patches, seed=i)
    best_cov_mat, cov_mat_initial, mean_vec, best_loss = run_optimization(patches, momentum, learning_rate, batch_size, eigenvalue_floor=1e-3)

    if best_loss < best_hp_loss:
        best_hp_loss = best_loss
        best_hp = (learning_rate, batch_size, momentum)
        
    # collect results
    results[(learning_rate, batch_size, momentum)] = best_loss

    # print hyperparameters and their best loss
    print(f"best loss: {best_loss:.2f}\t\tLearning rate: {learning_rate:.3e}, Batch size: {batch_size}, Momentum: {momentum:.3e}")

Initial loss:  630162.0724254091
best loss: 443.67		Learning rate: 2.976e-08, Batch size: 4, Momentum: 8.413e-01
Initial loss:  247570.5514104695
best loss: 1370.08		Learning rate: 3.360e+00, Batch size: 17, Momentum: 7.887e-01
Initial loss:  457.2525714721856
best loss: 451.93		Learning rate: 1.438e-02, Batch size: 37, Momentum: 9.464e-01
Initial loss:  184826.8824663631
best loss: 933.66		Learning rate: 1.624e-03, Batch size: 34, Momentum: 7.361e-01
Initial loss:  457.2256598088446
best loss: 450.88		Learning rate: 6.952e-06, Batch size: 42, Momentum: 7.887e-01
Initial loss:  329348.24360659026
best loss: 1326.84		Learning rate: 1.000e+01, Batch size: 47, Momentum: 3.681e-01
Initial loss:  191530.43684099428
best loss: 442.68		Learning rate: 8.859e-08, Batch size: 7, Momentum: 6.309e-01
Initial loss:  100856.11155377117
best loss: 487.34		Learning rate: 2.637e-07, Batch size: 42, Momentum: 2.103e-01
Initial loss:  887472.1436367106
best loss: 1297.42		Learning rate: 1.129e+00, Batch 

Traceback (most recent call last):
  File "/home/hpinkard_waller/mambaforge/envs/phenotypes/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3378, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_9866/3563418900.py", line 18, in <module>
    best_cov_mat, cov_mat_initial, mean_vec, best_loss = run_optimization(patches, momentum, learning_rate, batch_size, eigenvalue_floor=1e-3)
  File "/home/hpinkard_waller/GitRepos/EncodingInformation/gaussian_process_utils.py", line 488, in run_optimization
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/hpinkard_waller/mambaforge/envs/phenotypes/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 1997, in showtraceback
    stb = self.InteractiveTB.structured_traceback(
  File "/home/hpinkard_waller/mambaforge/envs/phenotypes/lib/python3.10/site-packages/IPython/core/ultratb.py", line 11

In [4]:
# print the hyperparameters ranked from best to worst
sorted_results = sorted(results.items(), key=lambda x: x[1])
for hp, loss in sorted_results:
    print(f"best loss: {loss:.2f}\t\tLearning rate: {hp[0]:.3e}, Batch size: {hp[1]}, Momentum: {hp[2]:.3e}")

best loss: 422.89		Learning rate: 2.976e-08, Batch size: 2, Momentum: 9.990e-01
best loss: 425.53		Learning rate: 4.281e-02, Batch size: 2, Momentum: 1.052e-01
best loss: 427.50		Learning rate: 1.000e-08, Batch size: 2, Momentum: 8.413e-01
best loss: 428.09		Learning rate: 6.952e-06, Batch size: 2, Momentum: 3.681e-01
best loss: 429.96		Learning rate: 1.129e+00, Batch size: 2, Momentum: 4.732e-01
best loss: 430.36		Learning rate: 2.069e-05, Batch size: 2, Momentum: 7.887e-01
best loss: 430.54		Learning rate: 2.637e-07, Batch size: 2, Momentum: 8.938e-01
best loss: 430.77		Learning rate: 6.952e-06, Batch size: 2, Momentum: 0.000e+00
best loss: 431.62		Learning rate: 2.637e-07, Batch size: 2, Momentum: 1.577e-01
best loss: 431.76		Learning rate: 7.848e-07, Batch size: 2, Momentum: 7.361e-01
best loss: 432.17		Learning rate: 4.833e-03, Batch size: 4, Momentum: 0.000e+00
best loss: 432.28		Learning rate: 1.438e-02, Batch size: 4, Momentum: 4.206e-01
best loss: 432.48		Learning rate: 2.069e