In [5]:
!rm -rf VAE-MedMNIST-Geodesics
!git clone https://github.com/SimoGalva/VAE-MedMNIST-Geodesics.git


Cloning into 'VAE-MedMNIST-Geodesics'...
remote: Enumerating objects: 20, done.[K
remote: Counting objects: 100% (20/20), done.[K
remote: Compressing objects: 100% (16/16), done.[K
remote: Total 20 (delta 4), reused 19 (delta 3), pack-reused 0 (from 0)[K
Receiving objects: 100% (20/20), 7.76 KiB | 7.76 MiB/s, done.
Resolving deltas: 100% (4/4), done.


In [6]:
!pip install -q medmnist matplotlib tqdm

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

import medmnist
from medmnist import INFO

import sys
sys.path.append("/content/VAE-MedMNIST-Geodesics")

from src.vae import VAE, loss as vae_loss

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# dataset
data_flag = 'chestmnist'
info = INFO[data_flag]
DataClass = getattr(medmnist, info['python_class'])

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: torch.round(torch.abs(x), decimals=3))
])

train_dataset = DataClass(split='train', transform=transform, download=True)
train_loader = data.DataLoader(dataset=train_dataset, batch_size=256, shuffle=True, drop_last=True)

In [9]:
vae = VAE((1,28,28), nhid=16).to(device)
optimizer = optim.AdamW(vae.parameters(), lr=1e-5, weight_decay=0.01)

In [10]:
max_epochs = 50  # smaller for notebook demo

for epoch in range(max_epochs):
    vae.train()
    total_loss = 0
    for X, _ in train_loader:
        X = X.to(device)
        X_hat, mean, logvar = vae(X)
        l = vae_loss(X, X_hat, mean, logvar)

        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        total_loss += l.item()

    print(f"epoch {epoch} loss: {total_loss/len(train_loader):.6f}")

epoch 0 loss: 11321.473247
epoch 1 loss: 9400.275986
epoch 2 loss: 7586.530535
epoch 3 loss: 6383.803618
epoch 4 loss: 5692.178277
epoch 5 loss: 5233.969001
epoch 6 loss: 4898.563773
epoch 7 loss: 4660.949688
epoch 8 loss: 4481.834976
epoch 9 loss: 4340.147889
epoch 10 loss: 4225.808307
epoch 11 loss: 4127.807531
epoch 12 loss: 4047.994101
epoch 13 loss: 3976.873838
epoch 14 loss: 3915.935513
epoch 15 loss: 3856.561805
epoch 16 loss: 3799.880040
epoch 17 loss: 3744.262906
epoch 18 loss: 3691.116860
epoch 19 loss: 3643.589228
epoch 20 loss: 3600.152248
epoch 21 loss: 3559.620700
epoch 22 loss: 3523.411173
epoch 23 loss: 3485.394109
epoch 24 loss: 3451.131068
epoch 25 loss: 3411.620888
epoch 26 loss: 3363.156005
epoch 27 loss: 3304.556954
epoch 28 loss: 3249.891077
epoch 29 loss: 3197.800412
epoch 30 loss: 3155.687960
epoch 31 loss: 3113.454907
epoch 32 loss: 3073.611528
epoch 33 loss: 3039.096290
epoch 34 loss: 3002.227452
epoch 35 loss: 2971.143118
epoch 36 loss: 2948.713037
epoch 37 l

In [11]:
torch.save(vae.state_dict(), "VAE.pt")
print("saved VAE.pt")

saved VAE.pt
