In [1]:
import os
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import cuda, device, Tensor, save, load, stack, zeros, vstack, squeeze, tensor, clamp_
from src.plots import plot_vae_training_result, plot_image, plot_image_label_two
from src.vae.mnist_vae import ConditionalVae
from src.image_classifier.exq_net_v1 import ExquisiteNetV1

device = device('cuda' if cuda.is_available() else 'cpu')

In [2]:
num_data = 1

# training_data is NOT normalized
# but no need for further normalization since Dataloader inside the train_model automatically does that
training_data = torchvision.datasets.MNIST(root='../../data/MNIST_train', train=True, download=True, transform=torchvision.transforms.ToTensor())
testing_data = torchvision.datasets.MNIST(root='../../data/MNIST_test', train=False, download=True, transform=torchvision.transforms.ToTensor())

# not normalized
input = training_data.data[:num_data]
labels_li = training_data.targets[:num_data]

# # not normalized (values 0 ~ 255)
# print(training_data.data)

In [3]:
# parameters
model = "cvae"
dataset = "mnist"
batch_size = 64
epoch = 20
learning_rate = 0.001

model_path = f"../../models/{model}_{dataset}_{batch_size}_{epoch}_{learning_rate}.pt"

if os.path.exists(model_path):
    cvae = load(model_path)
else:
    cvae = ConditionalVae(dim_encoding=3).to(device)

    vae_model, vae_loss_li, kl_loss_li = cvae.train_model(
        training_data=training_data,
        batch_size=batch_size,
        epochs=epoch,
        learning_rate=learning_rate
    )
    save(cvae, model_path)
    
    # move tensors to cpu before converting to np array
    np_kl_loss_li = []
    
    for output in kl_loss_li:
        if isinstance(output, Tensor):
            np_kl_loss_li.append(output.cpu().detach().numpy())
    
    # plot results
    plot_vae_training_result(
        input=input,
        labels=labels_li,
        vae_model=vae_model,
        vae_loss_li=vae_loss_li,
        kl_loss_li=np_kl_loss_li
    )

In [4]:
# check original image
input, label = training_data[0]

print("Input shape: ", input.shape)
print(input)

print("Label: ", label)

plt.figure()
plt.subplot(151)
plt.axis('off')
plt.imshow(np.squeeze(input.detach().numpy()), cmap='gray')

In [5]:
# check reconstructed image
input, label = training_data[0]
input = input.to(device)
label = tensor(label).to(device)
output = cvae(input, label)
print("Reconstructed shape: ", output.shape)
print(output)

plt.figure()
plt.subplot(151)
plt.axis('off')
squeezed_img = np.squeeze(output.cpu().detach().numpy())
plt.imshow(squeezed_img, cmap='gray')

In [6]:
# check randomly sampled image
image = cvae.generate_data(n_samples=1, target_label=5)
print("Randomly sampled shape: ", image.shape)
print(image)

plt.figure()
plt.subplot(151)
plt.axis('off')
squeezed_img = np.squeeze(image.cpu().detach().numpy())
plt.imshow(squeezed_img, cmap='gray')

In [7]:
# # simple classifier for performance evaluation
# model = "classifier"
# dataset = "mnist"
# batch_size = 64
# epoch = 10
# 
# classifier_path = f"../../models/{model}_{dataset}_{batch_size}_{epoch}.pt"
# 
# # if os.path.exists(classifier_path):
# #     classifier = load(classifier_path)
# # else:
# classifier = MNISTClassifier(input_size=784, num_classes=10)
# classifier.train_model(training_data, batch_size=batch_size, epochs=epoch)
# accuracy = classifier.test_model(testing_data)
# print("Test accuracy: ", accuracy)
# save(classifier, classifier_path)

In [15]:
# generate images for training on classifier
data_count = 60000
ratios = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
batch_size = 60

images_li = []
labels_li = []
for label_idx, ratio in enumerate(ratios):
    num_samples_to_generate = int(data_count * ratio)
    to_iterate = int(num_samples_to_generate / batch_size)
    for i in range(to_iterate):
        image = cvae.generate_data(n_samples=batch_size, target_label=label_idx).cpu().detach()
        images_li.append(image)
        labels_li.append(torch.full((batch_size,), label_idx))

print(len(images_li))
print(images_li[0].shape)
print(images_li[0])

print(len(labels_li))
print(labels_li[0].shape)
# print(labels_li[0])

In [29]:
import random

# shuffle
pairs = list(zip(images_li, labels_li))
random.shuffle(pairs)
shuffled_image_tensors, shuffled_labels = zip(*pairs)
images_li = list(shuffled_image_tensors)
labels_li = list(shuffled_labels)

print(labels_li[0])

In [30]:
# train CNN classifier on generated images
classifier = ExquisiteNetV1(class_num=10, img_channels=1)
classifier.to(device)
classifier.train_model_syn_image(input_li=images_li, labels_li=labels_li, epochs=20, learning_rate=0.01)

In [31]:
# test on real data
accuracy, loss, f1_macro, f1_micro = classifier.test_inference(testing_data, batch_size)

print(accuracy)
print(loss)
print(f1_macro)
print(f1_micro)

# # test on synthetic data
# data_count = 10000
# ratios = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
# images = []
# labels = []
# for label_idx, ratio in enumerate(ratios):
#     num_samples_to_generate = int(data_count * ratio)
#     images.append(
#         cvae.generate_data(n_samples=num_samples_to_generate, target_label=label_idx).cpu().detach()
#     )
#     label = zeros((num_samples_to_generate, 10), device=device)
#     label[:, label_idx] = 1
#     labels.append(label.cpu().detach())
# final_images = vstack(images)
# final_labels = vstack(labels)
# 
# testing_data = torchvision.datasets.MNIST(root='../../data/MNIST_train', train=False, download=True,
#                                                   transform=torchvision.transforms.ToTensor())
# testing_data.data = squeeze(final_images, dim=1)
# testing_data.targets = final_labels.argmax(dim=1)
# 
# accuracy, loss, f1_macro, f1_micro = classifier.test_inference(testing_data, batch_size)
# 
# print(accuracy)
# print(loss)
# print(f1_macro)
# print(f1_micro)