In [None]:
## import modules
import os
import importlib 
import numpy as np 

import matplotlib
import matplotlib.pyplot as plt

import time

import torch
import torchvision
from torchvision.transforms import v2

In [None]:
## import customized modules
import Modules.Models.FCAutoencoder as FCAutoencoder
from Modules.Data import Transforms
from Modules.Data import Dataset
from Modules.TrainAndValidate import Loss
from Modules.TrainAndValidate import TrainAndValidate
from Modules.TrainAndValidate import TieFCAutoencoder

In [None]:
## define transforms for supervised learning raw data 
train_rawdata_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
])
validate_rawdata_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
])

In [None]:
## define transforms for autoencoder data
train_feature_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
    # Transforms.AddGaussianNoise(0.0, 0.1), # add gaussian noise
    # Transforms.ClipChannelValues(0.0, 1.0), # clip pixel value to between [0,1]
    Transforms.RandomSetConstPxls(0.25, 0), # Mask noise
    Transforms.Reshape((-1,))
])
train_target_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
    Transforms.Reshape((-1,))
])

validate_feature_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
    # Transforms.AddGaussianNoise(0.0, 0.1), # add gaussian noise
    # Transforms.ClipChannelValues(0.0, 1.0), # clip pixel value to between [0,1]
    Transforms.RandomSetConstPxls(0.25, 0), # Mask noise
    Transforms.Reshape((-1,))
])
validate_target_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
    Transforms.Reshape((-1,))
])

In [None]:
## load supervised learning raw data set

# src_dataset_file_path may need to change according to your computer file path
src_dataset_file_path = r"E:\Python\DataSet\TorchDataSet\MNIST" 

train_rawdata = torchvision.datasets.MNIST(
    root = src_dataset_file_path,
    train = True,
    download = True,
    transform = train_rawdata_transform,
)

validate_rawdata = torchvision.datasets.MNIST(
    root = src_dataset_file_path,
    train = False,
    download = True,
    transform = validate_rawdata_transform,
)

rawdata_size = validate_rawdata[0][0].size()
print(f"raw_data size: {rawdata_size}")

In [None]:
## create autoencoder data set from raw data set

train_data = Dataset.AutoencoderDataset(
    train_rawdata,
    feature_transform = train_feature_transform,
    target_transform = train_target_transform,
)

validate_data = Dataset.AutoencoderDataset(
    validate_rawdata,
    feature_transform = validate_feature_transform,
    target_transform = validate_target_transform,
)

print(f"train_data length: {len(train_data)}")
print(f"validate_data length: {len(validate_data)}")

In [None]:
# create data loader

train_batch_size = 512
validate_batch_size = 512

train_dataloader = torch.utils.data.DataLoader(train_data, 
                                               batch_size = train_batch_size, 
                                               shuffle = True)

validate_dataloader = torch.utils.data.DataLoader(validate_data, 
                                                   batch_size = validate_batch_size, 
                                                   shuffle = True)

In [None]:
## take a look of the data
check_data_idx = 0
check_dataloader = train_dataloader

data_vec_to_image = Transforms.Reshape(rawdata_size)

check_features, check_targets = next(iter(check_dataloader))

print(f"Feature batch shape: {check_features.size()}")
print(f"Target batch shape: {check_targets.size()}")

check_feature = data_vec_to_image(check_features[check_data_idx]).squeeze().numpy()
check_target = data_vec_to_image(check_targets[check_data_idx]).squeeze().numpy()

plt.figure()

plt.subplot(1,2,1)
plt.imshow(check_feature, cmap = "gray")
plt.xticks([])
plt.yticks([])
plt.colorbar()
plt.title(f"Feature, idx = {check_data_idx}")

plt.subplot(1,2,2)
plt.imshow(check_target, cmap = "gray")
plt.xticks([])
plt.yticks([])
plt.colorbar()
plt.title(f"Target, idx = {check_data_idx}")

plt.tight_layout()
plt.show()

