# Injective flow

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from ciflows.flows import Injective1x1Conv

import numpy as np
import torch.nn as nn


  from .autonotebook import tqdm as notebook_tqdm


In [3]:

import tensorflow as tf
from keras import layers
import scipy


class invertible_1x1_conv(layers.Layer):
    """Invertible 1x1 convolutional layers"""

    def __init__(self, **kwargs):
        super(invertible_1x1_conv, self).__init__()
        self.type = kwargs.get('op_type', 'bijective')
        self.gamma = kwargs.get('gamma', 0.0)
        self.activation = kwargs.get('activation', 'linear')
        
    def build(self, input_shape, w=None):
        _, height, width, channels = input_shape
        
        if self.type=='bijective':
            random_matrix = np.random.randn(channels, channels).astype("float32")
            np_w = scipy.linalg.qr(random_matrix)[0].astype("float32")
            self.activation = 'linear'
            
        else:
            if self.activation == 'linear':
                random_matrix_1 = np.random.randn(channels//2, channels//2).astype("float32")
                random_matrix_2 = np.random.randn(channels//2, channels//2).astype("float32")
                np_w_1 = scipy.linalg.qr(random_matrix_1)[0].astype("float32")
                np_w_2 = scipy.linalg.qr(random_matrix_2)[0].astype("float32")
                np_w = np.concatenate([np_w_1, np_w_2], axis=0)/(np.sqrt(2.0))
                
            elif self.activation == 'relu':
                random_matrix_1 = np.random.randn(channels//2, channels//2).astype("float32")
                np_w = scipy.linalg.qr(random_matrix_1)[0].astype("float32")
        
        if w is not None:
            np_w = w
        self.w = tf.Variable(np_w, name='W', trainable=True)


    def call(self, x, reverse=False):
        # If height or width cannot be statically determined then they end up as
        # tf.int32 tensors, which cannot be directly multiplied with a floating
        # point tensor without a cast.
        _, height, width, channels = x.get_shape().as_list()
        s = tf.linalg.svd(self.w, 
            full_matrices=False, compute_uv=False)
        
        
        log_s = tf.math.log(s + self.gamma**2/(s + 1e-8))
        objective = tf.reduce_sum(log_s) * \
            tf.cast(height * width, log_s.dtype)
    
        if not reverse:
            
            if self.activation == 'relu':
                x = x[:,:,:,:channels//2] - x[:,:,:,channels//2:]
            w = tf.reshape(self.w , [1, 1] + self.w.get_shape().as_list())
            print('forward pass: ', x.shape, w.shape)
            x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME", data_format="NHWC")
        
        
        else:
            
            if self.activation=='relu':
                prefactor = tf.matmul(self.w, self.w, transpose_a=True) + \
                    self.gamma**2*tf.eye(tf.shape(self.w)[1])
                
                w_inv = tf.matmul(tf.linalg.inv(prefactor), self.w, transpose_b=True)
                conv_filter = tf.concat([w_inv, -w_inv], axis=1)
                conv_filter = tf.reshape(conv_filter, [1, 1] + conv_filter.get_shape().as_list())
                x = tf.nn.conv2d(x, conv_filter, [1, 1, 1, 1], "SAME", data_format="NHWC")
                x = tf.nn.relu(x)
            
            else:
                prefactor = tf.matmul(self.w, self.w, transpose_a=True) + \
                    self.gamma**2*tf.eye(tf.shape(self.w)[1])

                w_inv = tf.matmul(  tf.linalg.inv(prefactor) , self.w, transpose_b=True)
                print('reverse pass: ', w_inv.shape)
                conv_filter = w_inv
                conv_filter = tf.reshape(conv_filter, [1, 1] + conv_filter.get_shape().as_list())
                x = tf.nn.conv2d(x, conv_filter, [1, 1, 1, 1], "SAME", data_format="NHWC")
            
            objective *= -1
        return x, objective

    def jvp(self, x):
        """Calculates the jacobian-vector product with reverse=True"""
        # here J is a diagonal matrix and therefore, v^T J or Jv will 
        # be the same, except the transpose. 

        return self.call(x, reverse=True)

    def vjp(self, v):
        """Calculates the vector-Jacobian product with reverse=True"""
        ## here we just transpose the pseudo-inverted filter and apply to v
        if self.activation=='relu':
            prefactor = tf.matmul(self.w, self.w, transpose_a=True) + \
                self.gamma**2*tf.eye(tf.shape(self.w)[1])
            
            w_inv = tf.transpose(tf.matmul(tf.linalg.inv(prefactor), self.w, transpose_b=True))
            conv_filter = tf.concat([w_inv, -w_inv], axis=1)
            conv_filter = tf.reshape(conv_filter, [1, 1] + conv_filter.get_shape().as_list())
            out = tf.nn.conv2d(v, conv_filter, [1, 1, 1, 1], "SAME", data_format="NHWC")
            out = tf.nn.relu(out)
        
        else:
            prefactor = tf.matmul(self.w, self.w, transpose_a=True) + \
                self.gamma**2*tf.eye(tf.shape(self.w)[1])

            w_inv = tf.transpose(tf.matmul(tf.linalg.inv(prefactor) , self.w, transpose_b=True))
            conv_filter = w_inv
            conv_filter = tf.reshape(conv_filter, [1, 1] + conv_filter.get_shape().as_list())
            out = tf.nn.conv2d(out, conv_filter, [1, 1, 1, 1], "SAME", data_format="NHWC")

        return out

In [4]:

# Create a sample input tensor (e.g., batch size of 1, 4 channels, 32x32 image)
input_tensor = tf.random.normal((1, 32, 32, 4))

# Instantiate the layer
invertible_layer = invertible_1x1_conv(op_type='injective', gamma=0.0, activation='linear')

# Build the layer to initialize weights
invertible_layer.build(input_tensor.shape)

# Perform a forward pass
output_tensor, objective = invertible_layer(input_tensor)

# Perform a reverse pass
reconstructed_tensor, _ = invertible_layer(output_tensor, reverse=True)

# Print the results
print("Input Tensor Shape:", input_tensor.shape)
print("Output Tensor Shape:", output_tensor.shape)
print("Reconstructed Tensor Shape:", reconstructed_tensor.shape)

# Check if the reconstruction is close to the original input
print("Is reconstruction close to input?", tf.reduce_all(tf.abs(input_tensor - reconstructed_tensor) < 1e-5).numpy())
print(tf.reduce_mean(tf.abs(input_tensor - reconstructed_tensor)))


forward pass:  (1, 32, 32, 4) (1, 1, 4, 2)
reverse pass:  (2, 4)
Input Tensor Shape: (1, 32, 32, 4)
Output Tensor Shape: (1, 32, 32, 2)
Reconstructed Tensor Shape: (1, 32, 32, 4)
Is reconstruction close to input? False
tf.Tensor(0.5807822, shape=(), dtype=float32)


In [5]:
# now we will start from the latent space and try to reconstruct the image
input_tensor = tf.random.normal((1, 32, 32, 2))

# Perform a forward pass
output_tensor, objective = invertible_layer(input_tensor, reverse=True)

# Perform a reverse pass
reconstructed_tensor, _ = invertible_layer(output_tensor, reverse=False)

# Print the results
print("Input Tensor Shape:", input_tensor.shape)
print("Output Tensor Shape:", output_tensor.shape)
print("Reconstructed Tensor Shape:", reconstructed_tensor.shape)

# Check if the reconstruction is close to the original input
print("Is reconstruction close to input?", tf.reduce_all(tf.abs(input_tensor - reconstructed_tensor) < 1e-5).numpy())
print(tf.reduce_mean(tf.abs(input_tensor - reconstructed_tensor)))

reverse pass:  (2, 4)
forward pass:  (1, 32, 32, 4) (1, 1, 4, 2)
Input Tensor Shape: (1, 32, 32, 2)
Output Tensor Shape: (1, 32, 32, 4)
Reconstructed Tensor Shape: (1, 32, 32, 2)
Is reconstruction close to input? True
tf.Tensor(3.7361964e-08, shape=(), dtype=float32)


In [6]:
import torch.nn.functional as F
import scipy.linalg


class Invertible1x1Conv(nn.Module):
    """Invertible 1x1 convolutional layers"""

    def __init__(self, op_type="bijective", gamma=0.0, activation="linear"):
        super(Invertible1x1Conv, self).__init__()
        self.op = op_type
        self.gamma = gamma
        self.activation = activation
        self.w = None  # Placeholder for weight matrix

    def build(self, channels, w=None):
        if self.op == "bijective":
            random_matrix = np.random.randn(channels, channels).astype("float32")
            np_w = scipy.linalg.qr(random_matrix)[0].astype("float32")

        else:  # injective operation
            if self.activation == "linear":
                random_matrix_1 = np.random.randn(channels // 2, channels // 2).astype(
                    "float32"
                )
                np_w_1 = scipy.linalg.qr(random_matrix_1)[0].astype("float32")
                random_matrix_2 = np.random.randn(channels // 2, channels // 2).astype(
                    "float32"
                )
                np_w_2 = scipy.linalg.qr(random_matrix_2)[0].astype("float32")
                np_w = np.concatenate([np_w_1, np_w_2], axis=1) / (np.sqrt(2.0))
            elif self.activation == "relu":
                random_matrix_1 = np.random.randn(channels // 2, channels // 2).astype(
                    "float32"
                )
                np_w = scipy.linalg.qr(random_matrix_1)[0].astype("float32")

        if w is not None:
            np_w = w

        self.w = nn.Parameter(
            torch.tensor(np_w, dtype=torch.float32), requires_grad=True
        )

    def forward(self, x, reverse=False):
        # XXX: Matches now tensorflow implementation forward and reverse.
        _, channels, height, width = x.size()
        s = torch.svd(self.w).S

        log_s = torch.log(s + self.gamma**2 / (s + 1e-8))
        objective = torch.sum(log_s) * (height * width)

        if not reverse:
            if self.activation == "relu":
                # For injective, we are combining halves
                x_a = x[:, : channels // 2, :, :]
                x_b = x[:, channels // 2 :, :, :]
                x = x_a - x_b

            w = self.w.view(self.w.shape[0], self.w.shape[1], 1, 1)
            print('Forward: ', x.shape, w.shape)
            x = F.conv2d(x, w, stride=1, padding=0)

        else:
            prefactor = torch.matmul(
                self.w, self.w.T
            ) + self.gamma**2 * torch.eye(self.w.shape[0])
            w_inv = torch.matmul(self.w.T, torch.linalg.inv(prefactor))
            if self.activation == "relu":
                # prefactor = torch.mm(self.w.t(), self.w) + self.gamma**2 * torch.eye(
                #     self.w.size(1)
                # )
                # w_inv = torch.mm(torch.linalg.inv(prefactor), self.w.t())
                conv_filter = torch.cat([w_inv, -w_inv], dim=0)
                conv_filter = conv_filter.view(
                     conv_filter.size(0), conv_filter.size(1), 1, 1,
                )
                x = F.conv2d(x, conv_filter, stride=1, padding=0)
                x = F.relu(x)

            else:
                # prefactor = torch.matmul(
                #     self.w, self.w.T
                # ) + self.gamma**2 * torch.eye(self.w.shape[0])
                # w_inv = torch.matmul(self.w.T, torch.linalg.inv(prefactor))
                print('pt reverse: ', w_inv.shape)
                conv_filter = w_inv.view(w_inv.size(0), w_inv.size(1), 1, 1)
                x = F.conv2d(x, conv_filter, stride=1, padding=0)

            objective *= -1

        return x, objective

    def jvp(self, v):
        """Calculates the Jacobian-vector product with reverse=True"""
        return self.forward(v, reverse=True)

    def vjp(self, v):
        """Calculates the vector-Jacobian product with reverse=True"""
        prefactor = torch.matmul(
            self.w, self.w.T
        ) + self.gamma**2 * torch.eye(self.w.shape[0])
        w_inv = torch.matmul(self.w.T, torch.linalg.inv(prefactor))
        if self.activation == "relu":
            # prefactor = torch.mm(self.w.t(), self.w) + self.gamma**2 * torch.eye(
            #     self.w.size(1)
            # )
            # w_inv = self.w.t() @ torch.linalg.inv(prefactor)
            conv_filter = torch.cat([w_inv, -w_inv], dim=0)
            conv_filter = conv_filter.view(
                conv_filter.size(0), conv_filter.size(1), 1, 1
            )
            out = F.conv2d(v, conv_filter, stride=1, padding=0)
            out = F.relu(out)

        else:
            # prefactor = torch.mm(self.w.t(), self.w) + self.gamma**2 * torch.eye(
            #     self.w.size(1)
            # )
            # w_inv = self.w.t() @ torch.linalg.inv(prefactor)
            conv_filter = w_inv.view(w_inv.size(0), w_inv.size(1), 1, 1)
            out = F.conv2d(v, conv_filter, stride=1, padding=0)

        return out

In [7]:
# test the w matrix setup:

your_channel_count = 4
batch_size = 2
height = 28
width = 28

# Create a sample input tensor (e.g., batch size of 1, 4 channels, 32x32 image)
input_tensor = tf.random.normal((1, 28, 28, 4))
invertible_layer = invertible_1x1_conv(op_type='injective', gamma=0.0, activation='linear')
invertible_layer.build(input_tensor.shape)

conv_layer = Invertible1x1Conv(op_type="injective", activation="linear")
conv_layer.build(channels=your_channel_count)

print(conv_layer.w.shape, invertible_layer.w.shape)
print(np.round((conv_layer.w.T @ conv_layer.w).detach(), 3))
print(np.round((torch.Tensor(np.array(invertible_layer.w)) @ torch.Tensor(np.array(invertible_layer.w)).T).detach(), 3))

# assert torch.allclose(conv_layer.w, torch.Tensor(np.array(invertible_layer.w))), "Weight matrix is not symmetric"

torch.Size([2, 4]) (4, 2)
tensor([[ 0.5000, -0.0000, -0.1550, -0.4750],
        [-0.0000,  0.5000,  0.4750, -0.1550],
        [-0.1550,  0.4750,  0.5000, -0.0000],
        [-0.4750, -0.1550, -0.0000,  0.5000]])
tensor([[ 0.5000,  0.0000,  0.4960, -0.0590],
        [ 0.0000,  0.5000,  0.0590,  0.4960],
        [ 0.4960,  0.0590,  0.5000, -0.0000],
        [-0.0590,  0.4960, -0.0000,  0.5000]])


In [15]:
activation = 'relu'

# Create a sample input tensor (e.g., batch size of 1, 4 channels, 32x32 image)
input_tensor = tf.random.normal((1, 2, 2, 4))
input_tensor_pt = torch.Tensor(np.array(input_tensor)).permute(0, 3, 1, 2)  # NHWC to NCHW

channels = input_tensor.shape[-1]
random_matrix_1 = np.random.randn(channels // 2, channels // 2).astype("float32")
np_w_1 = scipy.linalg.qr(random_matrix_1)[0].astype("float32")
random_matrix_2 = np.random.randn(channels // 2, channels // 2).astype("float32")
np_w_2 = scipy.linalg.qr(random_matrix_2)[0].astype("float32")
np_w = np.concatenate([np_w_1, np_w_2], axis=0) / (np.sqrt(2.0))
print('W shape for TF: ', np_w.shape)
if activation == "relu":
    random_matrix_1 = np.random.randn(channels // 2, channels // 2).astype("float32")
    np_w = scipy.linalg.qr(random_matrix_1)[0].astype("float32")

# test tensorflow
invertible_layer = invertible_1x1_conv(op_type='injective', gamma=0.0, activation=activation)
invertible_layer.build(input_tensor.shape, w=np_w)
output_tensor, objective = invertible_layer(input_tensor)
# print(output_tensor.shape)
recon_tensor, objective = invertible_layer(output_tensor, reverse=True)

# Create MSE metric
mse_metric = tf.keras.metrics.MeanSquaredError()

# Calculate MSE
mse = mse_metric(input_tensor, recon_tensor)
print('MSE: ', mse)

# test pytorch
if activation == 'linear':
    np_w = np.concatenate([np_w_1, np_w_2], axis=1) / (np.sqrt(2.0))
print('W shape for PT: ', np_w.shape)
conv_layer = Invertible1x1Conv(op_type="injective", activation=activation)
conv_layer.build(channels=your_channel_count, w=np_w)
output_tensor_pt, objective_pt = conv_layer(input_tensor_pt)

recon_tensor_pt, objective_pt = conv_layer(output_tensor_pt, reverse=True)

# compute mse
print(input_tensor_pt.shape, output_tensor_pt.shape, recon_tensor_pt.shape)

mse = torch.nn.functional.mse_loss(input_tensor_pt, recon_tensor_pt)
print('MSE: ', mse)


# mse = torch.nn.functional.mse_loss(input_tensor_pt, torch.Tensor(np.array(recon_tensor)).permute(0, 3, 1, 2))
# print(f'MSE: ', mse)

print(torch.nn.functional.mse_loss(recon_tensor_pt, torch.Tensor(np.array(recon_tensor)).permute(0, 3, 1, 2)))
print(torch.nn.functional.mse_loss(output_tensor_pt, torch.Tensor(np.array(output_tensor)).permute(0, 3, 1, 2)))
# np.testing.assert_approx_equal(objective_pt.detach().numpy(), np.array(objective), significant=5), "Objective values are not equal"
# print(output_tensor, output_tensor_pt)
# print(output_tensor.shape, output_tensor_pt.shape)

W shape for TF:  (4, 2)
forward pass:  (1, 2, 2, 2) (1, 1, 2, 2)
MSE:  tf.Tensor(1.0731158, shape=(), dtype=float32)
W shape for PT:  (2, 2)
Forward:  torch.Size([1, 2, 2, 2]) torch.Size([2, 2, 1, 1])
torch.Size([1, 4, 2, 2]) torch.Size([1, 2, 2, 2]) torch.Size([1, 4, 2, 2])
MSE:  tensor(1.0731, grad_fn=<MseLossBackward0>)
tensor(0., grad_fn=<MseLossBackward0>)
tensor(0., grad_fn=<MseLossBackward0>)


In [21]:
print(conv_layer.w)

Parameter containing:
tensor([[-0.0573,  0.9984],
        [ 0.9984,  0.0573]], requires_grad=True)


In [22]:
# test pytorch glow
flow = Injective1x1Conv(num_channels_in=your_channel_count, activation='relu', preset_W=torch.Tensor(np_w))
output_tensor_pt, objective_pt = flow.inverse(input_tensor_pt)
recon_tensor_pt, objective_pt = flow.forward(output_tensor_pt)

# compute mse
print(input_tensor_pt.shape, output_tensor_pt.shape, recon_tensor_pt.shape)

mse = torch.nn.functional.mse_loss(output_tensor_pt, torch.Tensor(np.array(output_tensor)).permute(0, 3, 1, 2))
print('MSE: ', mse)
mse = torch.nn.functional.mse_loss(recon_tensor_pt, torch.Tensor(np.array(recon_tensor)).permute(0, 3, 1, 2))
print('MSE: ', mse)

print(input_tensor_pt.shape, output_tensor_pt.shape, recon_tensor_pt.shape)
mse = torch.nn.functional.mse_loss(input_tensor_pt, recon_tensor_pt)
print('MSE: ', mse)


torch.Size([1, 4, 2, 2]) torch.Size([1, 2, 2, 2]) torch.Size([1, 4, 2, 2])
MSE:  tensor(0., grad_fn=<MseLossBackward0>)
MSE:  tensor(0., grad_fn=<MseLossBackward0>)
torch.Size([1, 4, 2, 2]) torch.Size([1, 2, 2, 2]) torch.Size([1, 4, 2, 2])
MSE:  tensor(1.0731, grad_fn=<MseLossBackward0>)


In [178]:
prefactor = torch.matmul(
    conv_layer.w, conv_layer.w.T
) + conv_layer.gamma**2 * torch.eye(conv_layer.w.shape[0])
print(prefactor)
w_inv_pt = torch.matmul(conv_layer.w.T, torch.linalg.inv(prefactor))
print('pt reverse: ', w_inv_pt.shape, w_inv_pt)

prefactor = tf.matmul(
    invertible_layer.w, invertible_layer.w, transpose_a=True
) + invertible_layer.gamma**2 * torch.eye(invertible_layer.w.shape[1])
print(prefactor)
w_inv = tf.matmul(  tf.linalg.inv(prefactor) , invertible_layer.w, transpose_b=True)
print(w_inv)

assert torch.allclose(w_inv_pt.T, torch.Tensor(np.array(w_inv))), "Weight matrix same"


# w_inv = torch.matmul(conv_layer.w, torch.linalg.inv(prefactor), conv_layer.w)
# print('pt reverse: ', w_inv.shape)
# conv_filter = w_inv.view(w_inv.size(0), w_inv.size(1), 1, 1)

tensor([[ 1.0000e+00, -1.3758e-07],
        [-1.3758e-07,  1.0000e+00]], grad_fn=<AddBackward0>)
pt reverse:  torch.Size([4, 2]) tensor([[-0.1462,  0.6918],
        [ 0.6918,  0.1462],
        [-0.2179, -0.6727],
        [-0.6727,  0.2179]], grad_fn=<MmBackward0>)
tf.Tensor(
[[ 1.0000001e+00 -1.3758199e-07]
 [-1.3758199e-07  1.0000000e+00]], shape=(2, 2), dtype=float32)
tf.Tensor(
[[-0.14615329  0.69183755 -0.21789008 -0.67269886]
 [ 0.6918376   0.14615338 -0.67269903  0.21789002]], shape=(2, 4), dtype=float32)


In [125]:
# your_channel_count = 4
# batch_size = 2
# height = 28
# width = 28

# conv_layer = Invertible1x1Conv(op_type="injective", activation="linear")
# conv_layer.build(channels=your_channel_count)

# # Test with random input -> latents
# x = torch.randn(batch_size, your_channel_count, height, width)
# output, objective = conv_layer.forward(x, reverse=False)

# # reverse the latents to get the inputs
# x_reversed, objective_reversed = conv_layer.forward(output, reverse=False)

# assert torch.allclose(x, x_reversed)
# print(output.shape, x.shape, objective)