<a href="https://colab.research.google.com/github/Rohit1217/Flow/blob/main/Flow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import matplotlib.pyplot as plt
import numpy as np

import torch
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset,DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [41]:
class AffineCoupling(nn.Module):
  def __init__(self,input_layer):
    super(AffineCoupling,self).__init__()
    self.inp=input_layer//2
    self.fc1=nn.Linear(self.inp,self.inp)

  def coupling_shuffler(self,x):
    b,c=x.shape
    layer=torch.rand(b,c)
    for i in range(c):
      if i%2==0:
        layer[:,c//2+i//2]=x[:,i]
      else:
        layer[:,i//2]=x[:,i]
    return layer

  def inverse(self,x):
    b,c=x.shape
    x0=x[:,:c//2]
    x1=x[:,c//2:]
    x1=x1-F.relu(self.fc1(x0))
    x=torch.cat((x0, x1), 1)
    layer=torch.rand(b,c)
    for i in range(c):
      if i>=c//2:
        layer[:,2*i-c]=x[:,i]
      else:
        layer[:,i*2+1]=x[:,i]
    return layer


  def forward(self,x):
    b,c=x.shape
    x=self.coupling_shuffler(x)
    x0=x[:,:c//2]
    x1=x[:,c//2:]
    x1=x1+F.relu(self.fc1(x0))
    x=torch.cat((x0, x1), 1)
    return x

x=torch.rand(1,6)
print(x)
aff=AffineCoupling(6)
y=aff(x)
print(aff(x))
aff.inverse(y)


tensor([[0.1609, 0.6871, 0.6674, 0.0603, 0.2123, 0.6034]])
tensor([[0.6871, 0.0603, 0.6034, 0.3910, 0.6674, 0.2123]],
       grad_fn=<CatBackward0>)


tensor([[0.1609, 0.6871, 0.6674, 0.0603, 0.2123, 0.6034]],
       grad_fn=<CopySlices>)

In [43]:
class Flow(nn.Module):
  def __init__(self,input_layer):
    super(Flow,self).__init__()
    self.inp=input_layer
    self.ac1=AffineCoupling(input_layer)
    self.ac2=AffineCoupling(input_layer)
    self.ac3=AffineCoupling(input_layer)
    self.ac4=AffineCoupling(input_layer)
    self.fc5=nn.Linear(input_layer,input_layer)

  def inverse(self,x):
    #x=torch.rand(self.inp)
    x= torch.log((x/(1-x+1e-8))+1e-8)
    x=self.ac4.inverse(x)
    x=self.ac3.inverse(x)
    x=self.ac2.inverse(x)
    x=self.ac1.inverse(x)
    return x


  def forward(self,x):
    x=self.ac1(x)
    x=self.ac2(x)
    x=self.ac3(x)
    x=self.ac4(x)
    x=F.sigmoid(x)
    return x

x=torch.rand(1,6)
print(x)
flow=Flow(6)
y=flow(x)
x=flow.inverse(y)
print(y,x)


tensor([[0.7808, 0.0852, 0.3202, 0.5240, 0.2204, 0.2721]])
tensor([[0.5213, 0.6281, 0.7360, 0.6859, 0.6881, 0.5759]],
       grad_fn=<SigmoidBackward0>) tensor([[0.7808, 0.0852, 0.3202, 0.5240, 0.2204, 0.2721]],
       grad_fn=<CopySlices>)
