# VQ-VAE Results

This notebook contains the results from the training, validation and testing of the VQ-VAE models. The interpretation of the tables and graphs can be found in the report.

In [None]:
import os
import json
import shutil
import pickle
import pandas as pd
import seaborn as sns
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
import wandb
import cv2
from PIL import Image

workdir = "Generative Diffusion Models for 3D Geometric Objects"
if workdir in os.getcwd() and os.path.basename(os.getcwd()) != workdir:
    os.chdir("..")

## 1. Validation 

Definition of the VQVAE models

In [None]:
model_names = [
    "vqvae_autoencoder_3_512",
    "vqvae_autoencoder_3_2048",
    "vqvae_autoencoder_3_8192",
    "vqvae_autoencoder_8_512",
    "vqvae_autoencoder_8_2048",
    "vqvae_autoencoder_8_8192",
]

First the logged metrics collected during the training and validation of the models is downloaded from WandB.

In [None]:

# Initialize the wandb API
api = wandb.Api()

# Get the runs for a specific project
entity = "simonluder"
project_name = "MSE_P8"
runs = api.runs(f'{entity}/{project_name}')
run_names = [f"{model_name}_vqvae" for model_name in model_names]

# Loop through the runs and download the one with the specific name
for run in runs:
    
    if run.name in run_names:

        if not os.path.exists(f'holographic_pollen/{run.name.replace("_vqvae", "")}/metrics.csv'):

            print(f"Downloading: {run.name}")
            run_id = run.id

            # download metrics from the history artifact
            artifact = api.artifact(name=f'{entity}/{project_name}/run-{run_id}-history:latest', type='wandb-history')
            filename = artifact.file("temp/wandb/")

            # create csv
            df = pd.read_parquet(filename)
            df.to_csv(f'holographic_pollen/{run.name.replace("_vqvae", "")}/metrics.csv', index=False)

            shutil.rmtree("temp/wandb/")


After downloading, the metrics are loaded as pandas dataframe.

In [None]:
metrics = []

for model_name in model_names:
    df = pd.read_csv(f"holographic_pollen/{model_name}/metrics.csv")
    df.loc[:, "model_name"] = model_name.replace("_autoencoder", "")
    metrics.append(df)

metrics = pd.concat(metrics).reset_index(drop=True)

metrics.head(3)

Next the validation loss of the different models is visualized. The Loss is calculated as MSE over all samples in the validation set

In [None]:
plt.figure(figsize=(12,5))
sns.lineplot(data=metrics, x="step", y="val_epoch_reconstructon_loss", hue="model_name")
plt.title("Mean Squared Error on the validation set")
plt.yscale("log")
plt.ylabel("Mean Squared Error")
plt.show()

## 2. Testing

### 2.1 Quantitative Evaluation of the Reconstructed Images

Load the ground truth testset

In [None]:
ground_truh_labels = pd.read_csv("labels_test.csv")

In [None]:

files = ["vqvae_autoencoder_3_512", 
         "vqvae_autoencoder_3_2048",
         "vqvae_autoencoder_3_8192",
         "vqvae_autoencoder_8_512",
         "vqvae_autoencoder_8_2048",
         "vqvae_autoencoder_8_8192"
         ]

df_errors = list()
for file in files:
    with open(f"holographic_pollen/{file}/test/lpips_scores.pkl", "rb") as f:
        df = pd.DataFrame.from_dict(pickle.load(f))
        df["model"] = file.replace("_autoencoder", "")
        df_errors.append(df)
df_errors = pd.concat(df_errors)

Mean Sqaured Error per image as boxplot per model

In [None]:
plt.figure(figsize=(8,4))
sns.boxplot(x='model', y='mse_scores', data=df_errors, flierprops={'marker': 'x', 'markersize': 4})
plt.title('Mean squared error on generated images for the testset')
plt.ylabel('Mean squared error')
plt.xlabel('model')
plt.xticks(rotation=90)
plt.show()

print(df_errors.groupby("model")["mse_scores"].quantile([0.25, 0.5, 0.75]))

LPIPS per image as boxplot per model

In [None]:
plt.figure(figsize=(8,4))
sns.boxplot(x='model', y='lpips_scores', data=df_errors, flierprops={'marker': 'x', 'markersize': 4})
plt.title('LPIPS error on generated images for the testset')
plt.ylabel('LPIPS error')
plt.xlabel('model')
plt.xticks(rotation=90)
plt.show()

print(df_errors.groupby("model")["lpips_scores"].quantile([0.25, 0.5, 0.75]))

