In [7]:
import numpy as np
import neptune
import random
import torch

from torch.utils.data import DataLoader
from torchinfo import summary
from torchvision.utils import make_grid
from pathlib import Path
from tqdm import tqdm

import utils
import vae

# to ensure reproducible training/validation split
random.seed(41)

# find out if a GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

# directorys with data and to store training checkpoints and logs
DATA_DIR = Path.cwd().parent.parent / "DevelopmentData"
CHECKPOINTS_DIR = Path.cwd() / "vae_model_weights"

In [3]:
# training settings and hyperparameters
NO_VALIDATION_PATIENTS = 2
IMAGE_SIZE = [64, 64]
BATCH_SIZE = 32
N_EPOCHS = 200
DECAY_LR_AFTER = 50
LEARNING_RATE = 1e-4
DISPLAY_FREQ = 10

# dimension of VAE latent space
Z_DIM = 256

# function to reduce the learning rate
def lr_lambda(the_epoch):
    """Function for scheduling learning rate"""
    return (
        1.0
        if the_epoch < DECAY_LR_AFTER
        else 1 - float(the_epoch - DECAY_LR_AFTER) / (N_EPOCHS - DECAY_LR_AFTER)
    )

# find patient folders in training directory
# excluding hidden folders (start with .)
patients = [
    path
    for path in DATA_DIR.glob("*")
    if not any(part.startswith(".") for part in path.parts)
]
random.shuffle(patients)

# split in training/validation after shuffling
partition = {
    "train": patients[:-NO_VALIDATION_PATIENTS],
    "validation": patients[-NO_VALIDATION_PATIENTS:],
}

# load training data and create DataLoader with batching and shuffling
dataset = utils.ProstateMRDataset(partition["train"], IMAGE_SIZE, valid=True) # in my experiments the augmentations
# did not help, so I set valid=True to disable them
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
)

# load validation data
valid_dataset = utils.ProstateMRDataset(partition["validation"], IMAGE_SIZE, valid=True)
valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
)

# initialise model, optimiser
vae_model = vae.VAE().to(device)
optimizer = torch.optim.Adam(vae_model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lr_lambda)

In [9]:
summary(vae_model, input_size=(BATCH_SIZE, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]))

Layer (type:depth-idx)                   Output Shape              Param #
VAE                                      [32, 1, 64, 64]           --
├─Encoder: 1-1                           [32, 256]                 --
│    └─ModuleList: 2-5                   --                        (recursive)
│    │    └─Block: 3-1                   [32, 64, 64, 64]          37,824
│    └─MaxPool2d: 2-2                    [32, 64, 32, 32]          --
│    └─ModuleList: 2-5                   --                        (recursive)
│    │    └─Block: 3-2                   [32, 128, 32, 32]         221,952
│    └─MaxPool2d: 2-4                    [32, 128, 16, 16]         --
│    └─ModuleList: 2-5                   --                        (recursive)
│    │    └─Block: 3-3                   [32, 256, 16, 16]         886,272
│    └─MaxPool2d: 2-6                    [32, 256, 8, 8]           --
│    └─Sequential: 2-7                   [32, 512]                 --
│    │    └─Flatten: 3-4                 [32

In [6]:
# Initialize Neptune experiment
run = neptune.init_run(
    project="Capita-Selecta-Group-8/vae-baseline",
    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJkN2VlMjY3Mi00OWU0LTQxMzUtOWJlNS0wNzM3OGIyNzk3ZWUifQ==",
)

# Track hyperparameters
run["parameters"] = {
    "lr": LEARNING_RATE,
    "bs": BATCH_SIZE,
    "epochs": N_EPOCHS,
    "input_sz": IMAGE_SIZE[0] * IMAGE_SIZE[1],
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu")
}

# training loop
for epoch in range(N_EPOCHS):
    current_train_loss = 0.0
    current_valid_loss = 0.0
    
    # training iterations
    for x_real, y_real in tqdm(dataloader, position=0):
        # needed to zero gradients in each iterations
        optimizer.zero_grad()
        output, mu, logvar = vae_model(x_real.to(device))  # forward pass
        loss = vae.vae_loss(output, y_real.to(device).float(), mu, logvar)
        loss.backward()  # backpropagate loss
        current_train_loss += loss.item()
        optimizer.step()  # update weights

    scheduler.step() # step the learning step scheduler
    
    # evaluate validation loss
    with torch.no_grad():
        vae_model.eval()
        for x_real, y_real in tqdm(valid_dataloader, position=0):
            output, mu, logvar = vae_model(x_real.to(device))  # forward pass
            loss = vae.vae_loss(output, y_real.to(device).float(), mu, logvar)
            current_valid_loss += loss.item()
        
        # write to neptune log
        run["train/loss"].append(current_train_loss)
        run["valid/loss"].append(current_train_loss)

        # save examples of real/fake images
        if (epoch + 1) % DISPLAY_FREQ == 0:
            x_recon = output
            img_grid = make_grid(
                torch.cat((x_recon[:5].cpu(), x_real[:5].cpu())), nrow=5, padding=12, pad_value=-1
            )
            run["images/real_fake"].log(
                neptune.types.File.as_image((np.clip(img_grid[0][np.newaxis], -1, 1) / 2 + 0.5).squeeze()), 
                description=f"Real_fake (Epoch {epoch + 1})"
            )
        
            noise = vae.get_noise(10, Z_DIM, device)
            image_samples = vae_model.generator(noise)
            img_grid = make_grid(
                torch.cat((image_samples[:5].cpu(), image_samples[5:].cpu())),
                nrow=5,
                padding=12,
                pad_value=-1,
            )
            run["images/samples"].log(
                neptune.types.File.as_image((np.clip(img_grid[0][np.newaxis], -1, 1) / 2 + 0.5).squeeze()), 
                description=f"Samples (Epoch {epoch + 1})"
            )

        vae_model.train()

weights_dict = {k: v.cpu() for k, v in vae_model.state_dict().items()}
torch.save(
    weights_dict,
    CHECKPOINTS_DIR / "vae_model.pth",
)



[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/Capita-Selecta-Group-8/vae-baseline/e/VAE-45


        Convert the value to a supported type, such as a string or float, or use stringify_unsupported(obj)
        for dictionaries or collections that contain unsupported values.
        For more, see https://docs.neptune.ai/help/value_of_unsupported_type
100%|██████████| 34/34 [00:12<00:00,  2.64it/s]
100%|██████████| 5/5 [00:00<00:00,  5.13it/s]
100%|██████████| 34/34 [00:11<00:00,  2.93it/s]
100%|██████████| 5/5 [00:01<00:00,  3.17it/s]
100%|██████████| 34/34 [00:11<00:00,  2.96it/s]
100%|██████████| 5/5 [00:00<00:00,  5.28it/s]
100%|██████████| 34/34 [00:11<00:00,  2.97it/s]
100%|██████████| 5/5 [00:00<00:00,  5.30it/s]
100%|██████████| 34/34 [00:11<00:00,  3.00it/s]
100%|██████████| 5/5 [00:00<00:00,  5.58it/s]
100%|██████████| 34/34 [00:11<00:00,  2.96it/s]
100%|██████████| 5/5 [00:00<00:00,  5.19it/s]
100%|██████████| 34/34 [00:11<00:00,  3.00it/s]
100%|██████████| 5/5 [00:00<00:00,  5.55it/s]
100%|██████████| 34/34 [00:11<00:00,  3.00it/s]
100%|██████████| 5/5 [00:01<00:00,  

KeyboardInterrupt: 