In [None]:
## create model
prev_features, prev_targets = next(iter(train_dataloader))
nof_features = prev_features.size(-1)

code_dim = 64

## create encoder model
encoder_layer_descriptors = [
    {"nof_layers": 1, "in_features": nof_features, "out_features": code_dim, "activation": torch.nn.LeakyReLU},
]

encoder = FCAutoencoder.SimpleFCNetwork(
    layer_descriptors = encoder_layer_descriptors
)

print("Encoder:")
print(encoder)

print("\n")

## create decoder model
decoder_layer_descriptors = [
    {"nof_layers": 1, "in_features": code_dim, "out_features": nof_features, "activation": torch.nn.LeakyReLU},
]

decoder = FCAutoencoder.SimpleFCNetwork(
    layer_descriptors = decoder_layer_descriptors
)

print("Decoder:")
print(decoder)

In [None]:
## tie weights of encoder and decoder
importlib.reload(TieFCAutoencoder)
TieFCAutoencoder.tie_weight_sym_fc_autoencoder(encoder, decoder)

print(encoder)
print(decoder)

In [None]:
## quickly validate model can run
encoder.to("cpu")
decoder.to("cpu")
with torch.no_grad():
    check_features, check_targets = next(iter(train_dataloader))
    encoder.eval()
    decoder.eval()
    check_codes = encoder(check_features)
    check_preds = decoder(check_codes)
    print(check_preds.size())

In [None]:
## get device
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device} device")

In [None]:
## training parameters
learning_rate = 1E-3
nof_epochs = 100

train_parameters = list(set(list(encoder.parameters()) + list(decoder.parameters())))
loss_func = torch.nn.MSELoss()

optimizer = torch.optim.Adam(train_parameters, lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer = optimizer,
    mode = "min",
    factor = 0.1,
    patience = 10,
    threshold = 1E-4,
    min_lr = 0,
)

stop_lr = 1E-5

In [None]:
## training loop
encoder = encoder.to(device)
decoder = decoder.to(device)

learning_rates = torch.zeros((nof_epochs,))
train_losses = torch.zeros((nof_epochs,))
validate_losses = torch.zeros((nof_epochs,))

end_nof_epochs = 0

for i_epoch in range(nof_epochs):
    print(f" ------ Epoch {i_epoch} ------ ")

    end_nof_epochs = i_epoch
    
    cur_lr = optimizer.param_groups[0]['lr'];

    if cur_lr < stop_lr:
        break
    
    print(f"current lr = {cur_lr}")
    learning_rates[i_epoch] = cur_lr

    cur_train_loss = TrainAndValidate.train_one_epoch(
        encoder_model = encoder,
        decoder_model = decoder,
        train_loader = train_dataloader,
        data_loss_func = loss_func,
        optimizer = optimizer,
        device = device,
    )

    cur_validate_loss = TrainAndValidate.validate_one_epoch(
        encoder_model = encoder,
        decoder_model = decoder,
        validate_loader = validate_dataloader,
        loss_func = loss_func,
        device = device,
    )

    train_losses[i_epoch] = cur_train_loss
    validate_losses[i_epoch] = cur_validate_loss

    scheduler.step(cur_validate_loss)
    
    print("\n")

In [None]:
## plot learning rate and losses
plt.figure()

plt.subplot(2,1,1)
plt.plot(learning_rates[:end_nof_epochs], label = "learning rate")
plt.yscale("log")
plt.legend()

plt.subplot(2,1,2)
plt.plot(train_losses[:end_nof_epochs], label = "train loss")
plt.plot(validate_losses[:end_nof_epochs], label = "validate loss")
plt.yscale("log")
plt.legend()

plt.show()

In [None]:
## untie weights
importlib.reload(TieFCAutoencoder)
TieFCAutoencoder.untie_weight_fc_models(encoder)
TieFCAutoencoder.untie_weight_fc_models(decoder)

print(encoder)
print(decoder)

In [None]:
## fine tune learning parameters
finetune_learning_rate = 1E-3
finetune_nof_epochs = 50

