In [1]:
import INN
import torch
import torch.nn as nn

In [9]:
model = INN.Sequential(INN.Nonlinear(dim=3, method='NICE'),
                       INN.BatchNorm1d(3),
                       INN.Nonlinear(dim=3, method='RealNVP'))
model.eval()
''

''

In [2]:
model = INN.BatchNorm1d(3)#INN.Nonlinear(dim=3, method='iResNet', num_n=100, num_iter=50)

In [3]:
def linear_Jacobian_matrix(model, x):
    batch_size, dim = x.shape
    x.requires_grad = True
    model.computing_p(True)
    y, log_p, log_det = model(x)
    
    grad_list = []
    for i in range(dim):
        v = torch.zeros((batch_size, dim))
        v[:, i] = 1
        grad = INN.utilities.vjp(y, x, v)[0]
        grad_list.append(grad.detach())
    return torch.stack(grad_list, dim=1), log_det

In [4]:
def Jacobian_matrix(model, x):
    shape = x.shape
    dim = int(torch.prod(torch.Tensor(list(x.shape))).item())
    repeats = [dim]
    for i in range(len(x.shape)):
        repeats.append(1)
    
    x_hat = x.unsqueeze(0).repeat(tuple(repeats))
    x_hat.requires_grad = True
    model.computing_p(True)
    y, log_p, log_det = model(x_hat)
    
    v = torch.diag(torch.ones(dim)).reshape((dim, *x.shape))
    grad = INN.utilities.vjp(y, x_hat, v)[0]
    
    return grad.detach(), log_det.detach()

In [5]:
x = torch.randn(3)

In [10]:
model.eps = 1e-5
J, log_det = Jacobian_matrix(model, x)
J

tensor([[ 1.0207,  0.7147,  0.1159],
        [ 0.3893,  1.9171, -0.6180],
        [ 0.4169, -0.0570,  1.0998]])

In [11]:
torch.log(torch.abs(torch.det(J)))

tensor(0.4258)

In [12]:
torch.mean(log_det)

tensor(0.4258)

In [14]:
x = torch.randn((6, 3))

In [15]:
model.eval()
Js, log_det = linear_Jacobian_matrix(model, x)
real_log_det = torch.log(torch.abs(torch.det(Js)))

print(f'J_g={real_log_det},\nJ_c={log_det.detach()}')

J_g=tensor([ 1.9498, -0.0728,  1.3574,  0.7588,  1.3956,  0.7595]),
J_c=tensor([ 1.9498, -0.0728,  1.3574,  0.7588,  1.3956,  0.7595])


In [16]:
Js[0]

tensor([[ 1.9760, -1.3065,  1.2205],
        [ 0.5567, -0.8171, -1.6935],
        [ 0.0597,  2.1638,  2.1138]])

In [46]:
torch.sum(-1 * torch.log(torch.var(x, dim=0, unbiased=False) + model.eps) / 2)

tensor(0.7945, grad_fn=<SumBackward0>)

## Bug lists

1. `INN.BatchNorm1d()` fails on Jacobian tests [fixed]
2. `INN.iResNet()` has large differece to the ground-truth!

In [53]:
model = nn.BatchNorm1d(3, affine=False)

In [54]:
model.running_mean

tensor([0., 0., 0.])

In [55]:
model(x)

tensor([[-0.3725, -0.3559, -1.4208],
        [ 0.7690,  0.1532,  0.9104],
        [ 0.9525, -0.5532, -0.3847],
        [-2.0273,  0.2497,  1.6467],
        [ 0.4455,  1.8891, -0.3031],
        [ 0.2327, -1.3830, -0.4485]], grad_fn=<NativeBatchNormBackward>)

In [60]:
var = torch.var(x, dim=0, unbiased=False)
mean = torch.mean(x, dim=0)

(x - mean) / torch.sqrt(var + model.eps)

tensor([[-0.3725, -0.3559, -1.4208],
        [ 0.7690,  0.1532,  0.9104],
        [ 0.9525, -0.5532, -0.3847],
        [-2.0273,  0.2497,  1.6467],
        [ 0.4455,  1.8891, -0.3031],
        [ 0.2327, -1.3830, -0.4485]], grad_fn=<DivBackward0>)

