In [94]:
import os
import torchvision

from torch import cuda, device, Tensor, save, load, zeros, vstack, squeeze, zeros_like, where, round
from src.plots import plot_vae_training_result, plot_image
from src.vae.mnist_vae import ConditionalVae
from src.image_classifier.exq_net_v1 import ExquisiteNetV1

device = device('cuda' if cuda.is_available() else 'cpu')

In [87]:
num_data = 60000

training_data = torchvision.datasets.MNIST(root='../../data/MNIST_train', train=True, download=True, transform=torchvision.transforms.ToTensor())
testing_data = torchvision.datasets.MNIST(root='../../data/MNIST_test', train=False, download=True, transform=torchvision.transforms.ToTensor())

input = training_data.data[:num_data]
labels = training_data.targets[:num_data]

print(training_data.data[0])

In [89]:
# parameters
model = "cvae"
dataset = "mnist"
batch_size = 64
epoch = 10
learning_rate = 0.01

model_path = f"../../models/{model}_{dataset}_{batch_size}_{epoch}_{learning_rate}.pt"

if os.path.exists(model_path):
    vae = load(model_path)
else:
    vae = ConditionalVae(dim_encoding=3).to(device)

    vae_model, vae_loss_li, kl_loss_li = vae.train_model(
        training_data=training_data,
        batch_size=batch_size,
        epochs=epoch,
        learning_rate=learning_rate
    )
    save(vae, model_path)
    
    # move tensors to cpu before converting to np array
    np_kl_loss_li = []
    
    for output in kl_loss_li:
        if isinstance(output, Tensor):
            np_kl_loss_li.append(output.cpu().detach().numpy())
    
    # plot results
    plot_vae_training_result(
        input=input,
        labels=labels,
        vae_model=vae_model,
        vae_loss_li=vae_loss_li,
        kl_loss_li=np_kl_loss_li
    )

In [90]:
# images = vae.generate_data(n_samples=1, target_label=0)
# print(images.shape)
# print(images[0])
# print(images[0].max().item())
# print(images[0].min().item())

In [91]:
plot_image(training_data.data[:5] / 1.0)
print(training_data.data[:1000].mean().item())

In [112]:
# train classifier on generated images
data_count = 60000
ratios = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
images = []
labels = []

global_model = vae
            
for label_idx, ratio in enumerate(ratios):
    num_samples_to_generate = int(data_count * ratio)
    output = global_model.generate_data(n_samples=num_samples_to_generate, target_label=label_idx).cpu().detach()

    output = where(output < 5e-03, zeros_like(output), output)
    output = output * 200
    output = output.int()
    plot_image(output[:5])
    images.append(output)
    
    
    label = zeros((num_samples_to_generate, 10), device=device)
    label[:, label_idx] = 1
    labels.append(label.cpu().detach())
final_images = vstack(images)
final_labels = vstack(labels)

training_data.data = squeeze(final_images, dim=1)
training_data.targets = final_labels.argmax(dim=1)

print(training_data.data.shape)
print(training_data.targets.shape)
print(training_data.data[:1])

In [113]:
testing_data = torchvision.datasets.MNIST(root='../../data/MNIST_train', train=False, download=True,
                                                  transform=torchvision.transforms.ToTensor())

In [114]:
batch_size = 64
epoch = 5
learning_rate = 0.01

classifier = ExquisiteNetV1(class_num=10, img_channels=1)
classifier.to(device)
classifier.train_model(training_data, testing_data, batch_size=batch_size, learning_rate=learning_rate, epochs=epoch)

In [115]:
# test on real data
testing_data = torchvision.datasets.MNIST(root='../../data/MNIST_train', train=False, download=True,
                                                  transform=torchvision.transforms.ToTensor())

accuracy, loss, f1_macro, f1_micro = classifier.test_inference(testing_data, batch_size)
print("Test on real data")
print(accuracy)
print(loss)
print(f1_macro)
print(f1_micro)
print("")


# test on synthetic data
data_count = 10000
ratios = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
images = []
labels = []
for label_idx, ratio in enumerate(ratios):
    num_samples_to_generate = int(data_count * ratio)
    output = global_model.generate_data(n_samples=num_samples_to_generate, target_label=label_idx).cpu().detach()

    output = where(output < 5e-03, zeros_like(output), output)
    output = output * 200
    output = output.int()
    images.append(output)
    
    label = zeros((num_samples_to_generate, 10), device=device)
    label[:, label_idx] = 1
    labels.append(label.cpu().detach())
final_images = vstack(images)
final_labels = vstack(labels)

testing_data.data = squeeze(final_images, dim=1)
testing_data.targets = final_labels.argmax(dim=1)

accuracy, loss, f1_macro, f1_micro = classifier.test_inference(testing_data, batch_size)

print("Test on synthetic data")
print(accuracy)
print(loss)
print(f1_macro)
print(f1_micro)