# Solve for Gaussian approximations using optimization

In [23]:
%load_ext autoreload
%autoreload 2
# this only works on startup!
from jax import config
config.update("jax_enable_x64", True)

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = '2'
from gpu_utils import limit_gpu_memory_growth
limit_gpu_memory_growth()

from cleanplots import *
from tqdm import tqdm
from information_estimation import *
from image_utils import *
from gaussian_process_utils import *

from led_array.bsccm_utils import *
from bsccm import BSCCM
from jax import jit
import numpy as onp
import jax.numpy as np

bsccm = BSCCM('/home/hpinkard_waller/data/BSCCM/')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Opening BSCCM


In [None]:
# load images, extract patches, and compute cov mats
edge_crop = 32
patch_size = 10
num_images = 1000
num_patches = 1000
channel = 'LED119'
eigenvalue_floor = 1e0

images = load_bsccm_images(bsccm, channel=channel, num_images=num_images, edge_crop=edge_crop, median_filter=False)
patches = extract_patches(images, patch_size, num_patches=num_patches)

In [None]:
from jax.scipy.linalg import toeplitz
from jax import grad, jit, value_and_grad

def make_doubly_toeplitz(top_row, patch_size):
    """
    Make a doubly toeplitz matrix from its top row, which is the
    minimum number of parameters needed to specify the matrix.
    """
    # split into rows
    top_rows = np.split(top_row, patch_size)
    # make into toeplitz blocks
    blocks = []
    for tr in top_rows:
        blocks.append(toeplitz(tr))
    # use blocks to construct doubl

    rows = []
    for i in range(len(blocks)):
        row_blocks = [blocks[abs(i - j)] for j in range(len(blocks))]
        row = np.hstack(row_blocks)
        rows.append(row)
    doubly_toeplitz_mat = np.vstack(rows)
    return doubly_toeplitz_mat


def gaussian_likelihood(cov_mat, mean_vec, batch):
    """
    Evaluate the log likelihood of a multivariate gaussian
    for a batch of NxWXH samples.
    """
    log_likelihoods = []
    for sample in batch:
        ll = jax.scipy.stats.multivariate_normal.logpdf(sample.reshape(-1), mean=mean_vec, cov=cov_mat)
        log_likelihoods.append(ll)
    return np.array(log_likelihoods)

def batch_nll(log_likelihoods):
    return -np.mean(log_likelihoods)

def loss_function(eigvals, eig_vecs, mean_vec, data):
    cov_mat = eig_vecs @ np.diag(eigvals) @ eig_vecs.T
    # cov_mat = make_doubly_toeplitz(cov_mat_row, patch_size)
    # cov_mat = make_positive_definite(cov_mat, eigenvalue_floor)
    ll = gaussian_likelihood(cov_mat, mean_vec, data)
    return batch_nll(ll)

def make_valid_stationary(eigvals, eig_vecs, eigenvalue_floor):
    eigvals = np.where(eigvals < eigenvalue_floor, eigenvalue_floor, eigvals)
    cov_mat = eig_vecs @ np.diag(eigvals) @ eig_vecs.T
    dt_cov_mat = make_doubly_toeplitz(cov_mat[0], patch_size)
    eigvals, eig_vecs = np.linalg.eigh(dt_cov_mat)
    eigvals = np.where(eigvals < eigenvalue_floor, eigenvalue_floor, eigvals)
    return eigvals, eig_vecs

@jit
def optmization_step(eigvals, eig_vecs, velocity, data, mean_vec, momentum, learning_rate, eigenvalue_floor):
    grad_fn = grad(loss_function, argnums=0)
    eigenvalues_grad = grad_fn(eigvals, eig_vecs, mean_vec, data)
    new_velocity = momentum * velocity - learning_rate * eigenvalues_grad
    eigvals = eigvals + new_velocity
    # prox operator: make sure make sure positive definite, make sure doubly toeplitz
    eigvals, eig_vecs = make_valid_stationary(eigvals, eig_vecs, eigenvalue_floor)
    loss = loss_function(eigvals, eig_vecs, mean_vec, data)
    return eigvals, eig_vecs, new_velocity, loss


