In [1]:
import INN
import torch
import torch.nn as nn

In [9]:
model = INN.Sequential(INN.Nonlinear(dim=3, method='NICE'),
                       INN.BatchNorm1d(3),
                       INN.Nonlinear(dim=3, method='RealNVP'))
model.eval()
''

''

In [2]:
model = INN.BatchNorm1d(3)#INN.Nonlinear(dim=3, method='iResNet', num_n=100, num_iter=50)

In [3]:
def linear_Jacobian_matrix(model, x):
    batch_size, dim = x.shape
    x.requires_grad = True
    model.computing_p(True)
    y, log_p, log_det = model(x)
    
    grad_list = []
    for i in range(dim):
        v = torch.zeros((batch_size, dim))
        v[:, i] = 1
        grad = INN.utilities.vjp(y, x, v)[0]
        grad_list.append(grad.detach())
    return torch.stack(grad_list, dim=1), log_det

In [4]:
def Jacobian_matrix(model, x):
    shape = x.shape
    dim = int(torch.prod(torch.Tensor(list(x.shape))).item())
    repeats = [dim]
    for i in range(len(x.shape)):
        repeats.append(1)
    
    x_hat = x.unsqueeze(0).repeat(tuple(repeats))
    x_hat.requires_grad = True
    model.computing_p(True)
    y, log_p, log_det = model(x_hat)
    
    v = torch.diag(torch.ones(dim)).reshape((dim, *x.shape))
    grad = INN.utilities.vjp(y, x_hat, v)[0]
    
    return grad.detach(), log_det.detach()

In [5]:
x = torch.randn(3)

In [10]:
model.eps = 1e-5
J, log_det = Jacobian_matrix(model, x)
J

tensor([[ 1.0207,  0.7147,  0.1159],
        [ 0.3893,  1.9171, -0.6180],
        [ 0.4169, -0.0570,  1.0998]])

In [11]:
torch.log(torch.abs(torch.det(J)))

tensor(0.4258)

In [12]:
torch.mean(log_det)

tensor(0.4258)

In [5]:
x = torch.randn((6, 3))

In [11]:
model.eval()
Js, log_det = linear_Jacobian_matrix(model, x)
real_log_det = torch.log(torch.abs(torch.det(Js)))

print(f'J_g={real_log_det},\nJ_c={log_det.detach()}')

J_g=tensor([-1.5020e-05, -1.5020e-05, -1.5020e-05, -1.5020e-05, -1.5020e-05,
        -1.5020e-05]),
J_c=-1.5020295904832892e-05


In [12]:
Js[0]

tensor([[1.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000],
        [0.0000, 0.0000, 1.0000]])

In [46]:
torch.sum(-1 * torch.log(torch.var(x, dim=0, unbiased=False) + model.eps) / 2)

tensor(0.7945, grad_fn=<SumBackward0>)

## Bug lists

1. `INN.BatchNorm1d()` fails on Jacobian tests [fixed]
2. `INN.iResNet()` has large differece to the ground-truth!

In [53]:
model = nn.BatchNorm1d(3, affine=False)

In [54]:
model.running_mean

tensor([0., 0., 0.])

In [55]:
model(x)

tensor([[-0.3725, -0.3559, -1.4208],
        [ 0.7690,  0.1532,  0.9104],
        [ 0.9525, -0.5532, -0.3847],
        [-2.0273,  0.2497,  1.6467],
        [ 0.4455,  1.8891, -0.3031],
        [ 0.2327, -1.3830, -0.4485]], grad_fn=<NativeBatchNormBackward>)

In [60]:
var = torch.var(x, dim=0, unbiased=False)
mean = torch.mean(x, dim=0)

(x - mean) / torch.sqrt(var + model.eps)

tensor([[-0.3725, -0.3559, -1.4208],
        [ 0.7690,  0.1532,  0.9104],
        [ 0.9525, -0.5532, -0.3847],
        [-2.0273,  0.2497,  1.6467],
        [ 0.4455,  1.8891, -0.3031],
        [ 0.2327, -1.3830, -0.4485]], grad_fn=<DivBackward0>)

In [19]:
x = torch.randn((5, 3))
bn = nn.BatchNorm1d(3)

In [20]:
bn(x)

tensor([[-1.1245,  1.3747, -0.5232],
        [ 1.1031,  0.5880,  0.0641],
        [-0.8741,  0.0202, -1.3274],
        [-0.3750, -0.3597,  0.0675],
        [ 1.2705, -1.6232,  1.7191]], grad_fn=<NativeBatchNormBackward>)