In [1]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from monai.networks.layers.factories import Act, Norm
import monai
from losses import *

bs = 128
transforms = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((32,32)),
            ]
        )
# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms, download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

  from .autonotebook import tqdm as notebook_tqdm
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [2]:
config = {
        "lr": 0.00001,
        "latent_dim": 256,
        "kernel_size": 3,
        "dropout_rate": 0.1,
        "alpha": 1,
        "beta": 0.01,
        "norm": Norm.INSTANCE,
         "batch_size": 256,
        
            "val": 3,
            "channel": (32, 64, 64),
            "stride": (1, 2, 4),
            # "resnet_units_batch" : hp.choice("res6", res_d6),

        "num_resnets":  0,
           
        
    }

In [3]:
model = monai.networks.nets.VarAutoEncoder(
    dimensions=2,  
    kernel_size=config["kernel_size"],
    in_shape=[1, 32,32],
    out_channels=1,
    channels=config["channel"],
    strides=config["stride"],
    latent_size=config["latent_dim"],
    norm=config["norm"],
    dropout=config["dropout_rate"],
    num_res_units=config["num_resnets"],
)


In [4]:
if torch.cuda.is_available():
    model.cuda()

In [5]:
optimizer = optim.Adam(model.parameters())
dice = Dice()
loss_function = KLLoss(alpha=config["alpha"], beta=config["beta"])

In [6]:
for epoch in range(5):
    model.train()
    train_loss = 0
    train_dice = 0
    train_kl = 0
    
    
    for batch_idx, (x, _) in enumerate(train_loader):
        x = x.cuda()
        y = x.cuda()
        optimizer.zero_grad()
        #out is a tuple of (recon_batch, mu, logvar, z)
        out = model(x)
        loss, kl = loss_function(out, y)
        recon_batch = out[0]
        loss.backward()
        batch_dice = dice(recon_batch, y)
        train_loss += loss.item()
        train_dice += batch_dice.item()
        train_kl += kl.item()
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}\t Dice {:.4f}\t KL {:.4f}'.format(
                epoch, batch_idx * len(x), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(x), batch_dice.item() / len(x), kl.item()*config["latent_dim"] / len(x)))
    print('====> Epoch: {} Average loss: {:.4f}, mean Dice: {:.4f}, mean KL: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset), train_dice / len(train_loader.dataset), 
    train_kl*config["latent_dim"]/len(train_loader.dataset)))


====> Epoch: 0 Average loss: 0.0043, mean Dice: 0.0074, mean KL: 0.0142
====> Epoch: 1 Average loss: 0.0036, mean Dice: 0.0078, mean KL: 0.0140
====> Epoch: 2 Average loss: 0.0035, mean Dice: 0.0079, mean KL: 0.0134
====> Epoch: 3 Average loss: 0.0035, mean Dice: 0.0079, mean KL: 0.0127
====> Epoch: 4 Average loss: 0.0035, mean Dice: 0.0079, mean KL: 0.0120
