# Normalizing flows
Here is some of our work for examining how normalizing flows adapt to differently-shaped priors.

In [1]:
import torch
import torch.nn as nn
import FrEIA.modules as Fm
import FrEIA.framework as Ff
import numpy as np

In [2]:
device = 'cuda:2'

In [3]:
true_dist = torch.distributions.MultivariateNormal(torch.zeros([1], device=device), torch.eye(1, device=device))

In [4]:
def true_sample(n=1000):
    ret = torch.zeros([n,2])
    ret[:,0] = true_dist.rsample([n]).squeeze()
    return ret

In [62]:
def smoothed_gaussian_density(pts, alpha=0.2):
    probs = 1.5*true_dist.log_prob(pts[:,0].unsqueeze(1)).exp()
    for k, norm in enumerate(pts[:,1]):
        if norm.abs() >= alpha:
            probs[k]=0.
        else:
            probs[k]*= np.exp(1/alpha**2)*(-1/(alpha**2 - norm.pow(2))).exp()
    return probs

def smoothed_circle_density(pts, alpha=0.2):
    probs = torch.zeros_like(pts[:,0])
    norms = pts.norm(dim=1)
    for k, norm in enumerate(norms):
        if (1-norm).abs() < alpha:
            probs[k] = np.exp(1/alpha**2)*(-1/(alpha**2 - (1-norm).pow(2))).exp()
    return probs

In [8]:
t = torch.tensor([[0.4,0],[0.4,0.1],[0.4,0.3]], device=device)
smoothed_density(t)

tensor([5.5241e-01, 1.3278e-04, 0.0000e+00], device='cuda:2')

In [57]:
num_pts = 200
ret = np.zeros([num_pts, num_pts])
pts = []
for i, a in enumerate(np.linspace(-3, 3, num_pts)):
    for j, b in enumerate(np.linspace(-3, 3, num_pts)):
        pts.append(torch.tensor([a,b], device=device, dtype=torch.float))
pts = torch.stack(pts)
arr = smoothed_density(pts).reshape(num_pts, num_pts).cpu().numpy()

In [16]:
%matplotlib widget
import seaborn as sns
sns.heatmap(arr.transpose())

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<AxesSubplot:>

In [87]:
def heatmap(mod, num_pts=170):
    xsize, ysize = num_pts, num_pts
    batch, idxs = [],[]
    outs = []
    for i, a in enumerate(np.linspace(-2, 2, xsize)):
        for j, b in enumerate(np.linspace(-0.4,0.4, ysize)):
            batch.append(torch.tensor([a,b], dtype=torch.float, device=device))
            idxs.append((i,j))
            if len(batch) > 200:
                batch = torch.stack(batch)
                vals, logDs = mod(batch)
                outs.append((vals.detach(), logDs.detach()))
                batch, idxs = [], []
    if len(batch) > 0:
        batch = torch.stack(batch)
        vals, logDs = mod(batch)
        outs.append((vals.detach(), logDs.detach()))
    
    vals = torch.cat([x[0] for x in outs])
    logDs = torch.cat([x[1] for x in outs])
    
    probs = smoothed_circle_density(vals.reshape(-1,2)).reshape(xsize, ysize)
    return (probs*logDs.reshape(xsize, ysize)).cpu().numpy().transpose()

In [82]:
arr = heatmap(lambda x: (x,torch.ones([x.shape[0],1], device=device)))
%matplotlib widget
sns.heatmap(arr)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<AxesSubplot:>

In [83]:
width = 512
def fc_constr(c_in, c_out):
    """Plug-n-play fully connected net for INNs"""
    layers = [nn.Linear(c_in, width), nn.ReLU()]
    for _ in range(5):
        layers.append(nn.Linear(width,  width))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(width,  c_out))

    return nn.Sequential(*layers)

## KL-divergence for samples

In [84]:
def KL_divergence(s1, s2, k=1):
    n, m = len(s1), len(s2)
    D = torch.tensor(m / (n - 1), dtype=torch.float, device=device).log()
    d = torch.tensor(s1.shape[1], device=device).float()

    for pt in s1:
        # Estimate densities using the kth nearest neighbor. Idea from:
        # Qing Wang, Sanjeev R. Kulkarni, and Sergio Verdú. "Divergence estimation for multidimensional densities via k-nearest-neighbor distances." Information Theory, IEEE Transactions on 55.5 (2009): 2392-2405.
        norms = (s2-pt).norm(dim=1).reshape(-1)
        nu = norms[~(norms == 0)].kthvalue(k=k)[0]

        norms = (s1-pt).norm(dim=1).reshape(-1)
        rho = norms[~(norms == 0)].kthvalue(k=k)[0]

        D += (d/n)*(nu/rho).log()
    return D

