In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset

import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange
import os
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from pathlib import Path

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
import importlib
try: 
    importlib.reload(utils)
    importlib.reload(generators)
    importlib.reload(trainers)
    print("libs reloaded")
except:
    import utils  
    import generators 
    import trainers
    print("libs imported")
  
Prostate2D = utils.Prostate2D
VAE = generators.VAE
VAEGAN = generators.VAEGAN
TrainerVAE = trainers.TrainerVAE
TrainerVAEGAN = trainers.TrainerVAEGAN

In [None]:
DATA_DIR = '/content/drive/My Drive/Prostate_MRI/'
PROGRESS_DIR_VAE = Path.cwd() / "progress_vae"
PROGRESS_DIR_VAE.mkdir(parents=True, exist_ok=True)

PROGRESS_DIR_VAEGAN = Path.cwd() / "progress_vaegan"
PROGRESS_DIR_VAEGAN.mkdir(parents=True, exist_ok=True)

Z_DIM = 256
BATCH_SIZE = 32
N_EPOCHS = 200
DECAY_LR_AFTER = 150 #50
LEARNING_RATE = 1e-3
OPTIMIZER = lambda parameters, lr : torch.optim.Adam(parameters, lr=lr, betas=(0.,0.9))
KLD_ANNEALING_EPOCHS = 50
GAMMA = 1.0
ADA_TARGET = 0.6
ADA_LENGTH = 10_000
ACCUM = 0.5 ** (32 / (10 * 1000))

def lambda_lr(the_epoch):
    return (
        1.0
        if the_epoch < DECAY_LR_AFTER
        else 1 - float(the_epoch - DECAY_LR_AFTER) / (N_EPOCHS - DECAY_LR_AFTER)
    )

indx = np.array([0,2,3,4,5,6,7,8,9,10,11,12])
train_indx, valid_indx = train_test_split(indx, test_size=3, random_state=0)

train_set = Prostate2D(train_indx, data_dir=DATA_DIR)
valid_set = Prostate2D(valid_indx, data_dir=DATA_DIR)

train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
)

valid_loader = DataLoader(
    valid_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=False,
)

In [None]:
vaegan_model = VAEGAN(z_dim=Z_DIM, l=2, spade=True).to(device)
net_ema = VAEGAN(z_dim=Z_DIM, l=2, spade=True).to(device)
net_ema.eval()
optimizer_enc = OPTIMIZER(vaegan_model.encoder.parameters(), lr=LEARNING_RATE)
optimizer_gen = OPTIMIZER(vaegan_model.generator.parameters(), lr=LEARNING_RATE)
optimizer_disc = OPTIMIZER(vaegan_model.discriminator.parameters(), lr=LEARNING_RATE)
vaegan_trainer = TrainerVAEGAN(
    vaegan_model,
    optimizer_enc,
    optimizer_gen,
    optimizer_disc,
    KLD_ANNEALING_EPOCHS,
    PROGRESS_DIR_VAEGAN,
    train_loader,
    valid_loader,
    net_ema = net_ema,
    accum = ACCUM,
    ada_target = ADA_TARGET,
    ada_length = ADA_LENGTH,
    gamma = GAMMA,
    device = device,
    seed = 0,
)

In [None]:
vaegan_trainer.train(N_EPOCHS, lambda_lr=lambda_lr)

In [None]:
plt.plot(range(1,N_EPOCHS+1), vaegan_trainer.ada_p_log)
plt.plot(range(1,N_EPOCHS+1), vaegan_trainer.valid_losses)
plt.xlabel('Number of epochs')
plt.ylabel('Reconstruction loss')
plt.title('Reconstruction loss mask vae')
plt.legend(['Training set', 'Validation set'])
plot_name = 'RECON_LOSS'+str(N_EPOCHS) + 'epochs_' + str(Z_DIM) + '_latentdim.png'
path = EXPERIMENT_DIR / plot_name

plt.savefig(path, dpi=200)