In [2]:
import numpy as np
from scipy.stats import entropy
from typing import Tuple
from keras.models import load_model
import utils as ut
import plotly.express as px
import os
import pandas as pd

In [None]:
@ut.timer
def load_pretrained_classifier(model_path):
    model = load_model(model_path)
    return model

@ut.timer
def make_prediction(classifier, images: np.ndarray):
    # Predict the labels
    pred = classifier.predict(images)
    # Get the class with the highest probability for each image
    class_labels = np.argmax(pred, axis=1)
    return pred, class_labels

@ut.timer
def inception_score(pred, num_classes: int, eps: float = 1e-16) -> Tuple[float, float]:
    # Compute the KL divergence for each image
    kl = pred * (np.log(pred + eps) - np.log(1.0 / num_classes))
    # Compute the average KL divergence
    avg_kl = np.mean(np.sum(kl, axis=1))
    # Compute the inception score
    score = np.exp(avg_kl)
    return score


@ut.timer
def generate_images(generator, noise_dim, num_samples):
    # Generate a batch of images
    noise = np.random.normal(0, 1, (num_samples, noise_dim))
    gen_imgs = generator.predict(noise)
    gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale images from [-1, 1] to [0, 1]
    return gen_imgs


def plot_class_distribution(class_labels, num_classes):
    # Count the number of images per class
    counts = np.zeros(num_classes)
    for label in class_labels:
        counts[label] += 1
    # Prepare the data for the barplot
    data = {'Class': np.arange(num_classes), 'Number of Images': counts}
    # Plot the barplot using plotly.express
    fig = px.bar(data, x='Class', y='Number of Images', title='Generated Images per Class')
    return fig

@ut.timer
def save_plot(fig, plot_filepath, title, datetime):
    prefix = f"{title}-{datetime}.png"
    filepath = os.path.join(plot_filepath, prefix)
    # save figure to file
    fig.write_image(filepath)
    logger.info(f"Plot saved: {filepath}")

@ut.timer
def plot_inception_score(iterations, inception_scores):
    data = pd.DataFrame({'Iterations': iterations, 'Inception Score': inception_scores})
    fig = px.line(data, x='Iterations', y='Inception Score', markers=True, title='Inception Score vs Training Iterations')
    return fig


In [None]:
conf = ut.load_config()
classifier = load_pretrained_classifier(model_path=conf.a3.paths.classifier_model)
gen_imgs = generate_images(generator, 
                           conf.a3.gan_params.noise_dim, 
                           conf.a3.gan_params.num_samples)
_, class_labels = make_prediction(classifier, gen_imgs)
fig = plot_class_distribution(class_labels)
save_plot(fig, conf.a3.paths.training_inspection_plots, "classification-distribution", datetime)
fig = plot_inception_score(iterations, inception_scores)
save_plot(fig, conf.a3.paths.training_inspection_plots, "inception-score", datetime)