finetune_train_parameters = list(encoder.parameters()) + list(decoder.parameters())
finetune_loss_func = torch.nn.MSELoss()

finetune_optimizer = torch.optim.Adam(finetune_train_parameters, lr = learning_rate)
finetune_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer = finetune_optimizer,
    mode = "min",
    factor = 0.1,
    patience = 10,
    threshold = 1E-4,
    min_lr = 0,
)

finetune_stop_lr = 1E-5

In [None]:
## fine tune training loop
encoder = encoder.to(device)
decoder = decoder.to(device)

finetune_learning_rates = torch.zeros((nof_epochs,))
finetune_train_losses = torch.zeros((nof_epochs,))
finetune_validate_losses = torch.zeros((nof_epochs,))

finetune_end_nof_epochs = 0

for i_epoch in range(finetune_nof_epochs):
    print(f" ------ Epoch {i_epoch} ------ ")

    finetune_end_nof_epochs = i_epoch
    
    cur_lr = finetune_optimizer.param_groups[0]['lr'];

    if cur_lr < finetune_stop_lr:
        break
    
    print(f"current lr = {cur_lr}")
    finetune_learning_rates[i_epoch] = cur_lr

    cur_train_loss = TrainAndValidate.train_one_epoch(
        encoder_model = encoder,
        decoder_model = decoder,
        train_loader = train_dataloader,
        data_loss_func = finetune_loss_func,
        optimizer = finetune_optimizer,
        device = device,
    )

    cur_validate_loss = TrainAndValidate.validate_one_epoch(
        encoder_model = encoder,
        decoder_model = decoder,
        validate_loader = validate_dataloader,
        loss_func = finetune_loss_func,
        device = device,
    )

    finetune_train_losses[i_epoch] = cur_train_loss
    finetune_validate_losses[i_epoch] = cur_validate_loss

    finetune_scheduler.step(cur_validate_loss)
    
    print("\n")

In [None]:
## plot learning rate and losses
plt.figure()

plt.subplot(2,1,1)
plt.plot(finetune_learning_rates[:finetune_end_nof_epochs], label = "learning rate")
plt.yscale("log")
plt.legend()

plt.subplot(2,1,2)
plt.plot(finetune_train_losses[:finetune_end_nof_epochs], label = "train loss")
plt.plot(finetune_validate_losses[:finetune_end_nof_epochs], label = "validate loss")
plt.yscale("log")
plt.legend()

plt.show()

In [None]:
## check result

# check_dataloader = train_dataloader
check_dataloader = validate_dataloader

data_vec_to_image = Transforms.Reshape(rawdata_size)

encoder = encoder.to("cpu")
decoder = decoder.to("cpu")

encoder.eval()
decoder.eval()

# encoder.train()
# decoder.train()

check_data_idxs = range(3)

## generate check codes 
with torch.no_grad():

    check_features, check_targets = next(iter(check_dataloader))
    
    check_features = check_features.detach()
    check_targets = check_targets.detach()
        
    check_codes = encoder(check_features)
    check_preds = decoder(check_codes)
    
    check_codes = check_codes.detach()
    check_preds = check_preds.detach()

    ## plot encoded representations for all the samples in the batch 
    plt.figure()
    plt.imshow(check_codes, 
               cmap = "gray", 
               aspect = "auto",
               # vmin = 0, vmax = 1,
              )
    plt.title("code")
    plt.xlabel("feature #")
    plt.ylabel("sample #")
    plt.colorbar()
    plt.show()

    ## plot check data idxs
    for check_data_idx in check_data_idxs:
    
        check_feature = data_vec_to_image(check_features[check_data_idx]).squeeze().numpy()
        check_pred = data_vec_to_image(check_preds[check_data_idx]).squeeze().numpy()
        check_target = data_vec_to_image(check_targets[check_data_idx]).squeeze().numpy()
    
        plt.figure()
        
        plt.subplot(1,3,1)
        plt.imshow(check_feature, cmap = "gray")
        plt.xticks([])
        plt.yticks([])
        # plt.colorbar()
        plt.title(f"Feature, idx = {check_data_idx}")
        
        plt.subplot(1,3,2)
        plt.imshow(check_pred, cmap = "gray")
        plt.xticks([])
        plt.yticks([])
        # plt.colorbar()
        plt.title(f"Prediction, idx = {check_data_idx}")
        
        plt.subplot(1,3,3)
        plt.imshow(check_target, cmap = "gray")
        plt.xticks([])
        plt.yticks([])
        # plt.colorbar()
        plt.title(f"Target, idx = {check_data_idx}")
        
        plt.tight_layout()
        plt.show()

