## Imports

In [1]:
import torch
from torch import nn
#import torch.nn.functional as F
#import csv
#import pandas as pd

## Model Class

In [None]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE,self).__init__()
        
        
        self.common_fc = nn.Sequential(
            nn.Linear(28*28, out_features=196), nn.Sigmoid(),
            nn.Linear(196, out_features=48), nn.Sigmoid(),
        )
        
        self.mean_fc = nn.Sequential(
            nn.Linear(48, out_features=16), nn.Sigmoid(),
            nn.Linear(16, out_features=2), nn.Sigmoid()
        )
        
        self.log_var_fc = nn.Sequential(
            nn.Linear(48, out_features=16), nn.Sigmoid(),
            nn.Linear(16, out_features=2), nn.Sigmoid()
        )
        
        self.decoder_fcs = nn.Sequential(
            nn.Linear(2, out_features=16), nn.Sigmoid(), 
            nn.Linear(16, out_features=48), nn.Sigmoid(),
            nn.Linear(48, out_features=196), nn.Sigmoid(),
            nn.Linear(196, out_features=28*28), nn.Sigmoid(),
        )
        

    def encode(self,x):
        # B,C,H,W
        flat_x = torch.flatten(x, start_dim=1)
        print(flat_x.shape)
        out = self.common_fc(flat_x)
        mean = self.mean_fc(out)
        log_var = self.log_var_fc(out)
        return mean, log_var


    def sample(self, mean, log_var):
        std = torch.exp(0.5*torch.flatten(log_var, start_dim=-1))
        z = torch.randn_like(torch.flatten(std, start_dim=-1))
        return z * std + mean
    
    
    def decode(self, z):
        out = self.decoder_fcs(z)
        #out = torch.reshape(out, [1, 28*28])
        return out
        
    
    def forward(self, batch_x):
        #B,C,H,W
        outputs = []
        logv_arr = []
        mean_arr = []
        #Encoder
        mean, log_var = self.encode(batch_x)
        #Sampling
        z = self.sample(mean,log_var)
        #Decoder
        logv_arr.append(log_var)
        mean_arr.append(mean)
        outputs.append(self.decode(z))

        mean_arr = torch.stack(mean_arr, dim=0)
        logv_arr = torch.stack(logv_arr, dim=0)
        out = torch.stack(outputs, dim=0)
        return mean_arr, logv_arr, out


    def generate(self):
        n_sample = torch.normal(
            0.,torch.tensor(1,dtype=float)).to(device) # A sample from the standard normal distribution
        gen = self.decode(n_sample)
        return gen

## Training Parameters

In [3]:
# Device init
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

W_DECAY = 1e-4
LEARN_RATE = 1e-5
EPOCH_NUM = 5
BATCH_SIZE = 16

BETA = 1e-1 # for the KL divergence term

## Data Creation

In [4]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
import torchvision.transforms.v2 as transforms

def create_data():
    # ---INITIALIZE DATASET ---
    #Convert pilimage dataset to a standart numpy dataset
    dataset = MNIST(
        root='./data',
        download=True,  # Add this to download the dataset if needed
        transform= transforms.ToTensor()
    )

    TRIM_LEN = int(30_000)  # 60,000 - 30,000 = 10,000 SAMPLES
    TRAIN_PORTION = 0.9 # 90% training 10% everything else
    TRAIN_LEN = int((len(dataset) - TRIM_LEN) * TRAIN_PORTION)
    
    # ---SPLIT DATASET---
    train_ds, test_ds, _ = random_split(
        dataset,  # Split the dataset, not the dataloader!
        [TRAIN_LEN, len(dataset) - TRIM_LEN - TRAIN_LEN, TRIM_LEN]
    )
    #print(f"train length: {len(train_ds)} test_length: {len(test_ds)}")
    
    # ---CREATE DATALOADERS from the split datasets---
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
    
    return train_loader, test_loader

## Training

