In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

In [None]:
class ConvNet(nn.Module):
    """
    the Convnet that predicts (log s,t) from the masked input.
    TODO we have to be carefull about 1-d,2-d,3-d data also 
    Args:
        nn (_type_): _description_
    """
    def __init__(self,):
        super(ConvNet,self).__init__()
        self.NN = None
        #! head must be initialized with zero wieghts,bias...why? 
    def forward(self,x):
        """ 
        TODO - should work for 1-d,2-d,3d data as well       
        Args:
            x (torch.Tensor): shape: (B,C,H,W)
        return: 
            log s,t (torch.Tensor): shape: (B,C,H,W)
        """
        h = self.NN(x) #shape: (B,2C,H,W)
        log_s,t = torch.chunk(h,2,dim=1) # shape: (B,C,H,W)
        #TODO unconstraint log_s... won't that be problem
        return log_s,t
        

In [4]:
class AffineCoupling(nn.Module):
    """
    Affine coupling layer with given binary mask.
    """
    def __init__(self,mask: torch.Tensor):
        super(self,AffineCoupling).__init__()
        self.convnet = ConvNet()
        self.register_buffer('mask',mask)
    def forward(self,x):
        """
        split x->(x1,x2) using the mask.
        cumpute logs,t -> NN(x1)
        apply affine transform on y2.
        Args:
            x (torch.Tensor): input
            log_det (torch.Tensor, optional): _description_. Defaults to None.
        """
        #split
        x1 = self.mask*x
        x2 = (1 - self.mask)*x
        #transform
        log_s,t = self.convnet(x1)
        y2 = x2*torch.exp(log_s) + t
        #squeeze
        y = x1 + y2 
        log_det = torch.sum((log_s*(1 - self.mask)),dim=(1,2,3)) #shape: (B,)
        return y,log_det
    def backward(self,y):
        y1 = self.mask*y
        y2 = (self.mask - 1)*y
        log_s,t = self.convnet(y1)
        x2 = y2*torch.exp(-log_s) - t
        x = y1 + x2
        log_det = torch.sum((-log_s*(1 - self.mask)),dim=(1,2,3))
        return x,log_det    
                        


class BatchNorm(nn.Module):
    def __init__(self,num_channels: int, eps : float = 1e-8,momentum: float = 0.1,affine : bool = False):
        super(BatchNorm,self).__init__()
        self.affine = affine
        self.eps = eps 
        self.momemtum = momentum
        self.batchnorm = nn.BatchNorm2d(num_channels,momentum=momentum,affine=self.affine,eps = self.eps) # dim = x[,:::].dim
        #TODO ig, we have to take care of initialization of stats...
    def forward(self,x:torch.Tensor):
        """
        affine = False, doesn't apply the affine transform
        #TODO careful about the training and inference phase
        #? should be take mean over batch also alongside or maybe later (ig, it would be more stable if be normalize it here only...)
        Args:
            x (torch.Tensor): input
        """
        B,C,H,W = x.shape
        y = self.batchnorm(x)
        # mu,var (per channel statistics)-> one scalar per channel
        if not self.affine:
            log_det = -0.5*H*W*torch.sum(torch.log(self.batchnorm.running_var + self.batchnorm.eps),dim = 1) #shape:(B,)  
            return y,log_det  
        else:
            log_det = H*W*torch.sum(torch.sum(torch.log(self.batchnorm.weight) - 0.5*torch.log(self.batchnorm.running_var + self.batchnorm.eps)),dim = 1) #()
            return y,log_det                          
    def backward(self,y:torch.Tensor):
        """

        Args:
            y (torch.Tensor): input
        """
        B,C,H,W = y.shape
        sigma = (self.batchnorm.running_var + self.batchnorm.eps)**(-0.5)
        mu = -self.batchnorm.running_mean*(sigma**-1)
        x = y*sigma + mu
        if not self.affine:
            log_det = 0.5*H*W*torch.sum(torch.log(self.batchnorm.running_var + self.batchnorm.eps),dim= 0) #shape:(B,) #! ig, dim should be 0(default), as w're summing over C only 
            return x,log_det
        else:
            log_det = H*W*torch.sum(torch.sum( - torch.log(self.batchnorm.weight) + 0.5*torch.log(self.batchnorm.running_var + self.batchnorm.eps)),dim = 0) #! because these stats are for per channel for whole batch
            return x,log_det
        
        
class ActNorm(nn.Module):
    """_summary_

    Args:
        nn (_type_): _description_
    """
    def __init__(self,num_channels:int,eps : float = 1e-5):
        super(ActNorm,self).__init__()
        self.m = nn.Parameter(torch.zeros(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.init = False
        self.eps = eps
    def forward(self,x:torch.Tensor):
        """
        initialization on the first forward pass, mc,bc are set such that 
        the batch of activations has zero mean and unit variance.
        thereafter, there are considered as learnable parameters.
        Args:
            x (torch.Tensor): _description_
        """
        B,C,H,W = x.shape
        if not self.init:
            with torch.no_grad():
                self.m = torch.log(torch.std(x,dim=(0,2,3)) + self.eps) #shape: (C,)
                self.bias = torch.mean(x,dim=(0,2,3)) #shape: (C,)
        y = (x + self.bias)*torch.exp(-self.m)
        log_det = H*W*torch.sum(x,dim=0)
        return y,log_det
    def backward(self,y:torch.Tensor):
        """_summary_

        Args:
            y (torch.Tensor): input
        """
        B,C,H,W = y.shape
        x = y*torch.exp(self.m) - self.bias
        log_det = -H*W*torch.sum(self.m,dim=0)
        return x,log_det



class Invertible1x1_conv(nn.Module):
    """_summary_

    Args:
        nn (_type_): _description_
    """
    def __init__(self):
        super(Invertible1x1_conv,self).__init__()
        pass
    
    

        

              
        
        
        