In [None]:
## import modules
import os
import importlib 
import numpy as np 

import matplotlib
import matplotlib.pyplot as plt

import time

import torch
import torchvision
from torchvision.transforms import v2

In [None]:
## import customized modules
import Modules.Models.SimpleGaussianVAE as SimpleGaussianVAE
from Modules.Data import Transforms
from Modules.Data import Dataset
from Modules.TrainAndValidate import Loss
from Modules.TrainAndValidate import TrainAndValidate

In [None]:
## define transforms for supervised learning raw data 
train_rawdata_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
])
validate_rawdata_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
])

In [None]:
## define transforms for autoencoder data
train_feature_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
    Transforms.Reshape((-1,))
])
train_target_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
    Transforms.Reshape((-1,))
])

validate_feature_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
    Transforms.Reshape((-1,))
])
validate_target_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
    Transforms.Reshape((-1,))
])

In [None]:
## load supervised learning raw data set

# src_dataset_file_path may need to change according to your computer file path
src_dataset_file_path = r"E:\Python\DataSet\TorchDataSet\MNIST" 

train_rawdata = torchvision.datasets.MNIST(
    root = src_dataset_file_path,
    train = True,
    download = True,
    transform = train_rawdata_transform,
)

validate_rawdata = torchvision.datasets.MNIST(
    root = src_dataset_file_path,
    train = False,
    download = True,
    transform = validate_rawdata_transform,
)

rawdata_size = validate_rawdata[0][0].size()
print(f"raw_data size: {rawdata_size}")

In [None]:
## create autoencoder data set from raw data set

train_data = Dataset.AutoencoderDataset(
    train_rawdata,
    feature_transform = train_feature_transform,
    target_transform = train_target_transform,
)

validate_data = Dataset.AutoencoderDataset(
    validate_rawdata,
    feature_transform = validate_feature_transform,
    target_transform = validate_target_transform,
)

print(f"train_data length: {len(train_data)}")
print(f"validate_data length: {len(validate_data)}")

In [None]:
# create data loader

train_batch_size = 128
validate_batch_size = 128

train_dataloader = torch.utils.data.DataLoader(train_data, 
                                               batch_size = train_batch_size, 
                                               shuffle = True)

validate_dataloader = torch.utils.data.DataLoader(validate_data, 
                                                   batch_size = validate_batch_size, 
                                                   shuffle = True)

In [None]:
## take a look of the data
check_data_idx = 0
check_dataloader = train_dataloader

data_vec_to_image = Transforms.Reshape(rawdata_size)

check_features, check_targets = next(iter(check_dataloader))

print(f"Feature batch shape: {check_features.size()}")
print(f"Target batch shape: {check_targets.size()}")

check_feature = data_vec_to_image(check_features[check_data_idx]).squeeze().numpy()
check_target = data_vec_to_image(check_targets[check_data_idx]).squeeze().numpy()

plt.figure()

plt.subplot(1,2,1)
plt.imshow(check_feature, cmap = "gray")
plt.xticks([])
plt.yticks([])
plt.colorbar()
plt.title(f"Feature, idx = {check_data_idx}")

plt.subplot(1,2,2)
plt.imshow(check_target, cmap = "gray")
plt.xticks([])
plt.yticks([])
plt.colorbar()
plt.title(f"Target, idx = {check_data_idx}")

plt.tight_layout()
plt.show()

In [None]:
## create model
importlib.reload(SimpleGaussianVAE)

prev_features, prev_targets = next(iter(train_dataloader))
nof_features = prev_features.size(-1)

latent_dim = 2

## create encoder model
encoder_prev_layer_descriptors = [
    {"nof_layers": 1, "in_features": nof_features, "out_features": 400, "activation": torch.nn.LeakyReLU},
]
encoder_gaussparam_layer_descriptors = [
    {"nof_layers": 1, "out_features": latent_dim, "activation": None},
]

encoder = SimpleGaussianVAE.SimpleGaussVAEFCEncoder(
    prev_layer_descriptors = encoder_prev_layer_descriptors,
    gaussparam_layer_descriptors = encoder_gaussparam_layer_descriptors,
)

print("Encoder:")
print(encoder)

print("\n")

## create latent space distribution generation model
distrib_sample = SimpleGaussianVAE.DiagGaussSample()

print("Distribution Sampling:")
print(distrib_sample)

print("\n")