In [None]:
## check learned (decoder) features

code_dim = code_dim

plot_nof_rows = int(np.sqrt(code_dim))
plot_nof_cols = int(np.ceil(code_dim/plot_nof_rows))

image_shape = rawdata_size

encoder = encoder.to("cpu")
decoder = decoder.to("cpu")

encoder.eval()
decoder.eval()

check_codes = torch.diag(torch.ones((code_dim,)))
check_codes = 1E15 * check_codes

# plt.figure()
# plt.imshow(check_codes, cmap = "gray")
# plt.colorbar()
# plt.show()    

## plot grid of learned features
with torch.no_grad():

    check_decode_res = decoder(check_codes)
    check_decode_imgs = Transforms.Reshape((code_dim,) + image_shape)(check_decode_res)

    fig = plt.figure(figsize = (plot_nof_cols, plot_nof_rows) )
    plot_gs = matplotlib.gridspec.GridSpec(plot_nof_rows, plot_nof_cols, figure = fig, wspace = 0.1, hspace = 0.1)
    for i_row in range(plot_nof_rows):
        for i_col in range(plot_nof_cols):
            cur_code_idx = i_row * plot_nof_cols + i_col
            if cur_code_idx >= code_dim:
                continue
            cur_decode_img = check_decode_imgs[cur_code_idx,...].detach().numpy()
            cur_decode_img = np.squeeze(cur_decode_img)
            plt.subplot(plot_gs[i_row,i_col])
            plt.imshow(cur_decode_img, cmap = "gray")
            plt.xticks([])
            plt.yticks([])
    plt.show()

In [None]:
## save model and model parameters

## create dst dir path
dst_dir_path = r".\Results"
if not os.path.isdir(dst_dir_path):
    os.makedirs(dst_dir_path)

# save encoder model
dst_encoder_model_name = "encoder_model_" + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
dst_encoder_model_file_name = dst_encoder_model_name + ".pt"
dst_encoder_modelstate_file_name = dst_encoder_model_name + "_state.pt"

dst_encoder_model_file_path = os.path.join(dst_dir_path, dst_encoder_model_file_name)
torch.save(encoder, dst_encoder_model_file_path)
print("encoder model saved to: " + dst_encoder_model_file_path)

dst_encoder_modelstate_file_path = os.path.join(dst_dir_path, dst_encoder_modelstate_file_name)
torch.save(encoder.state_dict(), dst_encoder_modelstate_file_path)
print("encoder model state saved to: " + dst_encoder_modelstate_file_path)

# save decoder model
dst_decoder_model_name = "decoder_model_" + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
dst_decoder_model_file_name = dst_decoder_model_name + ".pt"
dst_decoder_modelstate_file_name = dst_decoder_model_name + "_state.pt"

dst_decoder_model_file_path = os.path.join(dst_dir_path, dst_decoder_model_file_name)
torch.save(decoder, dst_decoder_model_file_path)
print("decoder model saved to: " + dst_decoder_model_file_path)

dst_decoder_modelstate_file_path = os.path.join(dst_dir_path, dst_decoder_modelstate_file_name)
torch.save(decoder.state_dict(), dst_decoder_modelstate_file_path)
print("decoder model state saved to: " + dst_decoder_modelstate_file_path)