# Solve for Gaussian approximations using optimization

In [1]:
%load_ext autoreload
%autoreload 2
# this only works on startup!
from jax import config
config.update("jax_enable_x64", True)

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
from gpu_utils import limit_gpu_memory_growth
limit_gpu_memory_growth()

from cleanplots import *
from tqdm import tqdm
from information_estimation import *
from image_utils import *
from gaussian_process_utils import *

from led_array.bsccm_utils import *
from bsccm import BSCCM
from jax import jit
import numpy as onp
import jax.numpy as np

bsccm = BSCCM('/home/hpinkard_waller/data/BSCCM/')

Opening BSCCM
Opened BSCCM


In [2]:
# load images, extract patches, and compute cov mats
edge_crop = 32
patch_size = 10
num_images = 20000
num_patches = 1000
channel = 'LED119'
eigenvalue_floor = 1e0

images = load_bsccm_images(bsccm, channel=channel, num_images=num_images, edge_crop=edge_crop, median_filter=False)

## Search through hyperparameter combos 

In [5]:
learning_rates = np.logspace(1, -8, 20)
batch_sizes = np.linspace(2, 50, 20).astype(int)
momentums = np.linspace(0, 0.999, 20)

# generate tuples of random hyperparameters
hyperparameter_tuples = []
for i in range(10000):
    lr = onp.random.choice(learning_rates)
    bs = onp.random.choice(batch_sizes)
    m = onp.random.choice(momentums)
    hyperparameter_tuples.append((lr, bs, m))

results = {}
for i, (learning_rate, batch_size, momentum) in enumerate(hyperparameter_tuples):
    best_hp_loss = np.inf

    patches = extract_patches(images, patch_size, num_patches=num_patches, seed=i)
    best_cov_mat, cov_mat_initial, mean_vec, best_loss, train_loss_history, val_loss_history = run_optimization(patches, momentum, learning_rate, batch_size, eigenvalue_floor=1e-3)

    if best_loss < best_hp_loss:
        best_hp_loss = best_loss
        best_hp = (learning_rate, batch_size, momentum)
        
    # collect results
    results[(learning_rate, batch_size, momentum)] = best_loss

    # print hyperparameters and their best loss
    print(f"best loss: {best_loss:.2f}\t\tLearning rate: {learning_rate:.3e}, Batch size: {batch_size}, Momentum: {momentum:.3e}")

Initial loss:  766508.7981262293
best loss: 839.84		Learning rate: 5.456e-04, Batch size: 29, Momentum: 2.629e-01
Initial loss:  334167.50536042696
best loss: 454.55		Learning rate: 7.848e-07, Batch size: 2, Momentum: 8.413e-01
Initial loss:  457.36715387610764
best loss: 456.71		Learning rate: 3.360e+00, Batch size: 32, Momentum: 3.681e-01
Initial loss:  195647.36018990882
best loss: 453.17		Learning rate: 6.952e-06, Batch size: 17, Momentum: 9.464e-01
Initial loss:  456.4432017029313
best loss: 459.04		Learning rate: 4.281e-02, Batch size: 47, Momentum: 5.784e-01
Initial loss:  309247.33653273515
best loss: 878.40		Learning rate: 1.833e-04, Batch size: 37, Momentum: 9.464e-01
Initial loss:  478.88158725723986
best loss: 457.70		Learning rate: 1.274e-01, Batch size: 24, Momentum: 8.938e-01
Initial loss:  171884.67195788733
best loss: 1375.07		Learning rate: 1.000e+01, Batch size: 34, Momentum: 4.732e-01
Initial loss:  879570.5153264167
best loss: 62453.70		Learning rate: 5.456e-04, Ba

KeyboardInterrupt: 

In [6]:
# print the hyperparameters ranked from best to worst
sorted_results = sorted(results.items(), key=lambda x: x[1])
for hp, loss in sorted_results:
    print(f"best loss: {loss:.2f}\t\tLearning rate: {hp[0]:.3e}, Batch size: {hp[1]}, Momentum: {hp[2]:.3e}")

best loss: 452.22		Learning rate: 5.456e-04, Batch size: 2, Momentum: 8.413e-01
best loss: 453.17		Learning rate: 6.952e-06, Batch size: 17, Momentum: 9.464e-01
best loss: 453.41		Learning rate: 1.000e+01, Batch size: 37, Momentum: 8.938e-01
best loss: 453.52		Learning rate: 2.637e-07, Batch size: 19, Momentum: 5.784e-01
best loss: 453.79		Learning rate: 3.360e+00, Batch size: 7, Momentum: 9.990e-01
best loss: 453.89		Learning rate: 2.069e-05, Batch size: 4, Momentum: 4.732e-01
best loss: 453.93		Learning rate: 2.336e-06, Batch size: 2, Momentum: 5.258e-02
best loss: 454.38		Learning rate: 2.976e-08, Batch size: 2, Momentum: 3.681e-01
best loss: 454.40		Learning rate: 6.952e-06, Batch size: 44, Momentum: 8.938e-01
best loss: 454.55		Learning rate: 7.848e-07, Batch size: 2, Momentum: 8.413e-01
best loss: 454.76		Learning rate: 2.336e-06, Batch size: 44, Momentum: 3.681e-01
best loss: 455.05		Learning rate: 1.129e+00, Batch size: 19, Momentum: 3.681e-01
best loss: 455.25		Learning rate: 