In [None]:
## import modules
import os
import importlib 
import numpy as np 

import matplotlib
import matplotlib.pyplot as plt

import time

import torch
import torchvision
from torchvision.transforms import v2

In [None]:
## import customized modules
import Modules.Models.SimpleGaussianVAE as SimpleGaussianVAE
from Modules.Data import Transforms
from Modules.Data import Dataset

In [None]:
## define transforms for supervised learning raw data 
train_rawdata_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
])
validate_rawdata_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
])

In [None]:
## define transforms for autoencoder data
train_feature_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
    Transforms.Reshape((-1,))
])
train_target_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
    Transforms.Reshape((-1,))
])

validate_feature_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
    Transforms.Reshape((-1,))
])
validate_target_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
    Transforms.Reshape((-1,))
])

In [None]:
## load supervised learning raw data set

# src_dataset_file_path may need to change according to your computer file path
src_dataset_file_path = r"E:\Python\DataSet\TorchDataSet\MNIST" 

train_rawdata = torchvision.datasets.MNIST(
    root = src_dataset_file_path,
    train = True,
    download = True,
    transform = train_rawdata_transform,
)

validate_rawdata = torchvision.datasets.MNIST(
    root = src_dataset_file_path,
    train = False,
    download = True,
    transform = validate_rawdata_transform,
)

rawdata_size = validate_rawdata[0][0].size()
print(f"raw_data size: {rawdata_size}")

In [None]:
## create autoencoder data set from raw data set

train_data = Dataset.AutoencoderDataset(
    train_rawdata,
    feature_transform = train_feature_transform,
    target_transform = train_target_transform,
)

validate_data = Dataset.AutoencoderDataset(
    validate_rawdata,
    feature_transform = validate_feature_transform,
    target_transform = validate_target_transform,
)

print(f"train_data length: {len(train_data)}")
print(f"validate_data length: {len(validate_data)}")

In [None]:
# create data loader

train_batch_size = 128
validate_batch_size = 128

train_dataloader = torch.utils.data.DataLoader(train_data, 
                                               batch_size = train_batch_size, 
                                               shuffle = False)

validate_dataloader = torch.utils.data.DataLoader(validate_data, 
                                                   batch_size = validate_batch_size, 
                                                   shuffle = False)

In [None]:
## take a look of the data
check_data_idx = 0
check_dataloader = train_dataloader

data_vec_to_image = Transforms.Reshape(rawdata_size)

check_features, check_targets = next(iter(check_dataloader))

print(f"Feature batch shape: {check_features.size()}")
print(f"Target batch shape: {check_targets.size()}")

check_feature = data_vec_to_image(check_features[check_data_idx]).squeeze().numpy()
check_target = data_vec_to_image(check_targets[check_data_idx]).squeeze().numpy()

plt.figure()

plt.subplot(1,2,1)
plt.imshow(check_feature, cmap = "gray")
plt.xticks([])
plt.yticks([])
plt.colorbar()
plt.title(f"Feature, idx = {check_data_idx}")

plt.subplot(1,2,2)
plt.imshow(check_target, cmap = "gray")
plt.xticks([])
plt.yticks([])
plt.colorbar()
plt.title(f"Target, idx = {check_data_idx}")

plt.tight_layout()
plt.show()

In [None]:
## source directory path

src_encoder_file_path = r"E:\Python\GitRepos\VariationalAutoencoderPytorch\Results\encoder_model_2024-06-29-22-52-54.pt"
src_distrib_file_path = r"E:\Python\GitRepos\VariationalAutoencoderPytorch\Results\distrib_model_2024-06-29-22-52-54.pt"
src_decoder_file_path = r"E:\Python\GitRepos\VariationalAutoencoderPytorch\Results\decoder_model_2024-06-29-22-52-54.pt"

In [None]:
## load models

encoder = torch.load(src_encoder_file_path)
distrib_sample = torch.load(src_distrib_file_path)
decoder = torch.load(src_decoder_file_path)

print("Encoder:")
print(encoder)
print()

print("Distribution Sampling:")
print(distrib_sample)
print()

print("Decoder:")
print(decoder)
print()

In [None]:
dst_dir_path = ".\Results"
dst_plot_subdir_path = "Plots"

dst_plot_subdir_path = os.path.join(dst_dir_path, dst_plot_subdir_path)
if not os.path.isdir(dst_plot_subdir_path):
    os.makedirs(dst_plot_subdir_path)

print(dst_plot_subdir_path)

In [None]:
## check result

# check_dataloader = train_dataloader
check_dataloader = validate_dataloader

data_vec_to_image = Transforms.Reshape(rawdata_size)

encoder = encoder.to("cpu")
distrib_sample.to("cpu")
decoder = decoder.to("cpu")

encoder.eval()
distrib_sample.eval()
decoder.eval()

# encoder.train()
# decoder.train()

# check_data_idxs = range(3)

check_fig_prefix = "FCSoftSparsity1E-1_SingleResult"

check_data_idxs = [7]

check_batch_idx = 0

## generate check codes 
with torch.no_grad():

    for i_batch in range(check_batch_idx + 1):
        check_features, check_targets = next(iter(check_dataloader))
    
    check_features = check_features.detach()
    check_targets = check_targets.detach()

    check_distrib_params = encoder(check_features)
    check_codes = distrib_sample(check_distrib_params)
    check_preds = decoder(check_codes)

    check_codes = check_codes.detach()
    check_preds = check_preds.detach()

    ## plot encoded representations for all the samples in the batch 
    plt.figure()
    plt.imshow(check_codes, 
               cmap = "gray", 
               aspect = "auto",
               # vmin = 0, vmax = 1,
              )
    plt.title("code")
    plt.xlabel("feature #")
    plt.ylabel("sample #")
    plt.colorbar()
    plt.show()

    ## plot check data idxs
    for check_data_idx in check_data_idxs:
    
        check_feature = data_vec_to_image(check_features[check_data_idx]).squeeze().numpy()
        check_pred = data_vec_to_image(check_preds[check_data_idx]).squeeze().numpy()
        check_target = data_vec_to_image(check_targets[check_data_idx]).squeeze().numpy()

        check_fig_name = check_fig_prefix + f"_Result_{check_batch_idx}_{check_data_idx}"
        check_png_file_name = check_fig_name + ".png"
        check_png_file_path = os.path.join(dst_plot_subdir_path, check_png_file_name)
        
        fig = plt.figure(figsize = (12,4))
        
        plt.subplot(1,3,1)
        plt.imshow(check_feature, 
                   cmap = "gray",
                   interpolation = "none",
                  )
        plt.xticks([])
        plt.yticks([])
        # plt.colorbar()
        plt.title(f"Input data")
        
        plt.subplot(1,3,2)
        plt.imshow(check_pred, 
                   cmap = "gray",
                   interpolation = "none",
                  )
        plt.xticks([])
        plt.yticks([])
        # plt.colorbar()
        plt.title(f"Autoencoder prediction")
        
        plt.subplot(1,3,3)
        plt.imshow(check_target, 
                   cmap = "gray",
                   interpolation = "none",)
        plt.xticks([])
        plt.yticks([])
        # plt.colorbar()
        plt.title(f"Ground truth")
        
        plt.tight_layout()

        fig.savefig(check_png_file_path, bbox_inches='tight')
        
        plt.show()

        print("Plot saved to: " + check_png_file_path)