## create decoder model
decoder_layer_descriptors = [
    {"nof_layers": 1, "in_features": latent_dim, "out_features": 400, "activation": torch.nn.LeakyReLU},
    {"nof_layers": 1, "out_features": nof_features, "activation": torch.nn.Sigmoid},
]

decoder = SimpleGaussianVAE.SimpleGaussVAEFCDecoder(
    layer_descriptors = decoder_layer_descriptors
)

print("Decoder:")
print(decoder)

In [None]:
## quickly validate model can run
encoder.to("cpu")
distrib_sample.to("cpu")
decoder.to("cpu")
with torch.no_grad():
    check_features, check_targets = next(iter(train_dataloader))
    encoder.eval()
    distrib_sample.eval()
    decoder.eval()
    check_distrib_params = encoder(check_features)
    check_codes = distrib_sample(check_distrib_params)
    check_preds = decoder(check_codes)
    print(check_preds.size())

In [None]:
## get device
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device} device")

In [None]:
## training parameters
importlib.reload(Loss)

learning_rate = 1E-3
nof_epochs = 100

train_parameters = list(set(list(encoder.parameters()) + list(distrib_sample.parameters()) + list(decoder.parameters())))
loss_func = Loss.GaussSimilarityLoss(gauss_sigma = 1.0)

distrib_loss_rate = 1 
distrib_loss_func = Loss.UnitGaussKullbackLeiblerDivergenceLoss()

optimizer = torch.optim.Adam(train_parameters, lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer = optimizer,
    mode = "min",
    factor = 0.1,
    patience = 10,
    threshold = 1E-4,
    min_lr = 0,
)

stop_lr = 1E-6

In [None]:
## training loop
importlib.reload(TrainAndValidate)

encoder = encoder.to(device)
distrib_sample = distrib_sample.to(device)
decoder = decoder.to(device)

learning_rates = torch.zeros((nof_epochs,))
train_losses = torch.zeros((nof_epochs,))
validate_losses = torch.zeros((nof_epochs,))

end_nof_epochs = 0

for i_epoch in range(nof_epochs):
    print(f" ------ Epoch {i_epoch} ------ ")

    end_nof_epochs = i_epoch
    
    cur_lr = optimizer.param_groups[0]['lr'];

    if cur_lr < stop_lr:
        break
    
    print(f"current lr = {cur_lr}")
    learning_rates[i_epoch] = cur_lr

    cur_train_loss = TrainAndValidate.train_one_epoch(
        encoder_model = encoder,
        decoder_model = decoder,
        distrib_model = distrib_sample,
        train_loader = train_dataloader,
        data_loss_func = loss_func,
        optimizer = optimizer,
        distrib_loss_rate = distrib_loss_rate,
        distrib_loss_func = distrib_loss_func,
        device = device,
    )

    cur_validate_loss = TrainAndValidate.validate_one_epoch(
        encoder_model = encoder,
        decoder_model = decoder,
        distrib_model = distrib_sample,
        train_loader = validate_dataloader,
        data_loss_func = loss_func,
        distrib_loss_rate = distrib_loss_rate,
        distrib_loss_func = distrib_loss_func,
        device = device,
    )

    train_losses[i_epoch] = cur_train_loss
    validate_losses[i_epoch] = cur_validate_loss

    scheduler.step(cur_validate_loss)
    
    print("\n")

In [None]:
## plot learning rate and losses
plt.figure()

plt.subplot(2,1,1)
plt.plot(learning_rates[:end_nof_epochs], label = "learning rate")
plt.yscale("log")
plt.legend()

plt.subplot(2,1,2)
plt.plot(train_losses[:end_nof_epochs], label = "train loss")
plt.plot(validate_losses[:end_nof_epochs], label = "validate loss")
plt.yscale("log")
plt.legend()

plt.show()

In [None]:
## check result

# check_dataloader = train_dataloader
check_dataloader = validate_dataloader

data_vec_to_image = Transforms.Reshape(rawdata_size)

encoder = encoder.to("cpu")
distrib_sample.to("cpu")
decoder = decoder.to("cpu")

encoder.eval()
distrib_sample.eval()
decoder.eval()

# encoder.train()
# decoder.train()

check_data_idxs = range(3)

