In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
%cd ..

/scratch/km817/iREC


In [4]:
import torch
import torch.distributions as dist
import math
import matplotlib.pyplot as plt
import pickle as pkl

from tqdm.notebook import tqdm
from rec.beamsearch.distributions.CodingSampler import CodingSampler
from rec.beamsearch.distributions.EmpiricalMixturePosterior import EmpiricalMixturePosterior
from rec.beamsearch.samplers.GreedySampling import GreedySampler
from rec.beamsearch.Coders.Encoder import Encoder as Empirical_Encoder
from models.SimpleBayesianLinRegressor import BayesLinRegressor
from rec.utils import kl_estimate_with_mc, plot_samples_in_2d, plot_running_sum_2d, plot_pairs_of_samples, compute_variational_posterior
from rec.OptimisingVars.FinalJointOptimiser import FinalJointOptimiser

In [5]:
torch.set_default_tensor_type(torch.DoubleTensor)

In [6]:
def encode_sample(target, omega=5, epsilon=0.,
                  n_empirical_samples=10, seed=0, beamwidth=1, optimising_vars=False, aux_vars=None, dont_run=False):
    
    encoder = Empirical_Encoder(target,
                                seed,
                                CodingSampler,
                                GreedySampler,
                                EmpiricalMixturePosterior,
                                omega,
                                n_empirical_samples,
                                epsilon=epsilon,
                                beamwidth=beamwidth
                                )
    if aux_vars is not None:
        encoder.auxiliary_posterior.coding_sampler.auxiliary_vars = aux_vars
    
    if dont_run:
        return encoder
    else:
        return encoder, *encoder.run_encoder()

In [7]:
def create_blr_problem(dim, seed):
    signal_std = 1e-1
    prior_alpha = 1
    num_training = dim
    torch.set_default_tensor_type(torch.DoubleTensor)
    blr = BayesLinRegressor(prior_mean=torch.zeros(dim),
                            prior_alpha=prior_alpha,
                            signal_std=signal_std,
                            num_targets=2 * num_training,
                            seed=seed,
                            num_train_points=num_training)
    blr.sample_feature_inputs()
    blr.sample_regression_targets()
    blr.posterior_update()
    target = blr.weight_posterior
    return blr, target

In [41]:
dim = 50
beamwidth = 5
omega = 5
blr_seed = 1
b, t = create_blr_problem(dim=dim, seed=blr_seed)
num_compressed_samples = 500

In [42]:
prior = dist.MultivariateNormal(loc=torch.zeros_like(t.mean), covariance_matrix=torch.eye(dim))

emp_kl = dist.kl_divergence(t, prior)
var_kl = dist.kl_divergence(compute_variational_posterior(t), prior)
print(f"Emp kl: {emp_kl.item():.2f}, Var kl: {var_kl.item()}")

Emp kl: 195.36, Var kl: 219.0374861579449


In [None]:
aux_vars = None
z_sample = t.mean
compute_params_enc = encode_sample(target=t, dont_run=True)
n_auxiliaries = compute_params_enc.n_auxiliary
print(n_auxiliaries)
kl_q_p = compute_params_enc.total_kl
optimising = FinalJointOptimiser(z_sample, omega, n_auxiliaries, kl_q_p, n_trajectories=512, total_var=1)
aux_vars = optimising.run_optimiser(epochs=2500)
pkl.dump(aux_vars, open(f"PickledStuff/BLR_RESULTS_v2/Dim{dim}/optimised_vars_emp.pkl", "wb"))

  0%|          | 0/2500 [00:00<?, ?it/s]

40


The mean loss is 24.98242. The mean KL is: 5.18734:  13%|█▎        | 322/2500 [08:42<58:44,  1.62s/it]   

In [None]:
def compute_expected_coding_efficiency(kl, epsilon):
    K = (1 + epsilon) * kl
    return K + torch.log(K + 1) + 1

In [None]:
epsilons = [-0.3, -0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4]
determine_epsilons = torch.tensor(epsilons)

In [None]:
compute_expected_coding_efficiency(emp_kl, determine_epsilons)

In [None]:
compute_expected_coding_efficiency(var_kl, determine_epsilons)

In [None]:
# aux_vars = pkl.load(open(f"PickledStuff/BLR_RESULTS/Dim{dim}/optimised_vars_emp.pkl", "rb"))

In [None]:
torch.manual_seed(0)
seeds = torch.randint(low = 0, high = int(1e6), size=(num_compressed_samples,))
for eps in epsilons:
    exp_dict = {}
    exp_dict['seeds'] = seeds.numpy()
    exp_dict['target_mean'] = t.mean.numpy()
    exp_dict['target_covar'] = t.covariance_matrix.numpy()
    exp_dict['compressed_samples'] = []
    exp_dict['compressed_samples_idxs'] = []
    exp_dict['aux_vars'] = aux_vars
    pbar = tqdm(enumerate(seeds), total=num_compressed_samples)
    log_probs = torch.zeros([0])
    for i, s in pbar:
        enc, z, idx = encode_sample(target=t, beamwidth=beamwidth, epsilon=eps, omega=omega, 
                                    seed=s, n_empirical_samples=50, aux_vars=aux_vars)
        idxs_to_transmit = idx[0]
        best_sample = z[0]
        log_probs = torch.cat((log_probs, t.log_prob(best_sample)[None]))
        exp_dict['compressed_samples'].append(best_sample.numpy())
        exp_dict['compressed_samples_idxs'].append(idxs_to_transmit.numpy())
        pbar.set_description(f"Coded sample {i + 1}, has log prob of {t.log_prob(best_sample)}")
    
    print(torch.mean(log_probs))
    with open(f"PickledStuff/BLR_RESULTS_v2/Dim{dim}/Empirical_Epsilon{eps}_Beam{beamwidth}_Omega{omega}.pkl", "wb") as f:
        pkl.dump(exp_dict, f)

In [None]:
t.log_prob(t.sample((100000,))).mean()

In [None]:
print('done')