In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import os
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.model_selection import train_test_split
import numpy as np 
from torchvision.utils import make_grid
import imageio

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [3]:
import importlib
try: 
    importlib.reload(utils)
    importlib.reload(generators)
    importlib.reload(trainers)
    print("libs reloaded")
except:
    import utils  
    import generators 
    import trainers
    print("libs imported")
  
#Prostate2D = utils.Prostate2D
VAE = generators.VAE
VAEGAN = generators.VAEGAN
TrainerVAEGAN = trainers.TrainerVAEGAN

libs imported


In [5]:
WORKING_DIR = Path(r"C:\Users\marti\OneDrive - TU Eindhoven\Documenten\Master\Q3\Capita Selecta\Project")
DATA_DIR = WORKING_DIR / "Data"

PROGRESS_DIR_VAEGAN = WORKING_DIR / "progress_vaegan"
PROGRESS_DIR_VAEGAN.mkdir(parents=True, exist_ok=True)

N = 15
N_train = 15-3
Z_DIM = 256
BATCH_SIZE = 32
N_EPOCHS = 200
DECAY_LR_AFTER = 50
LEARNING_RATE = 1e-3
KLD_ANNEALING_EPOCHS = 50
GAMMA = 1.0
ADA_TARGET = 0.6
ADA_LENGTH = 10_000
ACCUM = 0.5 ** (32 / (10 * 1000))
NO_VALIDATION_PATIENTS = 2
IMAGE_SIZE = (64,64)
TOLERANCE = -1e-8
MINIMUM_VALID_LOSS = 10 
seed=0

exp = str(N_EPOCHS) + '_epochs_' + str(Z_DIM) + "_zdim" + "_ema"
EXPERIMENT_DIR_VAEGAN = PROGRESS_DIR_VAEGAN / exp
EXPERIMENT_DIR_VAEGAN.mkdir(parents=True, exist_ok=True)

def lambda_lr(the_epoch):
    return (
        1.0
        if the_epoch < DECAY_LR_AFTER
        else 1 - float(the_epoch - DECAY_LR_AFTER) / (N_EPOCHS - DECAY_LR_AFTER)
    )

patients = [
    path
    for path in DATA_DIR.glob("*")
    if not any(part.startswith(".") for part in path.parts)
]

train_indx, valid_indx = train_test_split(patients, random_state=seed, train_size=N_train)

partition = {
    "train": train_indx,
    "validation": valid_indx,
}

# load training data and create DataLoader with batching and shuffling
train_set = utils.ProstateMRDataset(partition["train"], IMAGE_SIZE)
train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
)

# load validation data
valid_set = utils.ProstateMRDataset(partition["validation"], IMAGE_SIZE)
valid_loader = DataLoader(
    valid_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
)

In [6]:
OPTIMIZER = lambda parameters, lr : torch.optim.RMSprop(
    parameters,
    lr=lr,
    alpha=0.9,
    eps=1e-8,
    weight_decay=0,
    momentum=0,
    centered=False
)

vaegan_model = VAEGAN(z_dim=Z_DIM, l=2, spade=True).to(device)
net_ema = VAEGAN(z_dim=Z_DIM, l=2, spade=True).to(device)
net_ema.eval()

optimizer_enc = OPTIMIZER(vaegan_model.encoder.parameters(), lr=LEARNING_RATE)
optimizer_gen = OPTIMIZER(vaegan_model.generator.parameters(), lr=LEARNING_RATE)
optimizer_disc = OPTIMIZER(vaegan_model.discriminator.parameters(), lr=LEARNING_RATE)

vaegan_trainer = TrainerVAEGAN(
    net=vaegan_model,
    optimizer_enc=optimizer_enc,
    optimizer_gen=optimizer_gen,
    optimizer_disc=optimizer_disc,
    kld_annealing_epochs=KLD_ANNEALING_EPOCHS,
    progress_dir=EXPERIMENT_DIR_VAEGAN,
    train_loader=train_loader,
    valid_loader=valid_loader,
    CHECKPOINTS_DIR=EXPERIMENT_DIR_VAEGAN,
    TOLERANCE=TOLERANCE, 
    minimum_valid_loss=MINIMUM_VALID_LOSS,
    net_ema = None,
    accum = ACCUM,
    ada_target = ADA_TARGET,
    ada_length = ADA_LENGTH,
    gamma = GAMMA,
    device = device,
    seed = 0,
    early_stopping=False
)

