In [18]:
import numpy as np
import torch
from torch import optim
import matplotlib.pyplot as plt
import scipy.optimize as so
from cryoBIFE import neglogpost_cryobife, neglogpost_cryobife_pytorch, normal_prior
from cryoBIFE.Generate_Gaussian_Images_Toymodel import sample_grid_data, get_num_images
from tqdm import tqdm

In [19]:
def get_optimal_string_fe(posterior_matrix, log_prior_fxn, kappa=1.0):
    number_of_nodes = posterior_matrix.shape[1]
    G_init = 2.0 * np.random.randn(number_of_nodes)  # Initial free energy differences
    G_op = so.minimize(neglogpost_cryobife, G_init, method='L-BFGS-B', args=(kappa, posterior_matrix, log_prior_fxn))
#     assert(G_op.success)
    optimal_string_fe = G_op.x
    optimal_logpost = -1 * G_op.fun
    return optimal_string_fe, optimal_logpost


def get_optimal_string_fe_pytorch(posterior_matrix, x_init=None, kappa=1.0, num_iter = 10000, 
                                  lr=1e-3, print_freq=1000):
    number_of_nodes = posterior_matrix.shape[1]
    if x_init is None:
        x_init = 0.1* torch.randn(number_of_nodes)  # Initial free energy differences
    x  = torch.clone(x_init)
    x.requires_grad_()
    optimizer = optim.SGD([x], lr=lr)
    
    for i in range(num_iter):
        optimizer.zero_grad()
        output = neglogpost_cryobife_pytorch(x, kappa, posterior_matrix)
        output.backward()
        optimizer.step()
        if i % print_freq == 0:
            print(i, output)
    
    optimal_logpost = -1 * x
    output = neglogpost_cryobife_pytorch(x, kappa, posterior_matrix)
    output.backward()
    print(x.grad)
    
    return x.detach(), optimal_logpost



def histogram_raw_data(coords):
    y = coords[:, 0]
    x = coords[:, 1]
    print(np.min(x), np.max(x))
    print(np.min(y), np.max(y))

    hist, xedges, yedges = np.histogram2d(x, y, bins=100, range=((-5, 25), (-5, 25)))

    plt.imshow(hist.T, interpolation='nearest', origin='lower',
               extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]])

    plt.show()

In [20]:
kappa = 0.0
dT = 1.

In [21]:
black_logposts = []
orange_logposts = []

black_fes = []
orange_fes = []

N_total = 10000

inv_T = 1.
grid_info, black_info, orange_info = sample_grid_data(inverse_T=inv_T, N_total = N_total, sigma=1.0)

In [22]:
coords, Grid, Num_images = grid_info
#     print(len(coords))
black, Post_Matrix_black = black_info
#     orange, Post_Matrix_orange = orange_info
black_fe, black_logpost = get_optimal_string_fe(Post_Matrix_black, normal_prior, kappa=kappa)
black_fe = torch.from_numpy(black_fe).float()
print(black_logpost)

-94409.15429669988


In [23]:
# Further optimize using pytorch... around 80% of sampled grids
# I immediately get NaN and I don't know why.
black_fe, black_logpost = get_optimal_string_fe_pytorch(torch.from_numpy(Post_Matrix_black).float(), 
                                                        x_init = black_fe,
                                                        kappa=kappa,
                                                        lr=1e-3)
black_logposts.append(black_logpost)
black_fes.append(black_fe)
print(black_logpost)

0 tensor(94409.1641, grad_fn=<NegBackward>)
1000 tensor(nan, grad_fn=<NegBackward>)
2000 tensor(nan, grad_fn=<NegBackward>)
3000 tensor(nan, grad_fn=<NegBackward>)
4000 tensor(nan, grad_fn=<NegBackward>)
5000 tensor(nan, grad_fn=<NegBackward>)
6000 tensor(nan, grad_fn=<NegBackward>)
7000 tensor(nan, grad_fn=<NegBackward>)
8000 tensor(nan, grad_fn=<NegBackward>)
9000 tensor(nan, grad_fn=<NegBackward>)
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
       grad_fn=<MulBackward0>)