In [5]:
def train_vae(model, train_loader, optimizer, epoch_num):
    # Method 1: One-liner
    #single_batch = next(iter(train_loader))[0].to(device)
    loss_func = nn.MSELoss()
    

    for iter, single_batch in enumerate(train_loader):
        single_batch = single_batch[0].to(device)
        # ---------Feed Forward---------
        # Extract just the generated images for now
        mean,log_var,img_gen_batch = model.forward(single_batch)

        #---------Back Prop---------
        # Loss is calculated by the batch's mean
        
        flat_sample = torch.flatten(single_batch,start_dim=1)
        print(f"flat_sample shape: {flat_sample.shape}")

        img_gen_batch = torch.flatten(img_gen_batch,start_dim=-1)
        print(f"image_gen_batch shape: {img_gen_batch.shape}")

        kl_div = -0.5 * torch.sum(
            1 + log_var - mean.pow(2) - log_var.exp(),
            dim=1
        ).mean()
        loss = loss_func(img_gen_batch, flat_sample) + BETA*kl_div

        optimizer.zero_grad()
        print(f"batch num {iter}: {loss.item()} at epoch: {epoch_num+1}")
        loss.backward()
        optimizer.step()

        
    torch.save(model.state_dict(), "vae_model.pth")
        

In [6]:
# Initialize the model
my_vae = VAE().to(device)
optim = torch.optim.Adam(params=my_vae.parameters(),
                         lr = LEARN_RATE, weight_decay=W_DECAY)

train_loader,test_loader = create_data()


for i in range(EPOCH_NUM):
    train_vae(model=my_vae, train_loader=train_loader,
            optimizer=optim,epoch_num=i)



torch.Size([16, 784])
flat_sample shape: torch.Size([16, 784])
image_gen_batch shape: torch.Size([1, 16, 784])
batch num 0: 0.6099346876144409 at epoch: 1


  return F.mse_loss(input, target, reduction=self.reduction)


torch.Size([16, 784])
flat_sample shape: torch.Size([16, 784])
image_gen_batch shape: torch.Size([1, 16, 784])
batch num 1: 0.6088641881942749 at epoch: 1
torch.Size([16, 784])
flat_sample shape: torch.Size([16, 784])
image_gen_batch shape: torch.Size([1, 16, 784])
batch num 2: 0.6079088449478149 at epoch: 1
torch.Size([16, 784])
flat_sample shape: torch.Size([16, 784])
image_gen_batch shape: torch.Size([1, 16, 784])
batch num 3: 0.6068063974380493 at epoch: 1
torch.Size([16, 784])
flat_sample shape: torch.Size([16, 784])
image_gen_batch shape: torch.Size([1, 16, 784])
batch num 4: 0.6078454852104187 at epoch: 1
torch.Size([16, 784])
flat_sample shape: torch.Size([16, 784])
image_gen_batch shape: torch.Size([1, 16, 784])
batch num 5: 0.6089862585067749 at epoch: 1
torch.Size([16, 784])
flat_sample shape: torch.Size([16, 784])
image_gen_batch shape: torch.Size([1, 16, 784])
batch num 6: 0.6063881516456604 at epoch: 1
torch.Size([16, 784])
flat_sample shape: torch.Size([16, 784])
image_g

  return F.mse_loss(input, target, reduction=self.reduction)


flat_sample shape: torch.Size([16, 784])
image_gen_batch shape: torch.Size([1, 16, 784])
batch num 28: 0.3188537359237671 at epoch: 2
torch.Size([16, 784])
flat_sample shape: torch.Size([16, 784])
image_gen_batch shape: torch.Size([1, 16, 784])
batch num 29: 0.3302154541015625 at epoch: 2
torch.Size([16, 784])
flat_sample shape: torch.Size([16, 784])
image_gen_batch shape: torch.Size([1, 16, 784])
batch num 30: 0.3225833773612976 at epoch: 2
torch.Size([16, 784])
flat_sample shape: torch.Size([16, 784])
image_gen_batch shape: torch.Size([1, 16, 784])
batch num 31: 0.32387644052505493 at epoch: 2
torch.Size([16, 784])
flat_sample shape: torch.Size([16, 784])
image_gen_batch shape: torch.Size([1, 16, 784])
batch num 32: 0.3176783621311188 at epoch: 2
torch.Size([16, 784])
flat_sample shape: torch.Size([16, 784])
image_gen_batch shape: torch.Size([1, 16, 784])
batch num 33: 0.3253556191921234 at epoch: 2
torch.Size([16, 784])
flat_sample shape: torch.Size([16, 784])
image_gen_batch shape:

Gen Test

In [7]:
number_img_T = my_vae.generate() # The tensor representing the image
print(number_img_T.shape)

transform_img = transforms.ToPILImage(number_img_T)
number_img = transform_img(number_img_T)
number_img.show()


NotImplementedError: "normal_kernel_cpu" not implemented for 'Long'