# Solve for Gaussian approximations using optimization

In [1]:
%load_ext autoreload
%autoreload 2
# this only works on startup!
from jax import config
config.update("jax_enable_x64", True)

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = '2'
from gpu_utils import limit_gpu_memory_growth
limit_gpu_memory_growth()

from cleanplots import *
from tqdm import tqdm
from information_estimation import *
from image_utils import *
from gaussian_process_utils import *

from led_array.bsccm_utils import *
from bsccm import BSCCM
from jax import jit
import numpy as onp
import jax.numpy as np

bsccm = BSCCM('/home/hpinkard_waller/data/BSCCM/')

Opening BSCCM
Opened BSCCM


In [2]:
# load images, extract patches, and compute cov mats
edge_crop = 32
patch_size = 10
num_images = 20000
num_patches = 1000
channel = 'DPC_Right'
eigenvalue_floor = 1e0

images = load_bsccm_images(bsccm, channel=channel, num_images=num_images, edge_crop=edge_crop, median_filter=False)

## Search through hyperparameter combos 

In [5]:
learning_rates = np.logspace(1, -8, 20)
batch_sizes = np.linspace(2, 50, 20).astype(int)
momentums = np.linspace(0, 0.999, 20)

# generate tuples of random hyperparameters
hyperparameter_tuples = []
for i in range(10000):
    lr = onp.random.choice(learning_rates)
    bs = onp.random.choice(batch_sizes)
    m = onp.random.choice(momentums)
    hyperparameter_tuples.append((lr, bs, m))

results = {}
for i, (learning_rate, batch_size, momentum) in enumerate(hyperparameter_tuples):
    best_hp_loss = np.inf

    patches = extract_patches(images, patch_size, num_patches=num_patches, seed=i)
    best_cov_mat, cov_mat_initial, mean_vec, best_loss, train_loss_history, val_loss_history = run_optimization(patches, momentum, learning_rate, batch_size, eigenvalue_floor=1e-3)

    if best_loss < best_hp_loss:
        best_hp_loss = best_loss
        best_hp = (learning_rate, batch_size, momentum)
        
    # collect results
    results[(learning_rate, batch_size, momentum)] = best_loss

    # print hyperparameters and their best loss
    print(f"best loss: {best_loss:.2f}\t\tLearning rate: {learning_rate:.3e}, Batch size: {batch_size}, Momentum: {momentum:.3e}")

Initial loss:  2406530.5159880687
best loss: 893.70		Learning rate: 5.456e-04, Batch size: 29, Momentum: 2.629e-01
Initial loss:  2757627.764051999
best loss: 498.60		Learning rate: 7.848e-07, Batch size: 2, Momentum: 8.413e-01
Initial loss:  1297362.5932994783
best loss: 1336.25		Learning rate: 3.360e+00, Batch size: 32, Momentum: 3.681e-01
Initial loss:  609355.4151318023
best loss: 499.57		Learning rate: 6.952e-06, Batch size: 17, Momentum: 9.464e-01
Initial loss:  1525201.8966025258
best loss: 1110.92		Learning rate: 4.281e-02, Batch size: 47, Momentum: 5.784e-01
Initial loss:  3085218.4325907985
best loss: 849.17		Learning rate: 1.833e-04, Batch size: 37, Momentum: 9.464e-01
Initial loss:  252622.9955842237
best loss: 1229.99		Learning rate: 1.274e-01, Batch size: 24, Momentum: 8.938e-01
Initial loss:  1102906.2787131583
best loss: 1457.13		Learning rate: 1.000e+01, Batch size: 34, Momentum: 4.732e-01
Initial loss:  2729193.9829202136
best loss: 831.46		Learning rate: 5.456e-04, B

KeyboardInterrupt: 

In [6]:
# print the hyperparameters ranked from best to worst
sorted_results = sorted(results.items(), key=lambda x: x[1])
for hp, loss in sorted_results:
    print(f"best loss: {loss:.2f}\t\tLearning rate: {hp[0]:.3e}, Batch size: {hp[1]}, Momentum: {hp[2]:.3e}")

best loss: 493.50		Learning rate: 1.000e-08, Batch size: 2, Momentum: 5.258e-01
best loss: 495.63		Learning rate: 6.952e-06, Batch size: 24, Momentum: 2.103e-01
best loss: 497.02		Learning rate: 4.833e-03, Batch size: 47, Momentum: 7.361e-01
best loss: 497.31		Learning rate: 3.793e-01, Batch size: 19, Momentum: 2.103e-01
best loss: 497.69		Learning rate: 2.637e-07, Batch size: 42, Momentum: 9.990e-01
best loss: 498.22		Learning rate: 7.848e-07, Batch size: 7, Momentum: 7.887e-01
best loss: 498.29		Learning rate: 2.976e-08, Batch size: 17, Momentum: 3.155e-01
best loss: 498.52		Learning rate: 2.336e-06, Batch size: 4, Momentum: 4.206e-01
best loss: 498.53		Learning rate: 1.000e+01, Batch size: 4, Momentum: 2.629e-01
best loss: 498.60		Learning rate: 7.848e-07, Batch size: 2, Momentum: 8.413e-01
best loss: 498.82		Learning rate: 2.336e-06, Batch size: 22, Momentum: 4.206e-01
best loss: 499.23		Learning rate: 6.952e-06, Batch size: 44, Momentum: 8.938e-01
best loss: 499.91		Learning rate:

: 