In [85]:
layers = [Ff.InputNode(2, name='Input')]

for i in range(5):
    layers.append(Ff.Node(
        [layers[-1].out0],
        Fm.RNVPCouplingBlock,#Fm.GINCouplingBlock,
        {'subnet_constructor':fc_constr},
        name=f'GIN {i}')
    )
    layers.append(Ff.Node([layers[-1].out0], Fm.PermuteRandom, {}))

layers.append(Ff.OutputNode([layers[-1].out0], name='Output'))

model = Ff.ReversibleGraphNet(layers, verbose=False)
_ = model.to(device)

In [88]:
import matplotlib.pyplot as plt
plt.ioff()

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-4,
    weight_decay=0
)

for epoch in range(200):
    losses = []
    model.zero_grad()

    angles = 2*np.pi*torch.rand([1000,1], device=device)
    samples = torch.cat([angles.cos(),angles.sin()], dim=1)

    z_hat, _ = model(samples, rev=True)
    
    true = true_sample(1000).to(device)
    
    loss = 0
    loss += KL_divergence(z_hat, true, k=1)
    loss += KL_divergence(true, z_hat, k=1)
    losses.append(loss.item())
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch} loss: {sum(losses)/len(losses)}', end='\r')
    
    plt.close()
    arr = heatmap(model)
    _ = sns.heatmap(arr)
    plt.title(f'Epoch {epoch} density')
    plt.savefig(f'gaussian/circle_flow/epoch-{epoch}.png')
    print()

Epoch 0 loss: 21.70546531677246
Epoch 1 loss: 21.489843368530273
Epoch 2 loss: 20.755199432373047
Epoch 3 loss: 20.114192962646484
Epoch 4 loss: 19.569581985473633
Epoch 5 loss: 18.985801696777344
Epoch 6 loss: 18.285877227783203
Epoch 7 loss: 17.637351989746094
Epoch 8 loss: 16.93558120727539
Epoch 9 loss: 16.593053817749023
Epoch 10 loss: 16.39727020263672
Epoch 11 loss: 15.725844383239746
Epoch 12 loss: 14.936430931091309
Epoch 13 loss: 14.200597763061523
Epoch 14 loss: 13.37991714477539
Epoch 15 loss: 12.699832916259766
Epoch 16 loss: 10.304031372070312
Epoch 17 loss: 9.436872482299805
Epoch 18 loss: 9.209270477294922
Epoch 19 loss: 8.556109428405762
Epoch 20 loss: 7.46096134185791
Epoch 21 loss: 8.385290145874023
Epoch 22 loss: 8.822319030761719
Epoch 23 loss: 9.59246826171875
Epoch 24 loss: 10.271574020385742
Epoch 25 loss: 10.336284637451172
Epoch 26 loss: 9.386771202087402
Epoch 27 loss: 10.301192283630371
Epoch 28 loss: 10.640898704528809
Epoch 29 loss: 9.044719696044922
Epoch

In [132]:
outs = []
for pts, _ in loader:
    pts = torch.stack(pts)
    yhat, _ = model(pts)
    outs.append(yhat.detach().cpu())
outs = torch.cat(outs)

In [133]:
outs.shape

torch.Size([10000, 4])

In [134]:
stds = outs.std(0, unbiased=True)
means = outs.mean(0)

In [135]:
print(means,stds)

tensor([ 0.0778,  0.1543, -0.1162, -0.0541]) tensor([1.0456e+00, 1.1054e+00, 5.9175e-05, 7.0614e-05])


In [136]:
samples = stds*torch.randn(10000, 4) + means

In [137]:
loader2 = make_dataloader(samples, torch.zeros_like(circle), 100)

In [138]:
outs = []
for pts, _ in loader2:
    pts = torch.stack(pts)
    yhat, _ = model(pts, rev=True)
    outs.append(yhat.detach().cpu())
outs = torch.cat(outs)

In [144]:
import matplotlib.pyplot as plt
%matplotlib widget

plt.scatter(outs[:,0], outs[:,1], s=5)
plt.xlim([-1.5,1.5])
plt.ylim([-1.5,1.5])
plt.title('Generated data from Gaussian priors (n=10,000)')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, 'Generated data from Gaussian priors (n=10,000)')

In [145]:
%matplotlib widget
import seaborn as sns
sns.histplot(outs[:,:2].norm(dim=1))
plt.title('Distances from zero')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, 'Distances from zero')