In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

root_dir = "../"
sys.path.append(root_dir)
from counting_lib import target_idx, DIGIT_NAMES, CLASS_NAME
from datasets import CountingDataset
from notebooks.viz_counting_utils import viz_dist, viz_results, viz_local_results

dataset = CountingDataset(train=False)
digits = pd.DataFrame(dataset.digits, columns=DIGIT_NAMES)

results_dir = os.path.join(root_dir, "results", "counting")
os.makedirs(results_dir, exist_ok=True)

figure_dir = os.path.join(root_dir, "figures", "counting")
os.makedirs(figure_dir, exist_ok=True)

predictions = pd.read_csv(os.path.join(results_dir, "predictions.csv"), index_col=0)

sns.set_style("white")
sns.set_context("paper")

In [None]:
m = 6
idx = np.random.choice(len(dataset), m, replace=False)

_, axes = plt.subplots(1, 6, figsize=(16 / 1.5, 9 / 4))
for i, _idx in enumerate(idx):
    _image, _digits = dataset[_idx]

    target_digit = np.round(_digits[target_idx]).astype(int)

    ax = axes[i]
    ax.imshow(_image.squeeze().permute(1, 2, 0))
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(f"{CLASS_NAME} = {target_digit}")
plt.savefig(os.path.join(figure_dir, "samples.pdf"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "samples.png"), bbox_inches="tight")
plt.show()

In [None]:
viz_dist(dataset, predictions, figure_dir)

In [None]:
global_results_dir = os.path.join(results_dir, "global_model")
global_figure_dir = os.path.join(figure_dir, "global_model")
os.makedirs(global_figure_dir, exist_ok=True)

viz_results(global_results_dir, "linear", global_figure_dir)
viz_results(global_results_dir, "rbf", global_figure_dir)

In [None]:
global_cond_results_dir = os.path.join(results_dir, "global_cond_model")
global_cond_figure_dir = os.path.join(figure_dir, "global_cond_model")
os.makedirs(global_figure_dir, exist_ok=True)

viz_results(global_cond_results_dir, "linear", global_cond_figure_dir)
viz_results(global_cond_results_dir, "rbf", global_cond_figure_dir)

In [None]:
local_cond_results_dir = os.path.join(results_dir, "local_cond_model")
gradcam_dir = os.path.join(results_dir, "gradcam")

local_cond_figure_dir = os.path.join(figure_dir, "local_cond_model")
os.makedirs(local_cond_figure_dir, exist_ok=True)

viz_local_results(local_cond_results_dir, gradcam_dir, local_cond_figure_dir)