### 2.2 Qualitative Evaluation of the Reconstructed Images

In [None]:
test_results = []
for model_name in model_names:
    df = pd.read_json(f"holographic_pollen/{model_name}/test/test_logs.json")
    df.loc[:, "model_name"] = model_name
    df["dataset_id"] = df["filenames"].apply(lambda x: os.path.split(x)[0])
    df["rec_path"] = df["filenames"].apply(lambda x: os.path.split(x)[1])
    test_results.append(df)

test_results = pd.concat(test_results).reset_index(drop=True)
test_results["model"] = test_results["model_name"].apply(lambda x: x.replace("_autoencoder", ""))

test_results.head(3)

Below, individual reconstructed samples from the sub-submitted vqvae variants are visualized and compared with the ground truth image.

In [None]:
samples = test_results.loc[test_results["filenames"] == test_results["filenames"].sample(1).values[0]].reset_index()

fig, ax = plt.subplots(1, len(samples)+1, figsize=(14, 5))

gt_path = os.path.join("Z:\marvel\marvel-fhnw\data\Poleno", samples.at[0, "dataset_id"], samples.at[0, "rec_path"])


ax[0].imshow(cv2.imread(gt_path))
ax[0].set_title("ground_truth", size=10)
ax[0].tick_params(left = False, right = False , labelleft = False , labelbottom = False, bottom = False)

for i, row in samples.iterrows():
   rel_path = f'holographic_pollen/{row["model_name"]}/test/images/{row["rec_path"]}'
   img_generated = cv2.imread(rel_path)
   ax[i+1].imshow(img_generated)
   ax[i+1].set_title(row["model"], size=10)
   ax[i+1].tick_params(left = False, right = False , labelleft = False , labelbottom = False, bottom = False)

plt.tight_layout()
plt.show()

In [None]:
# Select samples from the ground truth test set
sample_filenames = ground_truh_labels.groupby("label").head(1).sort_values(by="label")["rec_path"].values

Visualization of the first image per label from the test dataset. The first column represents the ground truth.

In [None]:
# Add ground truth labels to the dataframe
test_results = pd.merge(test_results, ground_truh_labels[['dataset_id', 'rec_path', 'label']], on=['dataset_id', 'rec_path'], how='left')

# Number of samples to display
num_samples = len(set(test_results["label"]))

# Filter the test results for the selected samples
samples = test_results[test_results["rec_path"].isin(sample_filenames)].reset_index()

# Calculate the number of subplots required
samples_per_subplot = 8  # Number of samples per subplot
num_subplots = (num_samples + samples_per_subplot - 1) // samples_per_subplot  # Calculate the number of subplots needed
num_cols = len(samples[samples["rec_path"] == sample_filenames[0]]) + 1
num_rows = samples_per_subplot

# Iterate through each subplot
for subplot_idx in range(num_subplots):
    # Create a figure for each subplot
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(2 * num_cols, 2.3 * num_rows))
    
    # Iterate through the samples for the current subplot
    for row_idx in range(num_rows):
        sample_idx = subplot_idx * samples_per_subplot + row_idx
        if sample_idx >= num_samples:
            break
        
        sample_filename = sample_filenames[sample_idx]
        # Get the subset of samples for the current filename
        sample_subset = samples[samples["rec_path"] == sample_filename].reset_index()
        
        # Display the ground truth image
        gt_path = os.path.join("Z:/marvel/marvel-fhnw/data/Poleno", sample_subset.at[0, "dataset_id"], sample_subset.at[0, "rec_path"])
        ground_truth_image = cv2.imread(gt_path)
        axes[row_idx, 0].imshow(ground_truth_image)
        axes[row_idx, 0].set_title("Ground Truth", size=12)
        axes[row_idx, 0].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
        
        # Display the generated images for each sample
        for col_idx, row in sample_subset.iterrows():
            rel_path = f'holographic_pollen/{row["model_name"]}/test/images/{row["rec_path"]}'
            generated_image = cv2.imread(rel_path)
            axes[row_idx, col_idx + 1].imshow(generated_image)
            axes[row_idx, col_idx + 1].set_title(row["model"], size=12)
            axes[row_idx, col_idx + 1].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
        
        # Set the label as subtitle for each row
        fig.text(-0.1, (num_rows - row_idx - 0.5) / num_rows, sample_subset.at[0, "label"], ha='center', va='center', fontsize=12, fontweight='bold')
    
    # Adjust layout and show the plot
    plt.tight_layout()
    plt.show()