## generate check codes 
with torch.no_grad():

    check_features, check_targets = next(iter(check_dataloader))
    
    check_features = check_features.detach()
    check_targets = check_targets.detach()

    check_distrib_params = encoder(check_features)
    check_codes = distrib_sample(check_distrib_params)
    check_preds = decoder(check_codes)

    check_codes = check_codes.detach()
    check_preds = check_preds.detach()

    ## plot encoded representations for all the samples in the batch 
    plt.figure()
    plt.imshow(check_codes, 
               cmap = "gray", 
               aspect = "auto",
               # vmin = 0, vmax = 1,
              )
    plt.title("code")
    plt.xlabel("feature #")
    plt.ylabel("sample #")
    plt.colorbar()
    plt.show()

    ## plot check data idxs
    for check_data_idx in check_data_idxs:
    
        check_feature = data_vec_to_image(check_features[check_data_idx]).squeeze().numpy()
        check_pred = data_vec_to_image(check_preds[check_data_idx]).squeeze().numpy()
        check_target = data_vec_to_image(check_targets[check_data_idx]).squeeze().numpy()
    
        plt.figure()
        
        plt.subplot(1,3,1)
        plt.imshow(check_feature, cmap = "gray")
        plt.xticks([])
        plt.yticks([])
        # plt.colorbar()
        plt.title(f"Feature, idx = {check_data_idx}")
        
        plt.subplot(1,3,2)
        plt.imshow(check_pred, cmap = "gray")
        plt.xticks([])
        plt.yticks([])
        # plt.colorbar()
        plt.title(f"Prediction, idx = {check_data_idx}")
        
        plt.subplot(1,3,3)
        plt.imshow(check_target, cmap = "gray")
        plt.xticks([])
        plt.yticks([])
        # plt.colorbar()
        plt.title(f"Target, idx = {check_data_idx}")
        
        plt.tight_layout()
        plt.show()

In [None]:
dst_dir_path = ".\Results"
dst_plot_subdir_name = "Plots"

plot_fig_prefix = r"VAE"

if not os.path.isdir(dst_dir_path):
    os.makedirs(dst_dir_path)

dst_plot_subdir_path = os.path.join(dst_dir_path, dst_plot_subdir_name)
if not os.path.isdir(dst_plot_subdir_path):
    os.makedirs(dst_plot_subdir_path)

print(dst_plot_subdir_path)

In [None]:
## save model and model parameters

# save encoder model
dst_encoder_model_name = "encoder_model_" + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
dst_encoder_model_file_name = dst_encoder_model_name + ".pt"
dst_encoder_modelstate_file_name = dst_encoder_model_name + "_state.pt"

dst_encoder_model_file_path = os.path.join(dst_dir_path, dst_encoder_model_file_name)
torch.save(encoder, dst_encoder_model_file_path)
print("encoder model saved to: " + dst_encoder_model_file_path)

dst_encoder_modelstate_file_path = os.path.join(dst_dir_path, dst_encoder_modelstate_file_name)
torch.save(encoder.state_dict(), dst_encoder_modelstate_file_path)
print("encoder model state saved to: " + dst_encoder_modelstate_file_path)


# save latent space distribution model
dst_distrib_model_name = "distrib_model_" + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
dst_distrib_model_file_name = dst_distrib_model_name + ".pt"
dst_distrib_modelstate_file_name = dst_distrib_model_name + "_state.pt"

dst_distrib_model_file_path = os.path.join(dst_dir_path, dst_distrib_model_file_name)
torch.save(distrib_sample, dst_distrib_model_file_path)
print("distribution model saved to: " + dst_distrib_model_file_path)


# save decoder model
dst_decoder_model_name = "decoder_model_" + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
dst_decoder_model_file_name = dst_decoder_model_name + ".pt"
dst_decoder_modelstate_file_name = dst_decoder_model_name + "_state.pt"

dst_decoder_model_file_path = os.path.join(dst_dir_path, dst_decoder_model_file_name)
torch.save(decoder, dst_decoder_model_file_path)
print("decoder model saved to: " + dst_decoder_model_file_path)

dst_decoder_modelstate_file_path = os.path.join(dst_dir_path, dst_decoder_modelstate_file_name)
torch.save(decoder.state_dict(), dst_decoder_modelstate_file_path)
print("decoder model state saved to: " + dst_decoder_modelstate_file_path)

In [None]:
## plot manifold
nof_latent_xs = 11
nof_latent_ys = 11

latent_xlims = [-2,2]
latent_ylims = [-2,2]

