In [None]:
## import modules
import os
import importlib 
import numpy as np 

import matplotlib
import matplotlib.pyplot as plt

import time

import torch
import torchvision
from torchvision.transforms import v2

In [None]:
from Modules.Data import Transforms
from Modules.Data import Dataset

In [None]:
## define transforms for supervised learning raw data 
train_rawdata_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
])
validate_rawdata_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
])

In [None]:
## define transforms for autoencoder data
train_feature_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
])
train_target_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
])

validate_feature_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
])
validate_target_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
])

In [None]:
## load supervised learning raw data set

# src_dataset_file_path may need to change according to your computer file path
src_dataset_file_path = r"E:\Python\DataSet\TorchDataSet\MNIST" 

train_rawdata = torchvision.datasets.MNIST(
    root = src_dataset_file_path,
    train = True,
    download = True,
    transform = train_rawdata_transform,
)

validate_rawdata = torchvision.datasets.MNIST(
    root = src_dataset_file_path,
    train = False,
    download = True,
    transform = validate_rawdata_transform,
)

rawdata_size = validate_rawdata[0][0].size()
print(f"raw_data size: {rawdata_size}")

In [None]:
## create autoencoder data set from raw data set

train_data = Dataset.AutoencoderDataset(
    train_rawdata,
    feature_transform = train_feature_transform,
    target_transform = train_target_transform,
)

validate_data = Dataset.AutoencoderDataset(
    validate_rawdata,
    feature_transform = validate_feature_transform,
    target_transform = validate_target_transform,
)

print(f"train_data length: {len(train_data)}")
print(f"validate_data length: {len(validate_data)}")

In [None]:
# create data loader

train_batch_size = 512
validate_batch_size = 512

train_dataloader = torch.utils.data.DataLoader(train_data, 
                                               batch_size = train_batch_size, 
                                               shuffle = False)

validate_dataloader = torch.utils.data.DataLoader(validate_data, 
                                                   batch_size = validate_batch_size, 
                                                   shuffle = False)

In [None]:
## take a look of the data
check_data_idx = 0
check_dataloader = train_dataloader

check_features, check_targets = next(iter(check_dataloader))

print(f"Feature batch shape: {check_features.size()}")
print(f"Target batch shape: {check_targets.size()}")

check_feature = check_features[check_data_idx].squeeze().numpy()
check_target = check_targets[check_data_idx].squeeze().numpy()

plt.figure()

plt.subplot(1,2,1)
plt.imshow(check_feature, cmap = "gray")
plt.xticks([])
plt.yticks([])
plt.colorbar()
plt.title(f"Feature, idx = {check_data_idx}")

plt.subplot(1,2,2)
plt.imshow(check_target, cmap = "gray")
plt.xticks([])
plt.yticks([])
plt.colorbar()
plt.title(f"Target, idx = {check_data_idx}")

plt.tight_layout()
plt.show()

In [None]:
## load models

code_dim = 128

encoder_file_path = r".\Results\encoder_model_2024-06-20-19-49-15.pt"
decoder_file_path = r".\Results\decoder_model_2024-06-20-19-49-15.pt"

encoder = torch.load(encoder_file_path)
decoder = torch.load(decoder_file_path)

print("Encoder:")
print(encoder)
print("\n")
print("Decoder:")
print(decoder)

In [None]:
dst_dir_path = ".\Results"
dst_plot_subdir_path = "Plots"

dst_plot_subdir_path = os.path.join(dst_dir_path, dst_plot_subdir_path)
if not os.path.isdir(dst_plot_subdir_path):
    os.makedirs(dst_plot_subdir_path)

print(dst_plot_subdir_path)

In [None]:
## check result

# check_dataloader = train_dataloader
check_dataloader = validate_dataloader

encoder = encoder.to("cpu")
decoder = decoder.to("cpu")

encoder.eval()
decoder.eval()

# encoder.train()
# decoder.train()

check_data_idxs = [50]

check_fig_prefix = "SimpleConv"