In [7]:
vaegan_trainer.train(N_EPOCHS, lambda_lr=lambda_lr)

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Epoch #000: Rec_Loss/train = 0.643, KLD/train = 133.827, Discl_Loss/train = 1.634, Adv_Loss/train = 0.972 | Rec_Loss/valid = 0.469, KLD/valid = 10.536, Discl_Loss/valid = 1.470, Adv_Loss/valid = 0.042


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Epoch #001: Rec_Loss/train = 0.490, KLD/train = 10.979, Discl_Loss/train = 1.533, Adv_Loss/train = 0.430 | Rec_Loss/valid = 0.454, KLD/valid = 5.438, Discl_Loss/valid = 1.604, Adv_Loss/valid = 0.169


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Epoch #002: Rec_Loss/train = 0.497, KLD/train = 5.404, Discl_Loss/train = 1.525, Adv_Loss/train = 0.363 | Rec_Loss/valid = 0.478, KLD/valid = 2.869, Discl_Loss/valid = 1.487, Adv_Loss/valid = 0.742


  0%|          | 0/32 [00:00<?, ?it/s]

In [None]:
plt.plot(range(1,len(vaegan_trainer.train_losses[0])+1), vaegan_trainer.train_losses[0])
plt.plot(range(1,len(vaegan_trainer.train_losses[0])+1), vaegan_trainer.valid_losses[0])
plt.xlabel('Number of epochs')
plt.ylabel('Reconstruction loss')
plt.title('Reconstruction loss vaegan')
plt.legend(['Training set', 'Validation set'])
plot_name = 'RECON_LOSS_'+str(N_EPOCHS) + '_epochs_' + str(Z_DIM) + '_zdim.png'
path = EXPERIMENT_DIR_VAEGAN / plot_name

plt.savefig(path, dpi=200)

In [None]:
plt.plot(range(1,len(vaegan_trainer.train_losses[1])+1), vaegan_trainer.train_losses[1])
plt.plot(range(1,len(vaegan_trainer.train_losses[1])+1), vaegan_trainer.valid_losses[1])
plt.xlabel('Number of epochs')
plt.ylabel('KLD loss')
plt.title('KLD loss vaegan')
plt.legend(['Training set', 'Validation set'])
plot_name = 'KLD_LOSS_'+str(N_EPOCHS) + '_epochs_' + str(Z_DIM) + '_zdim.png'
path = EXPERIMENT_DIR_VAEGAN / plot_name

plt.savefig(path, dpi=200)

In [None]:
plt.plot(range(1,len(vaegan_trainer.train_losses[2])+1), vaegan_trainer.train_losses[2])
plt.plot(range(1,len(vaegan_trainer.train_losses[2])+1), vaegan_trainer.valid_losses[2])
plt.xlabel('Number of epochs')
plt.ylabel('Discl_Loss')
plt.title('Discl_Loss vaegan')
plt.legend(['Training set', 'Validation set'])
plot_name = 'Discl_Loss_'+str(N_EPOCHS) + '_epochs_' + str(Z_DIM) + '_zdim.png'
path = EXPERIMENT_DIR_VAEGAN / plot_name

plt.savefig(path, dpi=200)

In [None]:
plt.plot(range(1,len(vaegan_trainer.train_losses[3])+1), vaegan_trainer.train_losses[3])
plt.plot(range(1,len(vaegan_trainer.train_losses[3])+1), vaegan_trainer.valid_losses[3])
plt.xlabel('Number of epochs')
plt.ylabel('Adv_Loss')
plt.title('Adv_Loss vaegan')
plt.legend(['Training set', 'Validation set'])
plot_name = 'Adv_Loss_'+str(N_EPOCHS) + '_epochs_' + str(Z_DIM) + '_zdim.png'
path = EXPERIMENT_DIR_VAEGAN / plot_name