In [17]:
x = torch.randn((5, 3))
bn = nn.BatchNorm1d(3, affine=False)

In [18]:
bn(x)

tensor([[-1.6941,  0.2933, -0.2451],
        [-0.1313, -0.2711,  1.4740],
        [ 0.2754, -0.2282,  0.4445],
        [ 0.1287, -1.4409, -0.0721],
        [ 1.4213,  1.6469, -1.6014]])

In [5]:
x = torch.randn((3,3,3))

In [29]:
list(x.shape)

[3, 3, 3]

In [26]:
x.shape = 5

AttributeError: attribute 'shape' of 'torch._C._TensorBase' objects is not writable

In [31]:
[1,2,3,4,5][4:]

[5]

# New Modules

## Conv1d

In [48]:
class _default_1d_coupling_function(nn.Module):
    def __init__(self, channels, kernel_size, activation_fn=nn.ReLU, w=4):
        super(_default_1d_coupling_function, self).__init__()
        if kernel_size % 2 != 1:
            raise ValueError(f'kernel_size must be an odd number, but got {kernel_size}')
        r = kernel_size // 2
        
        self.f = nn.Sequential(nn.Conv1d(channels, channels*w, kernel_size, padding=r),
                               activation_fn(),
                               nn.Conv1d(w*channels, w*channels, kernel_size, padding=r),
                               activation_fn(),
                               nn.Conv1d(w*channels, channels, kernel_size, padding=r)
                              )
    def forward(self, x):
        return self.f(x)

In [100]:
INNModule = INN.INNAbstract.INNModule