## generate check codes 
with torch.no_grad():

    check_features, check_targets = next(iter(check_dataloader))
    
    check_features = check_features.detach()
    check_targets = check_targets.detach()
        
    check_codes = encoder(check_features)
    check_preds = decoder(check_codes)
    
    check_codes = check_codes.detach()
    check_preds = check_preds.detach()

    ## plot encoded representations for all the samples in the batch 

    check_fig_name = check_fig_prefix + f"_Code"
    check_png_file_name = check_fig_name + ".png"
    check_png_file_path = os.path.join(dst_plot_subdir_path, check_png_file_name)
    
    fig = plt.figure(figsize = (6,4))
    plt.imshow(check_codes, 
               cmap = "gray", 
               aspect = "auto",
               interpolation = "none",
               # vmin = 0, vmax = 1,
              )
    plt.title("Latent representation")
    plt.xlabel("feature #")
    plt.ylabel("sample #")
    plt.colorbar(label = "value")
    fig.savefig(check_png_file_path, bbox_inches='tight', dpi = 300)
    plt.show()

    print(check_png_file_path)

    ## plot check data idxs
    for check_data_idx in check_data_idxs:
    
        check_feature = check_features[check_data_idx].squeeze().numpy()
        check_pred = check_preds[check_data_idx].squeeze().numpy()
        check_target = check_targets[check_data_idx].squeeze().numpy()

        check_fig_name = check_fig_prefix + f"_Result_{check_data_idx}"
        check_png_file_name = check_fig_name + ".png"
        check_png_file_path = os.path.join(dst_plot_subdir_path, check_png_file_name)
        
        fig = plt.figure(figsize = (5,2))
        
        plt.subplot(1,3,1)
        plt.imshow(check_feature, 
                   cmap = "gray",
                   interpolation = "none",
                  )
        plt.xticks([])
        plt.yticks([])
        # plt.colorbar()
        plt.title(f"Input data")
        
        plt.subplot(1,3,2)
        plt.imshow(check_pred, 
                   cmap = "gray",
                   interpolation = "none",
                  )
        plt.xticks([])
        plt.yticks([])
        # plt.colorbar()
        plt.title(f"Prediction")
        
        plt.subplot(1,3,3)
        plt.imshow(check_target, 
                   cmap = "gray",
                   interpolation = "none",)
        plt.xticks([])
        plt.yticks([])
        # plt.colorbar()
        plt.title(f"Ground truth")
        
        plt.tight_layout()

        fig.savefig(check_png_file_path, bbox_inches='tight', dpi = 300)
        
        plt.show()

        print("Plot saved to: " + check_png_file_path)

In [None]:
## check learned (decoder) features

code_dim = code_dim

check_code_amp = 1E15
# check_code_amp = 1

plot_nof_rows = int(np.sqrt(code_dim))
plot_nof_cols = int(np.ceil(code_dim/plot_nof_rows))

image_shape = rawdata_size

encoder = encoder.to("cpu")
decoder = decoder.to("cpu")

encoder.eval()
decoder.eval()

check_codes = torch.diag(torch.ones((code_dim,)))
check_codes = check_code_amp * check_codes

# plt.figure()
# plt.imshow(check_codes, cmap = "gray")
# plt.colorbar()
# plt.show()    

## plot grid of learned features
with torch.no_grad():

    check_decode_res = decoder(check_codes)
    check_decode_imgs = check_decode_res

    check_fig_name = check_fig_prefix + f"_DecoderFeatures"
    check_png_file_name = check_fig_name + ".png"
    check_png_file_path = os.path.join(dst_plot_subdir_path, check_png_file_name)
    
    fig = plt.figure(figsize = (plot_nof_cols, plot_nof_rows) )
    plot_gs = matplotlib.gridspec.GridSpec(plot_nof_rows, plot_nof_cols, figure = fig, wspace = 0.1, hspace = 0.1)
    for i_row in range(plot_nof_rows):
        for i_col in range(plot_nof_cols):
            cur_code_idx = i_row * plot_nof_cols + i_col
            if cur_code_idx >= code_dim:
                continue
            cur_decode_img = check_decode_imgs[cur_code_idx,...].detach().numpy()
            cur_decode_img = np.squeeze(cur_decode_img)
            plt.subplot(plot_gs[i_row,i_col])
            plt.imshow(
                cur_decode_img, 
                cmap = "gray",
                interpolation = "none",
            )
            plt.xticks([])
            plt.yticks([])
    fig.savefig(check_png_file_path, bbox_inches='tight', dpi = 300)
    plt.show()

    print(check_png_file_path)