In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
%cd ..

/scratch/km817/New_iREC/iREC


In [69]:
import torch
import torch.distributions as dist
import math
import matplotlib.pyplot as plt
import pickle as pkl
from tqdm.notebook import tqdm
from rec.beamsearch.distributions.CodingSampler import CodingSampler
from rec.beamsearch.distributions.VariationalPosterior import VariationalPosterior
from rec.OptimisingVars.VariationalOptimiser import VariationalOptimiser
from rec.beamsearch.samplers.GreedySampling import GreedySampler
from rec.beamsearch.Coders.Encoder_Variational import Encoder as Variational_Encoder
from models.SimpleBayesianLinRegressor import BayesLinRegressor
from rec.utils import kl_estimate_with_mc, compute_variational_posterior, plot_samples_in_2d, plot_running_sum_2d, plot_pairs_of_samples

In [70]:
torch.set_default_tensor_type(torch.DoubleTensor)

In [71]:
def encode_sample(target, omega=5, epsilon=0.,
                  n_empirical_samples=10, seed=0, beamwidth=1, optimising_vars=False, aux_vars=None, dont_run=False):
    
    target = compute_variational_posterior(target)
    encoder = Variational_Encoder(target,
                                  seed,
                                  CodingSampler,
                                  GreedySampler,
                                  VariationalPosterior,
                                  omega,
                                  epsilon=epsilon,
                                  beamwidth=beamwidth
                                  )
    if aux_vars is not None:
        encoder.auxiliary_posterior.coding_sampler.auxiliary_vars = aux_vars
    
    if dont_run:
        return encoder
    else:
        return encoder, *encoder.run_encoder()

In [75]:
def create_blr_problem(dim, seed):
    if dim in [2, 5]:
        signal_std = 1e-3
    else:
        signal_std = 1e-1
    torch.set_default_tensor_type(torch.DoubleTensor)
    blr = BayesLinRegressor(prior_mean=torch.zeros(dim),
                            prior_alpha=1,
                            signal_std=signal_std,
                            num_targets=100,
                            seed=1)
    blr.sample_feature_inputs()
    blr.sample_regression_targets()
    blr.posterior_update()
    target = blr.weight_posterior
    return blr, target

In [118]:
dim = 90
omega = 5
blr_seed = 1
b, t = create_blr_problem(dim=dim, seed=blr_seed)
num_compressed_samples = 50

In [119]:
torch.set_default_tensor_type(torch.DoubleTensor)
compute_params_enc = encode_sample(target=t, dont_run=True)
n_auxiliaries = compute_params_enc.n_auxiliary
kl_q_p = compute_params_enc.total_kl
optimising = VariationalOptimiser(compute_params_enc.target, omega, n_auxiliaries, kl_q_p, n_trajectories=50, total_var=1)
aux_vars = optimising.run_optimiser(epochs=500)
pkl.dump(aux_vars, open(f"PickledStuff/Correlated/Dim{dim}/optimised_vars_var.pkl", "wb"))

The mean loss is 9.98281. The mean KL is: 5.16202: 100%|██████████| 500/500 [41:08<00:00,  4.94s/it]   


In [145]:
dim = 90
omega = 5
blr_seed = 1
b, t = create_blr_problem(dim=dim, seed=blr_seed)
num_compressed_samples = 50

In [146]:
beamwidth = 1
torch.manual_seed(0)
seeds = torch.randint(low = 0, high = int(1e6), size=(num_compressed_samples,))
epsilons = [0., 0.05, 0.1, 0.15, 0.2]
aux_vars = torch.tensor(pkl.load(open(f"PickledStuff/Correlated/Dim{dim}/approx_vars_var.pkl", "rb")))
for eps in epsilons:
    exp_dict = {}
    exp_dict['seeds'] = seeds.numpy()
    exp_dict['target_mean'] = t.mean.numpy()
    exp_dict['target_covar'] = t.covariance_matrix.numpy()
    exp_dict['compressed_samples'] = []
    exp_dict['compressed_samples_idxs'] = []
    exp_dict['aux_vars'] = aux_vars
    pbar = tqdm(enumerate(seeds), total=num_compressed_samples)
    log_probs = torch.zeros([0])
    for i, s in pbar:
        enc, z, idx = encode_sample(target=t, beamwidth=beamwidth, epsilon=eps, omega=omega, 
                                    seed=s, n_empirical_samples=50, aux_vars=aux_vars)
        idxs_to_transmit = idx[0]
        best_sample = z[0]
        log_probs = torch.cat((log_probs, t.log_prob(best_sample)[None]))
        exp_dict['compressed_samples'].append(best_sample.numpy())
        exp_dict['compressed_samples_idxs'].append(idxs_to_transmit.numpy())
        pbar.set_description(f"Coded sample {i + 1}, has log prob of {t.log_prob(best_sample)}")
    
    print(torch.mean(log_probs))
    with open(f"PickledStuff/CorrelatedApprox/Dim{dim}/Variational_Epsilon{eps}_Beam{beamwidth}_Omega{omega}.pkl", "wb") as f:
        pkl.dump(exp_dict, f)

  0%|          | 0/50 [00:00<?, ?it/s]

tensor(-3902.1856)


  0%|          | 0/50 [00:00<?, ?it/s]

tensor(-1880.0473)


  0%|          | 0/50 [00:00<?, ?it/s]

tensor(-948.4976)


  0%|          | 0/50 [00:00<?, ?it/s]

tensor(-306.2101)


  0%|          | 0/50 [00:00<?, ?it/s]

tensor(62.4933)


In [149]:
t.log_prob(t.sample((500,)))

tensor([237.7575, 235.1231, 225.8538, 245.3285, 237.4480, 243.0981, 229.5325,
        219.9356, 235.5261, 236.4867, 240.6794, 239.4462, 215.9359, 232.9367,
        237.6035, 224.7586, 249.4764, 231.9974, 233.7869, 221.6480, 226.0308,
        252.9307, 234.8585, 232.6449, 231.7204, 233.2937, 233.4280, 241.1942,
        231.6637, 244.1568, 238.8343, 241.2729, 234.4596, 239.2966, 238.2857,
        233.3617, 226.7123, 236.3445, 230.7741, 236.0703, 213.8009, 239.9170,
        229.2661, 235.0919, 232.1030, 219.1356, 227.1950, 235.2411, 228.4886,
        230.4157, 236.6605, 238.4412, 234.2439, 224.3609, 231.7023, 237.5311,
        230.9115, 224.1803, 232.9024, 228.1904, 238.3929, 216.5781, 233.8996,
        231.7972, 227.0916, 233.7885, 218.9938, 230.5910, 234.2969, 235.9528,
        235.1866, 226.6141, 228.1459, 233.1242, 237.8167, 228.4055, 227.2538,
        239.3667, 237.0315, 227.9855, 224.9064, 234.0480, 238.2546, 226.1197,
        233.4401, 220.7370, 234.0281, 237.8153, 244.2737, 232.65