latent_xs_lin = np.linspace(latent_xlims[0],latent_xlims[-1], nof_latent_xs)
latent_ys_lin = np.linspace(latent_ylims[0],latent_ylims[-1], nof_latent_ys)

latent_xs, latent_ys = np.meshgrid(latent_xs_lin, latent_ys_lin, indexing = "xy")
latent_xys = np.stack([latent_xs, latent_ys], axis = -1)
latent_xys = latent_xys.reshape((-1,2))
latent_xys = torch.tensor(latent_xys, dtype = torch.float)

data_vec_to_image = Transforms.Reshape(rawdata_size)

plot_fig_name = plot_fig_prefix + f"_Manifold_X{latent_xlims[0]}_{latent_xlims[-1]}_Y_{latent_ylims[0]}_{latent_ylims[-1]}"
plot_png_file_name = plot_fig_name + ".png"
plot_png_file_path = os.path.join(dst_plot_subdir_path, plot_png_file_name)

decoder = decoder.to("cpu")
decoder.eval()
with torch.no_grad():
    manifold_ves = decoder(latent_xys)

    plot_nof_cols = nof_latent_xs
    plot_nof_rows = nof_latent_ys

    fig = plt.figure(figsize = (plot_nof_cols, plot_nof_rows) )
    plot_gs = matplotlib.gridspec.GridSpec(plot_nof_rows, plot_nof_cols, figure = fig, wspace = 0.1, hspace = 0.1)
    for i_row in range(plot_nof_rows):
        for i_col in range(plot_nof_cols):
            cur_code_idx = i_row * plot_nof_cols + i_col
            if cur_code_idx >= latent_xys.size(0):
                continue
            cur_manifold_img = data_vec_to_image(manifold_ves[cur_code_idx,...])
            cur_manifold_img = np.squeeze(cur_manifold_img)
            plt.subplot(plot_gs[i_row,i_col])
            plt.imshow(
                cur_manifold_img, 
                cmap = "gray",
                interpolation = "none",
            )
            plt.xticks([])
            plt.yticks([])
    fig.savefig(plot_png_file_path, bbox_inches='tight', dpi = 300)
    plt.show()

print(plot_png_file_path)

In [None]:
# create raw data loader

train_rawdataloader = torch.utils.data.DataLoader(train_rawdata, 
                                                   batch_size = train_batch_size, 
                                                   shuffle = False)

validate_rawdataloader = torch.utils.data.DataLoader(validate_rawdata, 
                                                     batch_size = validate_batch_size, 
                                                     shuffle = False)

In [None]:
# extract latent space distribution
check_rawdataloader = validate_rawdataloader

encoder = encoder.to("cpu")
distrib_sample = distrib_sample.to("cpu")

check_codes = []
check_labels = []

encoder.eval()
distrib_sample.eval()
with torch.no_grad():
    for i_batch, data in enumerate(check_rawdataloader):
        inputs, labels = data  
        inputs = torch.flatten(inputs, start_dim = 1, end_dim = -1)
        
        cur_distrib_params = encoder(inputs)
        cur_codes = distrib_sample(cur_distrib_params)
        
        check_codes.append(cur_codes)
        check_labels.append(labels)

check_codes = torch.concat(check_codes, dim = 0)
check_labels = torch.concat(check_labels, dim = 0)

print(check_codes.size())
print(check_labels.size())

In [None]:
## plot feature vs code

plot_fig_name = plot_fig_prefix + f"_LatentSpace"
plot_png_file_name = plot_fig_name + ".png"
plot_png_file_path = os.path.join(dst_plot_subdir_path, plot_png_file_name)

plot_x_code_idx = 0
plot_y_code_idx = 1

plot_labels = torch.unique(check_labels)

print(plot_labels)

fig = plt.figure(figsize = (12,12))
for cur_label in plot_labels:
    cur_plot_xs = check_codes[check_labels == cur_label, plot_x_code_idx]
    cur_plot_ys = check_codes[check_labels == cur_label, plot_y_code_idx]
    plt.scatter(cur_plot_xs, cur_plot_ys, s = 10, alpha = 0.5, label = f"{cur_label}")
plt.xlabel(f"Latent space dim {plot_x_code_idx}")
plt.ylabel(f"Latent space dim {plot_y_code_idx}")
plt.legend()
fig.savefig(plot_png_file_path, bbox_inches='tight', dpi = 150)
plt.show()

print(plot_png_file_path)