# Solve for Gaussian approximations using optimization

In [1]:
%load_ext autoreload
%autoreload 2
# this only works on startup!
from jax import config
config.update("jax_enable_x64", True)

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
from gpu_utils import limit_gpu_memory_growth
limit_gpu_memory_growth()

from cleanplots import *
from tqdm import tqdm
from information_estimation import *
from image_utils import *
from gaussian_process_utils import *

from led_array.bsccm_utils import *
from bsccm import BSCCM
from jax import jit
import numpy as onp
import jax.numpy as np

bsccm = BSCCM('/home/hpinkard_waller/data/BSCCM/')

Opening BSCCM
Opened BSCCM


In [2]:
# load images, extract patches, and compute cov mats
edge_crop = 32
patch_size = 20
num_images = 20000
num_patches = 1000
channel = 'LED119'
eigenvalue_floor = 1e0

images = load_bsccm_images(bsccm, channel=channel, num_images=num_images, edge_crop=edge_crop, median_filter=False)

## Search through hyperparameter combos 

In [5]:
learning_rates = np.logspace(1, -8, 20)
batch_sizes = np.linspace(2, 50, 20).astype(int)
momentums = np.linspace(0, 0.999, 20)

# generate tuples of random hyperparameters
hyperparameter_tuples = []
for i in range(10000):
    lr = onp.random.choice(learning_rates)
    bs = onp.random.choice(batch_sizes)
    m = onp.random.choice(momentums)
    hyperparameter_tuples.append((lr, bs, m))

results = {}
for i, (learning_rate, batch_size, momentum) in enumerate(hyperparameter_tuples):
    best_hp_loss = np.inf

    patches = extract_patches(images, patch_size, num_patches=num_patches, seed=i)
    best_cov_mat, cov_mat_initial, mean_vec, best_loss, train_loss_history, val_loss_history = run_optimization(patches, momentum, learning_rate, batch_size, eigenvalue_floor=1e-3)

    if best_loss < best_hp_loss:
        best_hp_loss = best_loss
        best_hp = (learning_rate, batch_size, momentum)
        
    # collect results
    results[(learning_rate, batch_size, momentum)] = best_loss

    # print hyperparameters and their best loss
    print(f"best loss: {best_loss:.2f}\t\tLearning rate: {learning_rate:.3e}, Batch size: {batch_size}, Momentum: {momentum:.3e}")

Initial loss:  1602746.422515065
best loss: 4707.55		Learning rate: 3.793e-01, Batch size: 19, Momentum: 4.732e-01
Initial loss:  1434216.410025181
best loss: 609800.56		Learning rate: 1.000e-08, Batch size: 47, Momentum: 2.103e-01
Initial loss:  55755.08502441772
best loss: 4035.03		Learning rate: 1.438e-02, Batch size: 32, Momentum: 5.258e-02
Initial loss:  870590.6982012092
best loss: 4414.16		Learning rate: 1.438e-02, Batch size: 19, Momentum: 7.361e-01
Initial loss:  1859.4591420254694
best loss: 1912.88		Learning rate: 8.859e-08, Batch size: 14, Momentum: 6.835e-01
Initial loss:  327827.0813128798
best loss: 4163.06		Learning rate: 1.624e-03, Batch size: 24, Momentum: 9.990e-01
Initial loss:  312433.0524452216
best loss: 1907.54		Learning rate: 2.637e-07, Batch size: 50, Momentum: 2.629e-01
Initial loss:  1144733.7858392452
best loss: 1859.83		Learning rate: 7.848e-07, Batch size: 39, Momentum: 0.000e+00
Initial loss:  1438979.6445443
best loss: 72693.18		Learning rate: 2.976e-08

KeyboardInterrupt: 

In [6]:
# print the hyperparameters ranked from best to worst
sorted_results = sorted(results.items(), key=lambda x: x[1])
for hp, loss in sorted_results:
    print(f"best loss: {loss:.2f}\t\tLearning rate: {hp[0]:.3e}, Batch size: {hp[1]}, Momentum: {hp[2]:.3e}")

best loss: 1780.07		Learning rate: 3.793e-01, Batch size: 14, Momentum: 5.784e-01
best loss: 1787.82		Learning rate: 3.793e-01, Batch size: 24, Momentum: 9.464e-01
best loss: 1793.12		Learning rate: 3.793e-01, Batch size: 32, Momentum: 5.258e-01
best loss: 1793.39		Learning rate: 7.848e-07, Batch size: 34, Momentum: 8.413e-01
best loss: 1799.64		Learning rate: 2.336e-06, Batch size: 24, Momentum: 9.464e-01
best loss: 1802.56		Learning rate: 3.360e+00, Batch size: 44, Momentum: 8.413e-01
best loss: 1803.68		Learning rate: 6.952e-06, Batch size: 42, Momentum: 3.681e-01
best loss: 1809.83		Learning rate: 8.859e-08, Batch size: 47, Momentum: 7.361e-01
best loss: 1818.00		Learning rate: 7.848e-07, Batch size: 12, Momentum: 4.732e-01
best loss: 1819.19		Learning rate: 1.438e-02, Batch size: 29, Momentum: 7.361e-01
best loss: 1820.06		Learning rate: 1.000e-08, Batch size: 50, Momentum: 9.990e-01
best loss: 1833.46		Learning rate: 2.637e-07, Batch size: 9, Momentum: 3.681e-01
best loss: 1835.1