In [6]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from physicsnemo.models.mlp.fully_connected import FullyConnected

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
input_size = 512  # Specify here
hidden_size = 512 # Specify here
latent_size = 256 # Specify here

In [15]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.Lin1 = nn.Linear(input_size, hidden_size) # 784 -> 128
        
        self.act = nn.ReLU() #Just an activation function

                                                #We want to multiply by 2 since we want to have mean an covariance
        self.outp = nn.Linear(hidden_size,latent_size*2) # 128 -> 20x2
        
    def forward(self, x):
        x = self.Lin1(x)
        x = self.act(x)
        x = self.outp(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.Lin1 = nn.Linear(latent_size,hidden_size) # 20 -> 128
        self.act = nn.ReLU()
        self.outp = nn.Linear(hidden_size, input_size) # 128 -> 784
        
    def forward(self, x):
        x = self.Lin1(x)
        x = self.act(x)
        x = self.outp(x)
        return x


In [16]:
def ELBO_loss(y_hat,y,mu,logvar,beta = 1):
    BCE = nn.functional.binary_cross_entropy_with_logits(y_hat,y.view(-1, 784), reduction="sum")
    KLD = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))
    return BCE + KLD*beta

In [17]:
class VAEPL(pl.LightningModule):
    def __init__(self,encoderkwgs,decoderkwgs):
        super(VAEPL,self).__init__()
        self.encoder = FullyConnected(**encoderkwgs)
        self.decoder = FullyConnected(**decoderkwgs)
    
    def reparameterise(self,mu,log_var,mode: str = "train"):
        if mode == "train":
            std = torch.exp(0.5*log_var)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu
        
    def shared_step(self, x,mode: str = "test"):
        mu, log_var = torch.chunk(self.encoder(x.view(-1,input_size)), 2, dim=1)
        z = self.reparameterise(mu, log_var,mode)
        x = self.decoder(z)
        return x, mu, log_var
    
    def train_step(self,x,batch_idx):
        x_hat, mu, log_var = self(x,mode = 'train')
        loss = self.lossF(x_hat, x, mu, log_var)
        self.log("Train",loss)

    def train_step(self,x,batch_idx):
        x_hat, mu, log_var = self(x,mode = 'validation')
        loss = self.lossF(x_hat, x, mu, log_var)
        self.log("validation",loss)

    def generate(self,N):
        z = torch.randn(N,latent_size).to(self.device)
        x = self.decoder(z)
        return  x

In [18]:
encoderkwgs  = {"in_features":input_size,"out_features":latent_size*2}
decoderkwgs = {"in_features":latent_size,"out_features":input_size}

In [19]:
model = VAEPL(encoderkwgs=encoderkwgs,decoderkwgs=decoderkwgs)

In [13]:
model.encoder

FullyConnected(
  (layers): ModuleList(
    (0-5): 6 x FCLayer(
      (activation_fn): SiLU()
      (linear): Linear(in_features=512, out_features=512, bias=True)
    )
  )
  (final_layer): FCLayer(
    (activation_fn): Identity()
    (linear): Linear(in_features=512, out_features=512, bias=True)
  )
)

In [20]:
model.decoder

FullyConnected(
  (layers): ModuleList(
    (0): FCLayer(
      (activation_fn): SiLU()
      (linear): Linear(in_features=256, out_features=512, bias=True)
    )
    (1-5): 5 x FCLayer(
      (activation_fn): SiLU()
      (linear): Linear(in_features=512, out_features=512, bias=True)
    )
  )
  (final_layer): FCLayer(
    (activation_fn): Identity()
    (linear): Linear(in_features=512, out_features=512, bias=True)
  )
)