In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader

In [None]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE,self).__init__()
        
        
        self.common_fc = nn.Sequential(
            nn.Linear(28*28, out_features=196), nn.Tanh(),
            nn.Linear(196, out_features=48), nn.Tanh(),
        )
        
        self.mean_fc = nn.Sequential(
            nn.Linear(48, out_features=16), nn.Tanh(),
            nn.Linear(16, out_features=2), nn.Tanh()
        )
        
        self.log_var_fc = nn.Sequential(
            nn.Linear(48, out_features=16), nn.Tanh(),
            nn.Linear(16, out_features=2), nn.Tanh()
        )
        
        self.decoder_fcs = nn.Sequential(
            nn.Linear(2, out_features=16), nn.Tanh(), 
            nn.Linear(16, out_features=48), nn.Tanh(),
            nn.Linear(48, out_features=196), nn.Tanh(),
            nn.Linear(196, out_features=28*28), nn.Tanh(),
        )
        
    def encode(self,x):
        # B,C,H,W
        out = self.common_fc(torch.flatten(x, start_dim=1))
        mean = self.mean_fc(out)
        log_var = self.log_var_fc(out)
        return mean, log_var

    def sample(self, mean, log_var):
        std = torch.exp(0.5*log_var)
        z = torch.randn_like(std)
        return z*std + mean
    
    def decode(self, z):
        out = self.decoder_fcs(z)
        out = out.reshape(z.shape(0), 1, 28*28)
        return out
        
    
    def forward(self, x):
        #B,C,H,W
        #Encoder
        mean, log_var = self.encode(x)
        #Sampling
        z = self.sample(mean,log_var)
        #Decoder
        out = self.decode(z)
        return mean, log_var, out


In [None]:
#training parameters


In [None]:
from torchvision.datasets import MNIST


def create_data():
    # Create the dataset and the dataloader
    train_dataset = MNIST(
        root='./data'
    )
    dload_mnist = DataLoader()
    pass


def train_vae():
    
    pass



In [None]:

if __name__ == 'main':
    # Initialize the model

    # Train for 5 epochs
    train_vae()