In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
%cd ..

/scratch/km817/REC/iREC


In [4]:
import torch
import torch.distributions as dist
import math
import matplotlib.pyplot as plt
import pickle as pkl

from tqdm.notebook import tqdm
from rec.beamsearch.distributions.CodingSampler import CodingSampler
from rec.beamsearch.distributions.EmpiricalMixturePosterior import EmpiricalMixturePosterior
from rec.beamsearch.samplers.GreedySampling import GreedySampler
from rec.beamsearch.Coders.Encoder import Encoder as Empirical_Encoder
from models.BayesianLinRegressor import BayesLinRegressor
from rec.utils import kl_estimate_with_mc, plot_samples_in_2d, plot_running_sum_2d, plot_pairs_of_samples
from rec.OptimisingVars.FinalJointOptimiser import FinalJointOptimiser

In [5]:
def encode_sample(target, omega=5, epsilon=0.,
                  n_empirical_samples=10, seed=0, beamwidth=1, optimising_vars=False, aux_vars=None, dont_run=False):
    
    encoder = Empirical_Encoder(target,
                                seed,
                                CodingSampler,
                                GreedySampler,
                                EmpiricalMixturePosterior,
                                omega,
                                n_empirical_samples,
                                epsilon=epsilon,
                                beamwidth=beamwidth
                                )
    if aux_vars is not None:
        encoder.auxiliary_posterior.coding_sampler.auxiliary_vars = aux_vars
    
    if dont_run:
        return encoder
    else:
        return encoder, *encoder.run_encoder()

In [6]:
def create_blr_problem(dim, seed):
    
    initial_seed_target = seed
    blr = BayesLinRegressor(prior_mean=torch.zeros(dim),
                        prior_alpha=1,
                        signal_std=1,
                        num_targets=10000,
                        seed=initial_seed_target)
    blr.sample_feature_inputs()
    blr.sample_regression_targets()
    blr.posterior_update()
    target = blr.weight_posterior
    return blr, target

In [7]:
dim = 5
beamwidth = 1
omega = 5
blr_seed = 1
b, t = create_blr_problem(dim=dim, seed=blr_seed)
num_compressed_samples = 50

In [None]:
z_sample = t.mean
compute_params_enc = encode_sample(target=t, dont_run=True)
n_auxiliaries = compute_params_enc.n_auxiliary
kl_q_p = compute_params_enc.total_kl
optimising = FinalJointOptimiser(z_sample, omega, n_auxiliaries, kl_q_p, n_trajectories=50, total_var=1)
aux_vars = optimising.run_optimiser()
pkl.dump(aux_vars, open(f"PickledStuff/Optimising_Vars/Dim{dim}/optimised_vars_emp.pkl", "wb"))

In [24]:
aux_vars

tensor([1.9289e-01, 1.5763e-01, 1.2632e-01, 1.0058e-01, 8.0449e-02, 6.4419e-02,
        5.1856e-02, 4.1906e-02, 3.4024e-02, 2.7578e-02, 2.2377e-02, 1.8130e-02,
        1.4668e-02, 1.1969e-02, 9.9020e-03, 8.2610e-03, 6.8251e-03, 5.5668e-03,
        4.5170e-03, 3.6594e-03, 2.9773e-03, 2.4206e-03, 1.9879e-03, 1.6337e-03,
        1.3359e-03, 1.1012e-03, 9.0214e-04, 7.4210e-04, 6.1020e-04, 4.9781e-04,
        4.2361e-04, 3.5228e-04, 2.8650e-04, 2.2831e-04, 1.8004e-04, 1.4104e-04,
        1.1112e-04, 8.8751e-05, 7.2240e-05, 6.0294e-05, 5.1803e-05, 4.6251e-05,
        4.3582e-05, 4.5084e-05, 1.2711e-04])

In [15]:
dim = 50
beamwidth = 20
omega = 5
blr_seed = 1
b, t = create_blr_problem(dim=dim, seed=blr_seed)
num_compressed_samples = 50

In [18]:
torch.manual_seed(0)
seeds = torch.randint(low = 0, high = int(1e6), size=(num_compressed_samples,))
epsilons = [0., 0.05, 0.1, 0.15, 0.2]

aux_vars = pkl.load(open(f"PickledStuff/Optimising_Vars/Dim{dim}/optimised_vars_emp.pkl", "rb"))
for eps in epsilons[:3]:
    exp_dict = {}
    exp_dict['seeds'] = seeds.numpy()
    exp_dict['target_mean'] = t.mean.numpy()
    exp_dict['target_covar'] = t.covariance_matrix.numpy()
    exp_dict['compressed_samples'] = []
    exp_dict['compressed_samples_idxs'] = []
    exp_dict['aux_vars'] = aux_vars
    pbar = tqdm(enumerate(seeds), total=num_compressed_samples)
    log_probs = torch.zeros([0])
    for i, s in pbar:
        enc, z, idx = encode_sample(target=t, beamwidth=beamwidth, epsilon=eps, omega=omega, 
                                    seed=s, n_empirical_samples=50, aux_vars=aux_vars)
        idxs_to_transmit = idx[0]
        best_sample = z[0]
        log_probs = torch.cat((log_probs, t.log_prob(best_sample)[None]))
        exp_dict['compressed_samples'].append(best_sample.numpy())
        exp_dict['compressed_samples_idxs'].append(idxs_to_transmit.numpy())
        pbar.set_description(f"Coded sample {i + 1}, has log prob of {t.log_prob(best_sample)}")
    
    print(torch.mean(log_probs))
    with open(f"PickledStuff/Optimising_Vars/Dim{dim}/Empirical_Epsilon{eps}_Beam{beamwidth}_Omega{omega}.pkl", "wb") as f:
        pkl.dump(exp_dict, f)

  0%|          | 0/50 [00:00<?, ?it/s]

tensor(-1083.8163)


  0%|          | 0/50 [00:00<?, ?it/s]

tensor(-567.1396)


  0%|          | 0/50 [00:00<?, ?it/s]

tensor(-195.5009)


In [19]:
print('done')

done