plt.savefig(path, dpi=200)

In [None]:
seed = 0

vaegan_model = VAEGAN(z_dim=Z_DIM, l=2, spade=True).to(device)

model_dir = EXPERIMENT_DIR_VAEGAN / "model.pth"
vaegan_model.load_state_dict(torch.load(model_dir))
vaegan_model.eval()

patients = [
    path
    for path in DATA_DIR.glob("*")
    if not any(part.startswith(".") for part in path.parts)
]

train_size = 13 

train_indx, valid_indx = train_test_split(patients, random_state=seed, train_size=train_size)

partition = {
    "train": train_indx,
    "validation": valid_indx,
}

training_set = utils.ProstateMRDataset(partition["train"], IMAGE_SIZE, empty_masks = False)
train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
)

first_indx = np.random.choice(np.arange(len(train_loader.dataset)), size=5, replace=False)
second_indx = first_indx+10
for i in range(len(second_indx)):
    if second_indx[i] > len(train_loader.dataset): 
        second_indx[i] = len(train_loader.dataset)-1

first_image_t , first_mask_t = train_loader.dataset[first_indx]
second_image_t, second_mask_t = train_loader.dataset[second_indx]
first_image_t = first_image_t.to(device)
first_mask_t = first_mask_t.to(device)
second_image_t = second_image_t.to(device)
second_mask_t = second_mask_t.to(device)


mu_first, logvar_first = vaegan_model.encoder(first_image_t)
latent_z_first = utils.sample_z(mu_first, logvar_first)
            
mu_second, logvar_second = vaegan_model.encoder(second_image_t)
latent_z_second = utils.sample_z(mu_second, logvar_second)


In [None]:
weight = 0.9
gen_latent_z = latent_z_first + weight*(latent_z_second - latent_z_first)
generations  = vaegan_model.generator(gen_latent_z, first_mask_t)

img_grid = make_grid(
                torch.cat([
                    first_image_t.cpu(),
                    generations.cpu(),
                    second_image_t.cpu()
                ]), 
                nrow=5, 
                padding=12, 
                pad_value=-1, )
    
plt.imshow(img_grid.numpy()[0] / 2.0 + 0.5)
plt.imsave(EXPERIMENT_DIR_VAEGAN / f"interpolated_generations_{weight}.png", img_grid.numpy()[0] / 2.0 + 0.5)

In [None]:
first_indx = np.random.choice(np.arange(len(train_loader.dataset)), size=1, replace=False)
second_indx = first_indx+10

first_image_t , first_mask_t = train_loader.dataset[first_indx]
second_image_t, second_mask_t = train_loader.dataset[second_indx]
first_image_t = first_image_t.to(device)
first_mask_t = first_mask_t.to(device)
second_image_t = second_image_t.to(device)
second_mask_t = second_mask_t.to(device)


mu_first, logvar_first = vaegan_model.encoder(first_image_t)
latent_z_first = utils.sample_z(mu_first, logvar_first)
            
mu_second, logvar_second = vaegan_model.encoder(second_image_t)
latent_z_second = utils.sample_z(mu_second, logvar_second)

images = []
for i in range(0,21,1):
    gen_latent_z = latent_z_first + (i/20)*(latent_z_second - latent_z_first)
    generations  = vaegan_model.generator(gen_latent_z, first_mask_t) 
    
    images.append(np.array(generations.squeeze().detach().cpu()))

    
imageio.mimsave(EXPERIMENT_DIR_VAEGAN / 'generations.gif', images, fps=10)

#did not know how to display a gif in jupyter notebook so the code below is from ChatGTP 
from IPython.display import HTML
from base64 import b64encode

gif = open(EXPERIMENT_DIR_VAEGAN / 'generations.gif','rb').read()
#HTML(f"<img src='data:image/gif;base64,{b64encode(gif).decode('utf-8')}' />")
HTML(f"<img src='data:image/gif;base64,{b64encode(gif).decode('utf-8')}' style='width:600px;height:450px;'/>")