# Solve for Gaussian approximations using optimization

In [1]:
%load_ext autoreload
%autoreload 2
# this only works on startup!
from jax import config
config.update("jax_enable_x64", True)

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
from gpu_utils import limit_gpu_memory_growth
limit_gpu_memory_growth()

from cleanplots import *
from tqdm import tqdm
from information_estimation import *
from image_utils import *
from gaussian_process_utils import *

from led_array.bsccm_utils import *
from bsccm import BSCCM
from jax import jit
import numpy as onp
import jax.numpy as np

bsccm = BSCCM('/home/hpinkard_waller/data/BSCCM/')

Opening BSCCM
Opened BSCCM


In [2]:
# load images, extract patches, and compute cov mats
edge_crop = 32
patch_size = 10
num_images = 20000
num_patches = 10000
channel = 'LED119'
eigenvalue_floor = 1e0

images = load_bsccm_images(bsccm, channel=channel, num_images=num_images, edge_crop=edge_crop, median_filter=False)

## Search through hyperparameter combos 

In [None]:
learning_rates = np.logspace(1, -8, 20)
batch_sizes = np.linspace(2, 50, 20).astype(int)
momentums = np.linspace(0, 0.999, 20)

# generate tuples of random hyperparameters
hyperparameter_tuples = []
for i in range(10000):
    lr = onp.random.choice(learning_rates)
    bs = onp.random.choice(batch_sizes)
    m = onp.random.choice(momentums)
    hyperparameter_tuples.append((lr, bs, m))

results = {}
for i, (learning_rate, batch_size, momentum) in enumerate(hyperparameter_tuples):
    best_hp_loss = np.inf

    patches = extract_patches(images, patch_size, num_patches=num_patches, seed=i)
    best_cov_mat, cov_mat_initial, mean_vec, best_loss = run_optimization(patches, momentum, learning_rate, batch_size, eigenvalue_floor=1e-3)

    if best_loss < best_hp_loss:
        best_hp_loss = best_loss
        best_hp = (learning_rate, batch_size, momentum)
        
    # collect results
    results[(learning_rate, batch_size, momentum)] = best_loss

    # print hyperparameters and their best loss
    print(f"best loss: {best_loss:.2f}\t\tLearning rate: {learning_rate:.3e}, Batch size: {batch_size}, Momentum: {momentum:.3e}")

In [4]:
# print the hyperparameters ranked from best to worst
sorted_results = sorted(results.items(), key=lambda x: x[1])
for hp, loss in sorted_results:
    print(f"best loss: {loss:.2f}\t\tLearning rate: {hp[0]:.3e}, Batch size: {hp[1]}, Momentum: {hp[2]:.3e}")

best loss: 425.96		Learning rate: 2.069e-05, Batch size: 2, Momentum: 1.577e-01
best loss: 427.36		Learning rate: 1.833e-04, Batch size: 2, Momentum: 7.887e-01
best loss: 427.71		Learning rate: 1.833e-04, Batch size: 2, Momentum: 5.784e-01
best loss: 427.99		Learning rate: 2.069e-05, Batch size: 2, Momentum: 4.732e-01
best loss: 428.01		Learning rate: 2.976e-08, Batch size: 2, Momentum: 9.990e-01
best loss: 429.00		Learning rate: 2.637e-07, Batch size: 2, Momentum: 4.206e-01
best loss: 430.66		Learning rate: 1.000e-08, Batch size: 2, Momentum: 0.000e+00
best loss: 430.86		Learning rate: 8.859e-08, Batch size: 2, Momentum: 2.629e-01
best loss: 431.77		Learning rate: 5.456e-04, Batch size: 2, Momentum: 5.258e-01
best loss: 436.76		Learning rate: 2.976e-08, Batch size: 2, Momentum: 3.681e-01
best loss: 437.97		Learning rate: 4.833e-03, Batch size: 4, Momentum: 5.784e-01
best loss: 443.25		Learning rate: 2.069e-05, Batch size: 9, Momentum: 5.258e-02
best loss: 443.54		Learning rate: 7.848e