In [43]:
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from model import VAE, vae_loss, generate_samples_from_vae

In [44]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 20
model = VAE(latent_dim=latent_dim).to(device)

In [None]:
#train vae on mnist
batch_size=16
num_epochs=300
learning_rate=1e-3
save_model=True

train_dataset = MNIST(root='./data', train=True, transform=ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCELoss(reduction='sum')

model.train()
for epoch in range(num_epochs):
    train_loss = 0
    for batch_idx, (data, _) in tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}"):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = vae_loss(recon_batch, data, mu, logvar, criterion)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss/len(train_loader.dataset):.4f}")

Epoch 1/300: 100%|████████████████████████████████████████████████████████████████| 3750/3750 [00:25<00:00, 147.09it/s]
Epoch 2/300:   0%|▎                                                                 | 16/3750 [00:00<00:24, 150.94it/s]

Epoch 1/300, Train Loss: 123.8666


Epoch 2/300: 100%|████████████████████████████████████████████████████████████████| 3750/3750 [00:24<00:00, 151.50it/s]
Epoch 3/300:   1%|▌                                                                 | 30/3750 [00:00<00:25, 148.26it/s]

Epoch 2/300, Train Loss: 105.7426


Epoch 3/300:  37%|███████████████████████▊                                        | 1398/3750 [00:09<00:15, 150.16it/s]

In [None]:
generate_samples_from_vae(model,device,latent_dim, 10)

In [None]:
if save_model == True:
    torch.save(model.state_dict(), "vae_mnist.pth")