<a href="https://colab.research.google.com/github/Hgherzog/NoisyBatchNorm-YeoJohnsonBatchNorm/blob/main/yjnorm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch, math, copy
import numpy as np
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from scipy import optimize

In [None]:
!pip install d2l==1.0.0-alpha1.post0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting d2l==1.0.0-alpha1.post0
  Downloading d2l-1.0.0a1.post0-py3-none-any.whl (93 kB)
[K     |████████████████████████████████| 93 kB 335 kB/s 
[?25hCollecting matplotlib-inline
  Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)
Collecting jupyter
  Downloading jupyter-1.0.0-py2.py3-none-any.whl (2.7 kB)
Collecting qtconsole
  Downloading qtconsole-5.4.0-py3-none-any.whl (121 kB)
[K     |████████████████████████████████| 121 kB 28.3 MB/s 
Collecting jedi>=0.10
  Downloading jedi-0.18.2-py2.py3-none-any.whl (1.6 MB)
[K     |████████████████████████████████| 1.6 MB 91.0 MB/s 
Collecting qtpy>=2.0.1
  Downloading QtPy-2.3.0-py3-none-any.whl (83 kB)
[K     |████████████████████████████████| 83 kB 2.7 MB/s 
Installing collected packages: jedi, qtpy, qtconsole, matplotlib-inline, jupyter, d2l
Successfully installed d2l-1.0.0a1.post0 jedi-0.18.2 jupyter-1.0.0 matplotlib-i

In [None]:
#Yeo Johnson Transformation with pytorch tensors 
#Base code acquired from scikit learn source code
#yeo johnson
def _yeo_johnson_transform(x, lmbda):
  """Return transformed input x following Yeo-Johnson transform with
  parameter lambda, works on single feature or channel 
  over a batch
  """
  out = torch.zeros_like(x)
  pos = x >= 0  # binary mask

  eps = torch.finfo(torch.float64).eps
  #when x>= 0
  # print(lmbda)
  if abs(lmbda) < eps:
    print(torch.all(1+x[pos]))
    out[pos] = torch.log(1 + x[pos])
  else:  # lmbda != 0
    out[pos] = (torch.pow(x[pos] + 1, lmbda) - 1) / lmbda

  # when x < 0
  if abs(2 - lmbda) > eps:
    out[~pos] = -(torch.pow(-x[~pos] + 1, 2 - lmbda) - 1) / (2 - lmbda)
  else:  # lmbda == 2
    print(torch.all(1 -x[~pos]))
    out[~pos] = -torch.log(1 -x[~pos])

  return out

def YJ_Transform(X, lmbda):
  """Return feature or channel wise transformed input X according to the
  Yeo-Johnson transform with parameter lambda lambda shape depends on if it is being used for a fully connected or 2D convolutional Layer
  """
  # assert len(X.shape) in (2, 4)
  if len(X.shape) == 2:
    # When using a fully connected layer, transform X on the feature dimension
    #I want to lambda map the yj across the features using the corresponding
    num_features = len(lmbda)
    #TODO: to write this more efficiently with some sort of lambda transformation across features
    features = []
    for i in range(num_features):
      X_hat = _yeo_johnson_transform(X[i:i+1], lmbda[i])
      features.append(X_hat)
      #This should be the data such that the ith feature is transformed by the ith lambda value so that each of the features is mapped by same transform
    X_out = torch.vstack(features)

  else:
  # When using a two-dimensional convolutional layer, calculate the
  # mean and variance on the channel dimension (axis=1). Here we
  # need to maintain the shape of `X`, so that the broadcasting
  # operation can be carried out later
  #need to make sure lambda has shape of number of channels for the layer if that is what the input is
    num_channels = len(lmbda)
    channels = []
    #TODO: Figure out how to optimixe this to be a lambda so we don't need the for loop
    for i in range(num_channels):
      # print("Channel slice for transformation", X[:,i:, :, :].size())
      X_hat = _yeo_johnson_transform(X[:,i:i+1, :, :], lmbda[i])
      channels.append(X_hat)
    X_out = torch.cat(channels, 1)
  return X_out


def _yeo_johnson_inverse_transform(x, lmbda):
  """Return inverse-transformed input x following Yeo-Johnson inverse
  transform with parameter lambda, across specific feture or channel
  """
  x_inv = torch.zeros_like(x)
  pos = x >= 0

  # when x >= 0
  eps = torch.finfo(torch.float64).eps
  if abs(lmbda) < eps:
    x_inv[pos] = torch.exp(x[pos]) - 1
  else:  # lmbda != 0
    x_inv[pos] = torch.pow(x[pos] * lmbda + 1, 1 / lmbda) - 1

  # when x < 0
  if abs(lmbda - 2) > eps:
    x_inv[~pos] = 1 - torch.pow(-(2 - lmbda) * x[~pos] + 1, 1 / (2 - lmbda))
  else:  # lmbda == 2
    x_inv[~pos] = 1 - torch.exp(-x[~pos])

  return x_inv

def YJINV_Transform(X, lmbda): 
  """Return feature or channel wise transformed input X according to the
  Yeo-Johnson transform with parameter lambda lambda shape depends on if it is being used for a fully connected or 2D convolutional Layer
  """
  # assert len(X.shape) in (2, 4)
  if len(X.shape) == 2:
    # When using a fully connected layer, transform X on the feature dimension
    #I want to lambda map the yj across the features using the corresponding
    num_features = len(lmbda)
    #TODO: to write this more efficiently with some sort of lambda transformation across features
    features = []
    for i in range(num_features):
      X_hat = _yeo_johnson_inverse_transform(X[i:i+1], lmbda[i])
      features.append(X_hat)
    #This should be the data such that the ith feature is transformed by the ith lambda value so that each of the features is mapped by same transform
    X_out = torch.vstack(features)

  else:
    # When using a two-dimensional convolutional layer, calculate the
    # mean and variance on the channel dimension (axis=1). Here we
    # need to maintain the shape of `X`, so that the broadcasting
    # operation can be carried out later
    #need to make sure lambda has shape of number of channels for the layer if that is what the input is
    num_channels = len(lmbda)
    channels = []
    #TODO: Figure out how to optimixe this to be a lambda so we don't need the for loop
    for i in range(num_channels):
      X_hat = _yeo_johnson_inverse_transform(X[:,i:i+1, :,:], lmbda[i])
      channels.append(X_hat)
    X_out = torch.cat(channels, 1) #I am unsure if this stacks the things back up in the way I want it too
    return X_out

#TO:DO need to update this function so that we cna do 
def _yeo_johnson_optimize(x):
  """Find and return optimal lambda parameter of the Yeo-Johnson
  transform by MLE, for observed data x. Need to do this across every feature or channel so lambda operation holds
  Like for Box-Cox, MLE is done via the brent optimizer.
  """
  #small zero like torch values 
  x_tiny = torch.finfo(torch.float64).tiny
  x = x.cpu()
  x= x.detach()
  def _neg_log_likelihood(lmbda):
    """Return the negative log likelihood of the observed data x as a
    function of lambda.
    """
    x_trans = _yeo_johnson_transform(x, lmbda)
    n_samples = x.shape[0] #may need to edit this for pytorch tensor
    x_trans_var = torch.var(x_trans) #Does this need to be attended to across dimensions?

    # Reject transformed data that would raise a RuntimeWarning in np.log
    if x_trans_var < x_tiny:
      return np.inf

    log_var = torch.log(x_trans_var)
    # print(log_var)
    loglike = -n_samples / 2 * log_var
    loglike += (lmbda - 1) * torch.sum(torch.sign(x) * torch.log(1 + torch.abs(x)))

    return -loglike

  # the computation of lambda is influenced by NaNs so we need to
  # get rid of them make them all 0
  x = x[~torch.isnan(x)]
  # choosing bracket -2, 2 like for boxcox
  
  return optimize.brent(_neg_log_likelihood, brack=(-2, 2))



In [None]:
from d2l import torch as d2l
#currently testing with noise addition
def yj_batch_norm(X, gamma, beta, lmbda, moving_mean, moving_var, moving_lmbda, eps, momentum):
  """ Batch Normalization layer that normalizaes higher order moments using yeo johnson transform
  """
    # Use `is_grad_enabled` to determine whether we are in training mode
    if not torch.is_grad_enabled():
        # In prediction mode, use mean and variance obtained by moving average
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
        X_hat = YJ_Transform(X_hat, moving_lmbda)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # When using a fully connected layer, calculate the mean and
            # variance on the feature dimension
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
            #Is there a way to optimize this without going into C++ source code
            num_features = len(lmbda)
            #TODO: to write this more efficiently with some sort of lambda transformation across features
            lmbda_mle = torch.zeros(num_features)
            for i in range(num_features):
              lmbda_mle[i] = _yeo_johnson_optimize(X[i:i+1])
        else:
            # When using a two-dimensional convolutional layer, calculate the
            # mean and variance on the channel dimension (axis=1). Here we
            # need to maintain the shape of `X`, so that the broadcasting
            # operation can be carried out later
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
            num_channels = len(lmbda)
            lmbda_mle = torch.zeros(num_channels).cuda()
            for i in range(num_channels):
              channel = X[:,i:i+1,:, :]
              lmbda_mle[i] = _yeo_johnson_optimize(channel) 
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # print("Size before transform:", X_hat.size())
        X_hat = YJ_Transform(X_hat, lmbda_mle)
        # print("Size after transform:", X_hat.size())
        # Update the mean and variance using moving average
        moving_mean = (1.0 - momentum) * moving_mean + momentum * mean
        moving_var = (1.0 - momentum) * moving_var + momentum * var
        moving_lmbda = (1.0 - momentum) * moving_lmbda + momentum * lmbda_mle
        
    #Untrasform scale and shift if needed
    X_untrans = YJINV_Transform(X_hat, lmbda)
    Y =  gamma*X_untrans + beta 
    return Y, moving_mean.data, moving_var.data, moving_lmbda.data

In [None]:
class YJ_Norm(nn.Module):
    # `num_features`: the number of outputs for a fully connected layer
    # or the number of output channels for a convolutional layer. `num_dims`:
    # 2 for a fully connected layer and 4 for a convolutional layer
    #Does batch norm with data normalized to a log normal distribution
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        #The power transform parameter is initialized to 1,
        #corresponding with the YJ identity transform
        self.lmbda = nn.Parameter(torch.ones(num_features))
        # The variables that are not model parameters are initialized to 0 and
        # 1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)
        self.moving_lmbda = torch.ones(num_features)

    def forward(self, X):
        # If `X` is not on the main memory, copy `moving_mean` and
        # `moving_var` to the device where `X` is located
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
            self.moving_lmbda = self.moving_lmbda.to(X.device)
        # Save the updated `moving_mean` and `moving_var`
        Y, self.moving_mean, self.moving_var , self.moving_lmbda = yj_batch_norm(
            X, self.gamma, self.beta, self.lmbda, self.moving_mean,
            self.moving_var, self.moving_lmbda, eps=1e-5, momentum=0.1)
        return Y

In [None]:
lr = 0.01
momentum = 0.9 
epochs = 3
criterion = torch.nn.CrossEntropyLoss()