# Two Modes

In [None]:
import torch
import seaborn as sns
from torch import distributions

from torch import nn, optim
import math
from tqdm import tqdm

from pyro import distributions as dist
from pyro.distributions import transforms as T
import matplotlib.pyplot as plt

In [None]:
theta_mu1 = -0.75
theta_mu2 = 0.75
theta_std=0.5
grid = torch.linspace(-5, 5, 1024)

In [None]:
logp = distributions.Normal(torch.tensor([theta_mu1, theta_mu2]), torch.tensor(theta_std)).log_prob(grid[:, None]).logsumexp(-1)

In [None]:
plt.plot(grid, logp.exp())

In [None]:
theta = torch.cat([theta_mu1 + torch.randn(512) * theta_std, theta_mu2 + torch.randn(512) * theta_std])

In [None]:
sns.kdeplot(theta)

In [None]:
x = theta.abs() + torch.randn(1024) * 0.1

In [None]:
sns.kdeplot(x)

In [None]:
def logprob_x_theta(x, theta):
    return distributions.Uniform(-5, 5).log_prob(theta) + distributions.Normal(theta.abs(), .1).log_prob(x)

In [None]:
@torch.no_grad()
def rejection_sampling(fn, rng, num_samples=1024, batch_size=32768, max_log_prob=0, n_dims=1):
    z = torch.Tensor()
    i = 0
    while len(z) < num_samples:
        eps = torch.rand((num_samples, n_dims))
        min, max = rng
        z_ = eps * (max - min) + min
        prob = torch.rand(num_samples)

        logprob = fn(z_).squeeze()
        prob_ = torch.exp(logprob - max_log_prob)
        accept = prob_ > prob
        z = torch.cat([z, z_[accept, :]])
        i += 1
    return z[:num_samples]

In [None]:
x0 = x[0]

In [None]:
p_theta_x0 = rejection_sampling(lambda theta: logprob_x_theta(x0, theta), (-5, 5), max_log_prob=1.2, num_samples=128)

In [None]:
sns.kdeplot(p_theta_x0.squeeze())

In [None]:
theta_samples = []
for x0 in tqdm(x):
#     print(x0)
    p_theta_x0 = rejection_sampling(lambda theta: logprob_x_theta(x0, theta), (-5, 5), max_log_prob=1.2, num_samples=128)
    theta_samples.append(p_theta_x0)

In [None]:
base_dist = dist.Normal(torch.zeros(1), torch.ones(1))
transform = [T.planar(1) for _ in range(4)]
composed_transform = T.ComposeTransformModule(transform)

transformed_dist = dist.TransformedDistribution(base_dist, composed_transform.inv)
optimizer = torch.optim.Adam(composed_transform.parameters(), lr=1e-2)


In [None]:
steps = 1024

for step in range(steps):
    optimizer.zero_grad()
    loss = -transformed_dist.log_prob(theta[:, None]).mean()
    loss.backward()
    optimizer.step()
    
    for t in transformed_dist.transforms:
        t.clear_cache()
        
    if step % 32 == 0:
        print(loss.item())
        


In [None]:
with torch.no_grad():
    plt.plot(grid, transformed_dist.log_prob(grid[:, None]).exp())

In [None]:
theta_samples = torch.stack(theta_samples)

In [None]:
theta_samples.shape

In [None]:
base_dist = dist.Normal(torch.zeros(1), torch.ones(1))
transform = [T.planar(1) for _ in range(4
                                       )]
composed_transform = T.ComposeTransformModule(transform)

transformed_dist = dist.TransformedDistribution(base_dist, composed_transform.inv)


In [None]:
optimizer = torch.optim.Adam(composed_transform.parameters(), lr=1e-2)


In [None]:
sns.kdeplot(theta.squeeze())

In [None]:
steps = 131072

for step in range(steps):
    optimizer.zero_grad()
    theta_batch = theta_samples[:, torch.randint(0, 128, (1,))]
    logprob = transformed_dist.log_prob(theta_batch.view(-1, 1)).view(theta_batch.shape)
    
    # Uniform prior so don't need to divide.
    logprob = torch.logsumexp(logprob, dim=1)
    logprob = logprob.mean()
    loss = -logprob
    loss.backward()
    optimizer.step()
    
    for t in transformed_dist.transforms:
        t.clear_cache()
        
    if step % 128 == 0:
        print(step, loss.item())

        with torch.no_grad():
#             plt.figure(facecolor='white')
            fig, ax1 = plt.subplots(facecolor='white')
            ax2 = ax1.twinx()
#             ax1.plot(x, y1, 'g-')
#             ax2.plot(x, y2, 'b-')

            ax1.plot(grid, transformed_dist.log_prob(grid[:, None]).exp(), label=r'$p(\theta | \mathbf{w})$')
            ax1.plot(grid, logp.exp() / 2, label='Ground Truth', color='black')
            sns.kdeplot(theta_samples.mean(1).squeeze(), ax=ax2, color='red', label='Average', alpha=0.5)
            ax1.set_ylabel(r'$p(\theta|\mathbf{X})$', fontsize=22)
#             plt.ylabel(r'$p(\theta|\mathbf{X})$', fontsize=22)
            ax1.set_xlabel(r'$\theta$', fontsize=22)

#             ax1.yticks(fontsize=16)
            ax1.tick_params(labelsize=16)
            ax2.tick_params(labelsize=16)
            plt.xticks(fontsize=16)
            ax1.legend()
            
            ax2.legend(loc='lower right')

            plt.tight_layout()
            plt.savefig(f"toy_result_{step}.pdf")
            plt.show()


In [None]:
sns.kdeplot(theta_samples.mean(1).squeeze())

In [None]:
plt.figure(facecolor='white')

for i in range(32):
    alpha = 1 if i in (0, 5, 4) else 0.1
    label =r'$\theta \sim p(\theta|x_i)$)' if i ==0 else None
    sns.kdeplot(theta_samples[i, :, 0], c='black', alpha=alpha, label=label)
plt.legend()
plt.ylabel(r'$p(\theta|\mathbf{x}_i)$', fontsize=22)
plt.xlabel(r'$\theta$', fontsize=22)
plt.yticks(fontsize=16)
plt.xticks(fontsize=16)

plt.tight_layout()
plt.savefig('p_theta_x.pdf', bbox_inches='tight')
# plt.title(r"Posteriors $p(\theta|\mathbf{x}_i)$ for $i \in \{1, \dots, N\}$")