class CouplingConv(INNModule):
    '''
    General invertible covolution layer for coupling methods
    '''
    def __init__(self, num_feature, mask=None):
        super(CouplingConv, self).__init__()
        self.num_feature = num_feature
        if mask is None:
            self.mask = self._mask(num_feature)
        else:
            self.mask = mask
    
    def _mask(self, n):
        m = torch.zeros(n)
        m[:(n // 2)] = 1
        return m
    
    def working_mask(self, x):
        '''
        Generate feature mask for 1d inputs
        x.shape = [batch_size, feature, *]
        mask.shape = [1, feature, *(1)]
        '''
        batch_size, feature, *other = x.shape
        mask = self.mask.reshape(1, self.num_feature, *[1] * len(other))
        return mask


class CouplingConv1d(CouplingConv):
    '''
    General 1-d invertible convolution layer for coupling methods
    '''
    def __init__(self, num_feature, mask=None):
        super(CouplingConv1d, self).__init__(num_feature, mask=mask)


class Conv1dNICE(CouplingConv1d):
    '''
    1-d invertible convolution layer by NICE method
    '''
    def __init__(self, channels, kernel_size, w=4, activation_fn=nn.ReLU, m=None, mask=None):
        super(Conv1dNICE, self).__init__(num_feature=channels, mask=mask)
        if m is None:
            self.m = _default_1d_coupling_function(channels, kernel_size, activation_fn, w=w)
        else:
            self.m = m
    
    def forward(self, x):
        mask = self.working_mask(x)
        
        x_ = mask * x
        x = x + (1-mask) * self.m(x_)
        
        x_ = (1-mask) * x
        x = x + mask * self.m(x_)
        return x
    
    def inverse(self, y):
        mask = self.working_mask(y)
        
        y_ = (1-mask) * y
        y = y - mask * self.m(y_)
        
        y_ = mask * y
        y = y - (1-mask) * self.m(y_)
        
        return y
    
    def logdet(self, **args):
        return 0


class Conv1dNVP(CouplingConv1d):
    '''
    1-d invertible convolution layer by NICE method
    TODO: inverse error is too large
    '''
    def __init__(self, channels, kernel_size, w=4, activation_fn=nn.ReLU, s=None, t=None, mask=None):
        super(Conv1dNVP, self).__init__(num_feature=channels, mask=mask)
        if s is None:
            self.log_s = _default_1d_coupling_function(channels, kernel_size, activation_fn, w=w)
        else:
            self.log_s = s
        
        if t is None:
            self.t = _default_1d_coupling_function(channels, kernel_size, activation_fn, w=w)
        else:
            self.t = t
    
    def s(self, x):
        return torch.exp(self.log_s(x))
    
    def forward(self, x):
        mask = self.working_mask(x)
        
        x_ = mask * x
        x = (1-mask) * (self.s(x_) * x + self.t(x_)) + x_
        
        mask = 1 - mask
        x_ = mask * x
        x = (1-mask) * (self.s(x_) * x + self.t(x_)) + x_
        
        return x
    
    def inverse(self, y):
        mask = 1 - self.working_mask(y)
        
        y_ = mask * y
        y = (1-mask) * (y - self.t(y_)) / torch.exp(self.log_s(y_)) + y_
        
        mask = 1 - mask
        y_ = mask * y
        y = (1-mask) * (y - self.t(y_)) / torch.exp(self.log_s(y_)) + y_
        
        return y
    
    def logdet(self, **args):
        return 0

In [2]:
import INN
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
x = torch.randn((3, 5, 8)) * 10

model = INN.Linear1d(5)

In [3]:
model(x)

(tensor([[[-4.5962e-01, -8.3237e-01, -5.2095e+00, -6.3192e+00, -9.2027e+00,
           -2.3169e+01, -2.2931e+00, -1.5648e+00],
          [ 2.1111e+00,  4.1545e+00, -4.4625e+00,  9.5212e+00,  4.8327e+00,
            8.9690e+00, -2.5365e+01, -1.0859e+01],
          [-3.9164e+00,  1.7029e+01,  4.7630e+00,  1.9899e+01, -2.3675e+01,
           -5.6664e+00,  1.8262e+01,  4.9501e+00],
          [ 3.7490e+00, -1.3335e+00,  3.8716e+00,  1.2091e+01, -2.0418e+00,
           -1.1282e+01,  1.1244e+00, -9.5155e-01],
          [ 2.0400e+00,  1.3096e+01, -4.2499e-01,  3.8104e+00, -9.6784e+00,
            1.1031e+01,  6.2131e+00, -1.5430e+01]],
 
         [[ 1.3582e+01, -1.9608e+01, -4.0115e+00,  1.2407e+00, -8.5277e+00,
            7.0188e+00, -7.1900e+00, -1.7199e+00],
          [-5.4620e+00, -9.5898e+00,  5.4368e+00,  2.3484e+00, -1.3124e+01,
           -4.8956e+00, -2.5614e+00,  3.3316e+00],
          [-3.0688e+00, -1.9365e-01, -4.9798e+00,  2.3397e+00,  1.3315e+01,
           -4.1992e-01,  1.6362e

In [42]:
logdet

tensor([17.2584,  8.3290, 27.7419], grad_fn=<AddBackward0>)

In [43]:
nn.L1Loss()(model.inverse(y), x)

tensor(9.3338, grad_fn=<L1LossBackward>)

In [47]:
model.inverse(y) - x

tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00, -5.9605e-08],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00, -4.7684e-07],
         [-1.9073e-06, -2.3842e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00]],

        [[ 2.0160e+00,  4.5725e+00, -6.5988e+00, -6.9353e+00, -1.9093e+01,
           6.4344e+00,  4.9652e+00, -9.1800e+00],
         [-5.2022e+00, -1.6432e+01, -1.0714e+00,  3.1963e+01, -4.8019e+01,
          -2.0962e+02, -1.3407e-02,  1.0146e+01],
         [-3.1409e-01, -8.2552e+00, -7.9736e+00,  2.8125e+00,  1.7196e+01,
          -1.2859e+01, -4.4796e+00, -1.7652e+00]

In [74]:
mat = INN.utilities.PLUMatrix(5)
mat.cuda()

PLUMatrix()

In [71]:
weight = mat.W().unsqueeze(-1)
F.conv1d(x, weight).shape

torch.Size([3, 5, 8])

In [73]:
x.shape

torch.Size([3, 5, 8])

In [75]:
mat.W()

tensor([[ 0.6578, -0.3593,  0.5535, -0.2352, -0.2766],
        [-0.4351, -0.4065, -0.1952, -0.1574, -0.7633],
        [ 0.2616, -0.7114, -0.5362,  0.1648,  0.3330],
        [-0.1821, -0.0787, -0.0500, -0.9152,  0.3472],
        [ 0.5258,  0.4398, -0.6046, -0.2349, -0.3309]], device='cuda:0',
       grad_fn=<MmBackward>)

In [80]:
class Conv1d1x1(nn.Module):
    def __init__(self, num_feature, mat=None):
        super(Conv1d1x1, self).__init__()
        if mat is None:
            self.mat = INN.utilities.PLUMatrix(num_feature)
        else:
            self.mat = mat
    
    def weight(self):
        return self.mat.W().unsqueeze(-1)
    
    def weight_inv(self):
        return self.mat.inv_W().unsqueeze(-1)
    
    def forward(self, x, log_p0=0, log_det_J=0):
        return F.conv1d(x, self.weight())
    
    def inverse(self, y):
        return F.conv1d(y, self.weight_inv())

In [81]:
model = Conv1d1x1(5)

In [82]:
y = model(x).detach()

In [85]:
model.inverse(y) - x

tensor([[[ 9.5367e-07, -7.7486e-07,  0.0000e+00, -2.5034e-06,  0.0000e+00,
           1.0729e-06, -1.9073e-06,  5.3644e-07],
         [-1.9073e-06,  9.5367e-07,  9.5367e-07,  2.8610e-06,  0.0000e+00,
          -1.9073e-06,  1.9073e-06, -2.3842e-07],
         [ 0.0000e+00,  0.0000e+00,  4.7684e-07,  0.0000e+00, -1.9073e-06,
           0.0000e+00,  3.5763e-07, -9.5367e-07],
         [-3.8147e-06,  0.0000e+00,  0.0000e+00,  1.9073e-06,  3.8147e-06,
          -4.7684e-07,  0.0000e+00, -1.1921e-07],
         [-1.6689e-06, -9.5367e-07,  4.7684e-07,  0.0000e+00,  4.7684e-07,
           9.5367e-07, -1.9073e-06,  1.4305e-06]],

        [[-7.1526e-07,  1.9073e-06,  4.7684e-07, -1.4305e-06,  0.0000e+00,
           2.3842e-06, -1.4305e-06,  0.0000e+00],
         [ 0.0000e+00,  9.5367e-07, -4.7684e-07, -2.6226e-06, -1.9073e-06,
          -1.9073e-06,  9.5367e-07, -1.4305e-06],
         [ 0.0000e+00,  0.0000e+00,  9.5367e-07,  1.9073e-06,  0.0000e+00,
           4.7684e-07, -4.7684e-07,  4.7684e-07]

In [6]:
nn.Conv2d(5, 5, 1).weight.shape

torch.Size([5, 5, 1, 1])

## Reshape

In [None]:
import INN
import torch
import torch.nn as nn
import torch.nn.functional as F

In [13]:
class reshape(nn.Module):
    '''
    Invertible reshape
    
    * shape_in: shape of the input. Note that batch_size don't need to be included.
    * shape_out: shape of the output
    '''
    def __init__(self, shape_in, shape_out):
        super(reshape, self).__init__()
        
        #self._check_shape(shape_in, shape_out)
        self.shape_in = shape_in
        self.shape_out = shape_out
    
    def _check_shape(self, shape_in, shape_out):
        '''
        Check if the in and out are in the same size
        '''
        s_in = 1
        for d in shape_in:
            s_in *= d
        s_out = 1
        for d in shape_out:
            s_out *= d
        
        if s_in != s_out:
            raise ValueError(f'shape_in and shape_out must have the same size, but got {s_in} and {s_out}.')
        return
    
    def forward(self, x):
        batch_size = x.shape[0]
        return x.reshape(batch_size, *self.shape_out)
    
    def inverse(self, x):
        batch_size = x.shape[0]
        return x.reshape(batch_size, *self.shape_in)

In [24]:
x = torch.randn((3, 5, 9))
model = reshape(shape_in=(5, 9), shape_out=(-1,))

In [25]:
model(x).shape

torch.Size([3, 45])