## Imports

In [1]:
import torch
from torch import nn
#import torch.nn.functional as F
#import csv
#import pandas as pd

## Model Class

In [2]:
from architecture import VAE

## Training Parameters

In [3]:
# Device init
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

W_DECAY = 5e-3
LEARN_RATE = 5e-6
EPOCH_NUM = 5
BATCH_SIZE = 16

BETA = 1e-2 # for the KL divergence term

## Data Creation

In [None]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
import torchvision.transforms.v2 as transforms

def create_data():
    # ---INITIALIZE DATASET ---
    #Convert pilimage dataset to a standart numpy dataset
    dataset = MNIST(
        root='./data',
        download=True,  # Add this to download the dataset if needed
        transform= transforms.ToTensor()
    )

    TRIM_LEN = int(25_000)  # 60,000 - 30,000 = 30,000 SAMPLES
    TRAIN_PORTION = 0.9 # 90% training 10% everything else
    TRAIN_LEN = int((len(dataset) - TRIM_LEN) * TRAIN_PORTION)
    
    # ---SPLIT DATASET---
    train_ds, test_ds, _ = random_split(
        dataset,  # Split the dataset, not the dataloader!
        [TRAIN_LEN, len(dataset) - TRIM_LEN - TRAIN_LEN, TRIM_LEN]
    )
    #print(f"train length: {len(train_ds)} test_length: {len(test_ds)}")
    
    # ---CREATE DATALOADERS from the split datasets---
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
    
    return train_loader, test_loader

## Training

In [5]:
def train_vae(model, train_loader, optimizer, epoch_num):
    # Method 1: One-liner
    #single_batch = next(iter(train_loader))[0].to(device)
    global BETA
    loss_func = nn.MSELoss()
    

    for iter, single_batch in enumerate(train_loader):
        single_batch = single_batch[0].to(device)
        # ---------Feed Forward---------
        # Extract just the generated images for now
        mean,log_var,img_gen_batch = model.forward(single_batch)

        #---------Back Prop---------
        # Loss is calculated by the batch's mean
        
        flat_sample = torch.flatten(single_batch,start_dim=1)
        img_gen_batch = torch.flatten(img_gen_batch,start_dim=-1)

        kl_div = -0.5 * torch.sum(
            1 + log_var - mean.pow(2) - log_var.exp(),
            dim=1
        ).mean()
        loss = loss_func(img_gen_batch, flat_sample) + BETA*kl_div

        optimizer.zero_grad()
        print(f"batch num {iter}: {loss.item()} at epoch: {epoch_num+1}")
        loss.backward()
        optimizer.step()
        if (BETA<1.0):
            #BETA*=1 + 1e-4
            BETA+=1e-2

        
    torch.save(model.state_dict(), "vae_model.pth")
        

In [6]:
# Initialize the model
my_vae = VAE().to(device)
optim = torch.optim.Adam(params=my_vae.parameters(),
                         lr = LEARN_RATE, weight_decay=W_DECAY)

train_loader,test_loader = create_data()


for i in range(EPOCH_NUM):
    train_vae(model=my_vae, train_loader=train_loader,
            optimizer=optim,epoch_num=i)

  return F.mse_loss(input, target, reduction=self.reduction)


batch num 0: 0.2665865421295166 at epoch: 1
batch num 1: 0.29770058393478394 at epoch: 1
batch num 2: 0.33216890692710876 at epoch: 1
batch num 3: 0.36173737049102783 at epoch: 1
batch num 4: 0.39477598667144775 at epoch: 1
batch num 5: 0.4262789487838745 at epoch: 1
batch num 6: 0.4572809338569641 at epoch: 1
batch num 7: 0.48933619260787964 at epoch: 1
batch num 8: 0.5221700072288513 at epoch: 1
batch num 9: 0.5500831604003906 at epoch: 1
batch num 10: 0.5839390158653259 at epoch: 1
batch num 11: 0.6155115962028503 at epoch: 1
batch num 12: 0.6458249092102051 at epoch: 1
batch num 13: 0.6814471483230591 at epoch: 1
batch num 14: 0.7093165516853333 at epoch: 1
batch num 15: 0.7422301769256592 at epoch: 1
batch num 16: 0.7742789387702942 at epoch: 1
batch num 17: 0.8063235282897949 at epoch: 1
batch num 18: 0.8384933471679688 at epoch: 1
batch num 19: 0.8699488043785095 at epoch: 1
batch num 20: 0.9007471203804016 at epoch: 1
batch num 21: 0.9329231977462769 at epoch: 1
batch num 22: 0

  return F.mse_loss(input, target, reduction=self.reduction)


batch num 5: 3.02345609664917 at epoch: 2
batch num 6: 3.0207715034484863 at epoch: 2
batch num 7: 3.0189101696014404 at epoch: 2
batch num 8: 3.022299289703369 at epoch: 2
batch num 9: 3.0202202796936035 at epoch: 2
batch num 10: 3.0202808380126953 at epoch: 2
batch num 11: 3.01908016204834 at epoch: 2
batch num 12: 3.0195775032043457 at epoch: 2
batch num 13: 3.0199098587036133 at epoch: 2
batch num 14: 3.0169289112091064 at epoch: 2
batch num 15: 3.01562762260437 at epoch: 2
batch num 16: 3.0209193229675293 at epoch: 2
batch num 17: 3.017883062362671 at epoch: 2
batch num 18: 3.017592430114746 at epoch: 2
batch num 19: 3.0166521072387695 at epoch: 2
batch num 20: 3.015714645385742 at epoch: 2
batch num 21: 3.012986183166504 at epoch: 2
batch num 22: 3.0184648036956787 at epoch: 2
batch num 23: 3.0127336978912354 at epoch: 2
batch num 24: 3.015124559402466 at epoch: 2
batch num 25: 3.010100841522217 at epoch: 2
batch num 26: 3.0119376182556152 at epoch: 2
batch num 27: 3.010684251785