In [2]:
import torchvision
import torch
from torch import Tensor

from src.plots import plot_vae_classifier_training_result, plot_image_label, plot_image_label_two
from src.vae.mnist_vae import VaeAutoencoderClassifier
from src.image_classifier.image_classifier import MNISTClassifier
from src.sampling import split_dirichlet

In [3]:
training_data = torchvision.datasets.MNIST(root='../data/MNIST_train', train=True, download=True, transform=torchvision.transforms.ToTensor())
testing_data = torchvision.datasets.MNIST(root='../data/MNIST_test', train=False, download=True, transform=torchvision.transforms.ToTensor())

print(training_data)
print(testing_data)

input = training_data.data[:60000] / 255.0   # normalizing necessary to make pixels in [0, 1] range for FID
labels = training_data.targets[:60000]

In [6]:
# Train VAE classifier
vae_classifier = VaeAutoencoderClassifier(dim_encoding=2)

vae_classifier_model, total_losses, classifier_accuracy_li, classifier_loss_li, vae_loss_li, kl_loss_li = vae_classifier.train_model(
    training_data,
    batch_size=64,
    alpha=1.0,
    beta=20.0,
    epochs=20
)

In [9]:
# plot generated data
image_tensor, label_tensor = vae_classifier.generate_data(n_samples=5)
plot_image_label(image_tensor.cpu().detach().numpy(), label_tensor.cpu().detach().numpy())

In [10]:
# move tensors to cpu before converting to np array
np_classifier_accuracy_li = []
np_classifier_loss_li = []
np_vae_loss_li = []
np_kl_loss_li = []

for output in classifier_accuracy_li:
    if isinstance(output, Tensor):
        np_classifier_accuracy_li.append(output.cpu().detach().numpy())

for output in classifier_loss_li:
    if isinstance(output, Tensor):
        np_classifier_loss_li.append(output.cpu().detach().numpy())
        
for output in vae_loss_li:
    if isinstance(output, Tensor):
        np_vae_loss_li.append(output.cpu().detach().numpy())

for output in kl_loss_li:
    if isinstance(output, Tensor):
        np_kl_loss_li.append(output.cpu().detach().numpy())


In [11]:
# plot results
plot_vae_classifier_training_result(
    input=input,
    labels=labels,
    vae_model_classifier=vae_classifier_model,
    vae_loss_li=np_vae_loss_li,
    total_losses=total_losses, 
    classifier_accuracy_li=np_classifier_accuracy_li, 
    classifier_loss_li=np_classifier_loss_li,
    kl_loss_li=np_kl_loss_li
)

In [12]:
# train classifier for performance evaluation

classifier = MNISTClassifier(input_size=784, num_classes=10)
classifier.train_model(training_data, batch_size=100, epochs=1)
accuracy = classifier.test_model(testing_data)
print("Test accuracy: ", accuracy)

In [13]:
# test image classification with gen images
x, y = vae_classifier.generate_data(n_samples=10000)
print(y.shape)
assert x.shape[0] == y.shape[0]
print("Number of images: ", x.shape[0])

accuracy = classifier.test_model_syn_img_label(x, y)
print("Accuracy: ", accuracy)

In [15]:
# generate imbalanced data set for comparison of distribution of input vs distribution of generated images
training_data = torchvision.datasets.MNIST(root='../data/MNIST_train', train=True, download=True, transform=torchvision.transforms.ToTensor())

input = training_data.data[:60000]
labels = training_data.targets[:60000]

users_data = split_dirichlet(dataset=training_data, num_users=4, is_cfar=False, beta=0.5)

total_input = []
total_labels = []
total_counts = []
for user_idx in users_data:
    images = []
    outputs = []
    counts = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

    for data_idx in users_data[user_idx]:
        image = input[int(data_idx)]
        images.append(image)
        label = labels[int(data_idx)]
        outputs.append(label)
        counts[label] +=1
    total_input.append(images)
    total_labels.append(outputs)
    total_counts.append(counts)

user_idx = 0
sample_input = total_input[user_idx]
sample_label = total_labels[user_idx]

# print(sample_input)
# print(total_labels[user_idx])

input_tensor = torch.stack(sample_input)
label_tensor = torch.stack(sample_label)

plot_image_label_two(input_tensor.cpu().detach().numpy(), label_tensor.cpu().detach().numpy())

assert input_tensor.shape[0] == label_tensor.shape[0]

training_data.data = input_tensor
training_data.targets = label_tensor

assert training_data.data.shape == input_tensor.shape
assert training_data.targets.shape == label_tensor.shape


# Train VAE on imbalanced dataset
vae_imbalanced = VaeAutoencoderClassifier(dim_encoding=2)

# sufficient epoch makes the generated data distribution similar to the given input
_, _, _, _, _, _ = vae_imbalanced.train_model(
    training_data,
    batch_size=64,
    alpha=1.0,
    beta=20.0,
    epochs=10
)

gen_image, gen_output = vae_imbalanced.generate_data(n_samples=sum(total_counts[user_idx]))
gen_counts = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
for probabilities in gen_output:
    max_index = torch.argmax(probabilities)
    gen_counts[max_index]+=1

# plot generated data
plot_image_label(gen_image.cpu().detach().numpy(), gen_output.cpu().detach().numpy())

print("Input counts: ", total_counts[user_idx])
print("Generated counts: ", gen_counts)

In [None]:
# # compute FID score
# syn_input, _ = vae_classifier.generate_data(n_samples=500)
# input = input[:500]
# 
# input_rgb = input.view(-1, 1, 28, 28).repeat(1, 3, 1, 1)
# syn_input_rgb = syn_input.view(-1, 1, 28, 28).repeat(1, 3, 1, 1)
# 
# # compute FID score (worst: 131, best: 85)
# # 0 score only possible if absolutely identical
# fid_score = frechet_inception_distance(input_rgb, syn_input_rgb)
# print("Frechet Inception Distance: ", fid_score)