In [1]:
import math
import numpy as np
import torch
import torch.nn as nn   
from functorch import jacrev, vmap


In [2]:
# squashed gaussian

In [None]:
MIN_LOG_NN_OUTPUT = -5
MAX_LOG_NN_OUTPUT = 2
SMALL_NUMBER = 1e-6
class TorchSquashedGaussian:
    """A tanh-squashed Gaussian distribution defined by: mean, std, low, high.

    The distribution will never return low or high exactly, but
    `low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively.
    """

    def __init__(
        self,
        inputs,
        model,
        low: float = -1.0,
        high: float = 1.0,
    ):
        """Parameterizes the distribution via `inputs`.

        Args:
            low: The lowest possible sampling value
                (excluding this value).
            high: The highest possible sampling value
                (excluding this value).
        """
        super().__init__()
        # Split inputs into mean and log(std).
        mean, log_std = torch.chunk(self.inputs, 2, dim=-1)
        # Clip `scale` values (coming from NN) to reasonable values.
        log_std = torch.clamp(log_std, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT)
        std = torch.exp(log_std)
        self.dist = torch.distributions.normal.Normal(mean, std)
        assert np.all(np.less(low, high))
        self.low = low
        self.high = high
        self.mean = mean
        self.std = std

    def deterministic_sample(self):
        self.last_sample = self._squash(self.dist.mean)
        return self.last_sample

    def sample(self):
        # Use the reparameterization version of `dist.sample` to allow for
        # the results to be backprop'able e.g. in a loss term.

        normal_sample = self.dist.rsample()
        self.last_sample = self._squash(normal_sample)
        return self.last_sample

    def logp(self, x):
        # Unsquash values (from [low,high] to ]-inf,inf[)
        unsquashed_values = self._unsquash(x)
        # Get log prob of unsquashed values from our Normal.
        log_prob_gaussian = self.dist.log_prob(unsquashed_values)
        # For safety reasons, clamp somehow, only then sum up.
        log_prob_gaussian = torch.clamp(log_prob_gaussian, -100, 100)
        log_prob_gaussian = torch.sum(log_prob_gaussian, dim=-1)
        # Get log-prob for squashed Gaussian.
        unsquashed_values_tanhd = torch.tanh(unsquashed_values)
        log_prob = log_prob_gaussian - torch.sum(
            torch.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER), dim=-1
        )
        return log_prob

    def sample_logp(self):
        z = self.dist.rsample()
        actions = self._squash(z)
        return actions, torch.sum(
            self.dist.log_prob(z) - torch.log(1 - actions * actions + SMALL_NUMBER),
            dim=-1,
        )

    def entropy(self):
        raise ValueError("Entropy not defined for SquashedGaussian!")

    def kl(self, other ):
        raise ValueError("KL not defined for SquashedGaussian!")

    def _squash(self, raw_values):
        # Returned values are within [low, high] (including `low` and `high`).
        squashed = ((torch.tanh(raw_values) + 1.0) / 2.0) * (
            self.high - self.low
        ) + self.low
        return torch.clamp(squashed, self.low, self.high)

    def _unsquash(self, values):
        normed_values = (values - self.low) / (self.high - self.low) * 2.0 - 1.0
        # Stabilize input to atanh.
        save_normed_values = torch.clamp(
            normed_values, -1.0 + SMALL_NUMBER, 1.0 - SMALL_NUMBER
        )
        unsquashed = torch.atanh(save_normed_values)
        return unsquashed


In [7]:
x = torch.from_numpy(np.array([120]))
x_max = 100

loss = torch.clip( x - x_max , min = 0, max = None)

In [8]:
loss

tensor([20])

In [2]:
net = nn.Sequential(
    nn.Linear(10,10),
    nn.ReLU(),
    nn.Linear(10,1) 
)
x = torch.randn(3,10)

In [3]:
jacob = vmap(jacrev(net))(x)

  warn_deprecated('jacrev')
  warn_deprecated('vmap', 'torch.vmap')


In [4]:
jacob.shape

torch.Size([3, 1, 10])

In [5]:
jacob_norm = torch.norm(jacob,2,dim=(1,2)).unsqueeze(1)

In [6]:
jacob_norm.shape

torch.Size([3, 1])