In [None]:
## import modules
import os
import importlib 
import numpy as np 

import matplotlib
import matplotlib.pyplot as plt

import time

import torch
import torchvision
from torchvision.transforms import v2

In [None]:
from Modules.Data import Transforms

In [None]:
## define transforms for supervised learning raw data 
train_rawdata_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
    Transforms.Reshape((-1,)),
])
validate_rawdata_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
    Transforms.Reshape((-1,)),
])

In [None]:
## load supervised learning raw data set

# src_dataset_file_path may need to change according to your computer file path
src_dataset_file_path = r"E:\Python\DataSet\TorchDataSet\MNIST" 

train_rawdata = torchvision.datasets.MNIST(
    root = src_dataset_file_path,
    train = True,
    download = True,
    transform = train_rawdata_transform,
)

validate_rawdata = torchvision.datasets.MNIST(
    root = src_dataset_file_path,
    train = False,
    download = True,
    transform = validate_rawdata_transform,
)

rawdata_size = validate_rawdata[0][0].size()
print(f"raw_data size: {rawdata_size}")

In [None]:
# create data loader

train_batch_size = 512
validate_batch_size = 512

train_rawdataloader = torch.utils.data.DataLoader(train_rawdata, 
                                               batch_size = train_batch_size, 
                                               shuffle = False)

validate_rawdataloader = torch.utils.data.DataLoader(validate_rawdata, 
                                                   batch_size = validate_batch_size, 
                                                   shuffle = False)

In [None]:
## load model

# encoder_file_path = r".\Results\encoder_model_2024-06-19-20-32-01.pt"
# encoder_file_path = r".\Results\encoder_model_2024-06-19-21-19-41.pt"
encoder_file_path = r".\Results\encoder_model_2024-06-19-23-13-42.pt"

encoder = torch.load(encoder_file_path)

print("Encoder:")
print(encoder)

In [None]:
## extract code 

check_rawdataloader = validate_rawdataloader

encoder = encoder.to("cpu")

check_codes = []
check_labels = []

encoder.eval()
with torch.no_grad():
    for i_batch, data in enumerate(check_rawdataloader):
        inputs, labels = data    
        
        cur_codes = encoder(inputs)

        check_codes.append(cur_codes)
        check_labels.append(labels)

check_codes = torch.concat(check_codes, dim = 0)
check_labels = torch.concat(check_labels, dim = 0)

print(check_codes.size())
print(check_labels.size())

In [None]:
## plot feature vs code

plot_x_code_idx = 0
plot_y_code_idx = 1

plot_labels = torch.unique(check_labels)

print(plot_labels)

plt.figure(figsize = (16,16))
for cur_label in plot_labels:
    cur_plot_xs = check_codes[check_labels == cur_label, plot_x_code_idx]
    cur_plot_ys = check_codes[check_labels == cur_label, plot_y_code_idx]
    plt.scatter(cur_plot_xs, cur_plot_ys, label = f"{cur_label}")
plt.legend()
plt.show()