In [16]:
import numpy as np
import torch
from pyro.distributions import TransformModule
from pyro.distributions import transforms as T
from torch.distributions import constraints, TransformedDistribution

from torch import nn



class PartialTransform(TransformModule):
    domain = constraints.real_vector
    codomain = constraints.real_vector

    def __init__(self, transform, dims_to_transform):
        super().__init__(cache_size=1)
        self.transform = transform
        self.dims_to_transform = dims_to_transform

    def _call(self, x):
        x_left, x_right = torch.split(x, [
            self.dims_to_transform,
            x.shape[1]-self.dims_to_transform
        ], -1)

        x = x_left
        y = self.transform(x)
        log_abs_det_jacobian = self.transform.log_abs_det_jacobian(x, y)

        self._cached_log_abs_det_jacobian = log_abs_det_jacobian

        return torch.cat([y, x_right], -1)

    def _inverse(self, y):
        y_left, y_right = torch.split(y, [
            self.dims_to_transform,
            y.shape[1] - self.dims_to_transform
        ], -1)

        y = y_left
        x = self.transform.inv(y)
        log_abs_det_jacobian = -transform.log_abs_det_jacobian(x, y)
        self._cached_log_abs_det_jacobian = log_abs_det_jacobian
        return torch.cat([x, y_right], -1)

    def clear_cache(self):
        self._cached_log_abs_det_jacobian = None

    def log_abs_det_jacobian(self, x, y):
        x_old, y_old = self._cached_x_y
        if not(x is x_old and y is y_old):
            raise NotImplementedError()

        if self._cached_log_abs_det_jacobian is None:
            raise NotImplementedError()

        log_abs_det_jacobian = self._cached_log_abs_det_jacobian
        self.clear_cache()
        return log_abs_det_jacobian

In [28]:
t = PartialTransform(T.SoftplusTransform().inv, 3)

In [29]:
x = torch.randn(6, 5).exp()

In [30]:
y = t(x)

In [31]:
y

tensor([[-1.6787e+00,  4.0210e+00,  8.6261e+00,  2.4198e+00,  1.2607e+01],
        [ 9.1735e-01, -1.0803e+00,  4.5691e+00,  7.5617e-01,  9.8101e-01],
        [-2.2943e-01,  7.2165e-01,  5.3149e+00,  6.7265e+00,  1.0125e+01],
        [ 6.1891e-03, -4.6604e-01,  6.6302e-01,  7.6511e-01,  1.7297e+00],
        [ 2.3935e-01,  2.8913e-01,  2.1755e+00,  3.9033e+00,  2.2716e+00],
        [-4.9702e-01,  1.2060e+00, -1.2108e+00,  2.6857e-01,  4.3742e-01]])

In [32]:
t.inv(y) - x

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

In [34]:
x[:, 3:].min()

tensor(0.2686)