def run_optimization(data, momentum, learning_rate, batch_size, eigenvalue_floor=1e-3):
    patch_size = int(np.sqrt(np.prod(np.array(data.shape)[1:])))
    # Initialize parameters, hyperparameters
    mean_vec = np.ones(patch_size**2) * np.mean(data)

    # initialize covariance matrix so likelihood is not nan
    cov_mat_initial = make_positive_definite(compute_stationary_cov_mat(data), eigenvalue_floor=eigenvalue_floor)

    initial_evs, initial_eig_vecs = make_valid_stationary(*np.linalg.eigh(cov_mat_initial), eigenvalue_floor)
    print('Initial loss: ', loss_function(initial_evs, initial_eig_vecs, mean_vec, data[:batch_size]))

    cov_mat_initial = initial_eig_vecs @ np.diag(initial_evs) @ initial_eig_vecs.T

    if np.isnan(jax.scipy.stats.multivariate_normal.logpdf(patches[0].flatten(), mean=mean_vec, cov=cov_mat_initial)):
        raise ValueError("Initial likelihood is nan")
    
    # Training loop
    eigvals = initial_evs
    eig_vecs = initial_eig_vecs
    velocity = np.zeros_like(eigvals)
    best_loss = np.inf
    key = jax.random.PRNGKey(onp.random.randint(0, 100000))
    for i in range(1000):
        # select a random batch
        batch_indices = jax.random.randint(key, shape=(batch_size,), minval=0, maxval=len(data))
        key, subkey = jax.random.split(key)
        batch = data[batch_indices]
        
        eigvals, eig_vecs, velocity, loss = optmization_step(eigvals, eig_vecs, velocity, 
                                                             batch, mean_vec, momentum, learning_rate, eigenvalue_floor)

        if loss < best_loss:
            best_loss = loss
            best_eigvals = eigvals
            best_eig_vecs = eig_vecs
        print(f"Iteration {i+1}, Loss: {loss}", end='\r')
    eigvals, eig_vecs = make_valid_stationary(best_eigvals, best_eig_vecs, eigenvalue_floor)
    best_cov_mat = eig_vecs @ np.diag(eigvals) @ eig_vecs.T
    return best_cov_mat, cov_mat_initial, mean_vec, best_loss
 

## Search through hyperparameter combos 

In [None]:
import jax
import jax.numpy as jnp


learning_rates = np.logspace(1, -8, 20)
batch_sizes = np.linspace(2, 50, 20).astype(int)
momentums = np.linspace(0, 0.999, 20)

# generate tuples of random hyperparameters
hyperparameter_tuples = []
for i in range(10000):
    lr = onp.random.choice(learning_rates)
    bs = onp.random.choice(batch_sizes)
    m = onp.random.choice(momentums)
    hyperparameter_tuples.append((lr, bs, m))

results = {}
for learning_rate, batch_size, momentum in hyperparameter_tuples:
    best_hp_loss = np.inf

    best_cov_mat, cov_mat_initial, mean_vec, best_loss = run_optimization(patches, momentum, learning_rate, batch_size, eigenvalue_floor=1e-3)

    if best_loss < best_hp_loss:
        best_hp_loss = best_loss
        best_hp = (learning_rate, batch_size, momentum)
        
    # collect results
    results[(learning_rate, batch_size, momentum)] = best_loss

    # print hyperparameters and their best loss
    print(f"best loss: {best_loss:.2f}\t\tLearning rate: {learning_rate:.3e}, Batch size: {batch_size}, Momentum: {momentum:.3e}")

Initial loss:  323799.07958974945
Iteration 72, Loss: 797.44085060267017

KeyboardInterrupt: 

In [None]:
# print the hyperparameters ranked from best to worst
sorted_results = sorted(results.items(), key=lambda x: x[1])
for hp, loss in sorted_results:
    print(f"best loss: {loss:.2f}\t\tLearning rate: {hp[0]:.3e}, Batch size: {hp[1]}, Momentum: {hp[2]:.3e}")

best loss: 438.59		Learning rate: 6.952e-05, Batch size: 2, Momentum: 8.413e-01
best loss: 438.59		Learning rate: 1.833e+00, Batch size: 2, Momentum: 9.464e-01
best loss: 438.59		Learning rate: 6.158e-02, Batch size: 2, Momentum: 5.258e-01
best loss: 443.39		Learning rate: 6.158e-02, Batch size: 4, Momentum: 5.258e-01
best loss: 443.39		Learning rate: 4.833e-03, Batch size: 4, Momentum: 7.361e-01
best loss: 443.39		Learning rate: 5.456e-06, Batch size: 4, Momentum: 5.784e-01
best loss: 443.39		Learning rate: 4.833e-03, Batch size: 4, Momentum: 2.103e-01
best loss: 447.74		Learning rate: 1.833e+00, Batch size: 32, Momentum: 2.629e-01
best loss: 447.74		Learning rate: 1.274e-05, Batch size: 32, Momentum: 5.258e-02
best loss: 450.95		Learning rate: 1.438e-01, Batch size: 47, Momentum: 7.887e-01
best loss: 450.95		Learning rate: 3.793e-04, Batch size: 47, Momentum: 3.155e-01
best loss: 451.99		Learning rate: 7.848e-01, Batch size: 29, Momentum: 9.464e-01
best loss: 451.99		Learning rate: 1

## optimize with best ones

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, value_and_grad



# learning_rate = 1e0
# momentum = 0.5
# batch_size = 4
# best loss: 533.36		Learning rate: 2.976e-05, Batch size: 27, Momentum: 5.258e-02
learning_rate = 2.976e-05
momentum = 5.258e-02
batch_size = 27

cov_mat, cov_mat_initial, mean_vec = run_optimization(patches, momentum, learning_rate, batch_size, eigenvalue_floor=1e-3)


Initial loss:  324759.2031881873
Iteration 350, Loss: 560.0870344734547

KeyboardInterrupt: 

# TODO draw before after samples from optimized covariance mat