In [None]:
import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from gradient.slide_deck.shapes import Image, Placement, Text, TextContent
from gradient.slide_deck.slidedeck import (
    DEFAULT_GRADIENT_PRESENTATION_TEMPLATE_PATH,
    DefaultGradientSlideLayouts,
    SlideDeck,
)
from matplotlib.patches import Rectangle

demo_classes = [
    "airport",
    "dam",
    "factory_or_powerplant",
    "hospital",
    "military_facility",
    "nuclear_powerplant",
    "oil_or_gas_facility",
    "place_of_worship",
    "prison",
    "stadium",
    "electric_substation",
    "road_bridge",
]
# demo_countries = ["TUR", "SYR", "IRQ"]
# demo_countries = ["RUS"]
# demo_countries = ["USA"]

split_name = "dev"
tvt_splits = ["train", "val"]
tvt_str = "".join(tvt_splits)

# with open(f"{split_name}_{tvt_str}_balance_data.json", "w") as fp:
#     json.dump(balance_data, fp)
# with open(f"{split_name}_{tvt_str}_balance_classwise.json", "w") as fp:
#     json.dump(balance_classwise_data, fp)

fig_dir = Path.cwd() / "figs"
grad_dir = Path.cwd() / "gradient_shiz"

fig_dir.mkdir(exist_ok=True)
grad_dir.mkdir(exist_ok=True)

In [None]:
with open(f"{split_name}_{tvt_str}_balance_classwise.json") as fp:
    mi_classwise = json.load(fp)

### Diversity (overall)


In [None]:
# # Evenness independent of class label
with open(f"{split_name}_{tvt_str}_diversity_data.json") as f:
    diversity = json.load(f)

# plt_df = pd.DataFrame({"Simpson": diversity["simpson"], "Shannon": diversity["shannon"]})
plt_df = pd.DataFrame({"Simpson": diversity["simpson"]})

fig_div_overall = plt_df.plot(kind="bar", rot=90, figsize=(6, 5), ylabel="Diversity Index")
plt.gca().set_xticklabels(diversity["factors"])
plt.tight_layout(pad=0)
plt.savefig(f"{fig_dir}/diversity_overall_{split_name}.png", dpi=150)
# plt.show()
diversity_rollup = np.sum(np.array(diversity["simpson"]) < 0.5)
diversity_rollup_text = f"Diversity: {diversity_rollup} factors with low diversity"

print([f for i, f in enumerate(diversity["factors"]) if diversity["simpson"][i] < 0.5])

### Diversity (classwise)


In [None]:
# plot
with open(f"{split_name}_{tvt_str}_diversity_classwise.json") as f:
    div_classwise = json.load(f)

fig_div_classwise, ax = plt.subplots(figsize=(12, 6))
# cmap = sns.diverging_palette(10, 220, as_cmap=True)
sns_plot = sns.heatmap(
    div_classwise["classwise_shannon"],
    cmap="viridis_r",
    vmin=0,
    vmax=1,
    square=True,
    linewidths=0.5,
    cbar_kws={"shrink": 0.7, "label": "Evenness/Diversity Index"},
    xticklabels=[n for n in div_classwise["factors"] if n != "class"],
    yticklabels=div_classwise["classes"],
    annot=True,
)
plt.xlabel("Factors")
plt.ylabel("Class")
plt.title("")
plt.tight_layout(pad=0)
plt.savefig(f"{fig_dir}/diversity_classwise_{split_name}.png", dpi=150)

### Balance (mutual information)


In [None]:
with open(f"{split_name}_{tvt_str}_balance_data.json") as fp:
    mi = json.load(fp)

MI = np.array(mi["mutual_information"])
MI = np.minimum(MI, 1)
fig_balance_joint, ax = plt.subplots(figsize=(12, 8))
# mask out lower triangular portion
mask = np.zeros_like(MI, dtype=np.bool_)
mask[np.tril_indices_from(mask)] = True
mask[np.diag_indices_from(mask)] = True
# Generate a custom diverging colormap
cmap = sns.diverging_palette(220, 10, as_cmap=True)

# Draw the heatmap with the mask and correct aspect ratio
sns_plot = sns.heatmap(
    MI[:, 1:],
    mask=mask[:, 1:],
    cmap="viridis",
    vmin=0,
    vmax=1,
    square=True,
    linewidths=0.5,
    cbar_kws={"shrink": 0.5, "label": "Normalized Mutual Information"},
    xticklabels=mi["factors"][1:],
    yticklabels=mi["factors"][:-1],
    annot=True,
)
# highlight correlation with class
ax.add_patch(Rectangle((0, 0), MI.shape[0], 1, fill=False, edgecolor="w", lw=4))
plt.tight_layout(pad=0)
plt.savefig(f"{fig_dir}/balance_joint_{split_name}.png", dpi=150)

balance_rollup = np.sum(MI[0, 1:] > 0.5)
balance_rollup_text = f"Balance: {balance_rollup} factors co-occuring with class label (at risk of shortcut)"

### Balance per class


In [None]:
with open(f"{split_name}_{tvt_str}_balance_classwise.json") as fp:
    mi = json.load(fp)
MI = np.array(mi["mutual_information"])
fig_balance_classwise, ax = plt.subplots(figsize=(12, 8))
cmap = sns.diverging_palette(220, 10, as_cmap=True)
sns_plot = sns.heatmap(
    MI,
    cmap="viridis",
    vmin=0,
    vmax=1,
    square=True,
    linewidths=0.5,
    cbar_kws={"shrink": 0.5, "label": "Mutual Information [nat]"},
    xticklabels=mi["factors"][1:],
    yticklabels=mi["classes"],
    annot=True,
)
plt.xlabel("Class")
plt.tight_layout(pad=0)
plt.savefig(f"{fig_dir}/balance_classwise_{split_name}.png", dpi=150)

In [None]:
# def generate_slide_from_fig(fig_path, fig_title, rollup_text, description):
#     # save_dir = os.path.join(working_directory, f"{fig_title}.png")
#     # fig.savefig(save_dir)
#     # fig_path = f"{working_directory}/{fig_name}"
#     kwargs = {
#         "title": fig_title,
#         "layout": DefaultGradientSlideLayouts.CONTENT_TITLE_ONLY,
#         "additional_shapes": [
#             Image(path=fig_path, placement=Placement.MANUAL, left=5, top=1.5, width=8, height=5, pad=True),
#             TextBox(
#                 content=TextContent(
#                     lines=[
#                         Text(description + "\n"),
#                         Text(rollup_text, bold=True),
#                     ],
#                 ),
#                 placement=Placement.MANUAL,
#                 left=0.5,
#                 top=1.5,
#                 width=4.5,
#                 height=5.0,
#             ),
#         ],
#     }
#     return kwargs

In [None]:
def generate_slide_from_fig(fig_path, fig_title, rollup_text, description):
    kwargs = {
        "title": fig_title,
        "layout": DefaultGradientSlideLayouts.CONTENT_TWO_COLUMNS,
        "placeholder_fillings": [
            TextContent(
                lines=[
                    Text("\n\n", fontsize=18),
                    Text(description + "\n", fontsize=18),
                    Text(rollup_text, bold=True, fontsize=18),
                ],
            ),
            Image(
                path=fig_path,
                placement=Placement.AUTO,
            ),
        ],
    }
    return kwargs

In [None]:
# Generate and add to the slide deck
deck = SlideDeck(presentation_template_path=DEFAULT_GRADIENT_PRESENTATION_TEMPLATE_PATH)

fig_list = ["balance_classwise", "balance_joint", "diversity_classwise", "diversity_overall"]
fig_titles = ["Balance per class", "Balance", "Diversity per class", "Diversity"]

rollup_text = ["", balance_rollup_text, "", diversity_rollup_text]
description_text = [
    "A measure of co-occurrence of metadata factors with class labels.  Metadata factors that spuriously correlate with individual classes may allow a model to learn shortcut relationships rather than the salient properties of each class.  Larger normalized mutual information indicates undesirable opportunity for shortcut learning.",
    "A measure of co-occurrence of metadata factors with class labels.  Metadata factors that spuriously correlate with individual classes may allow a model to learn shortcut relationships rather than the salient properties of each class.  Larger normalized mutual information indicates undesirable opportunity for shortcut learning.",
    "A measure of the distribution of metadata factors in the dataset. A balanced dataset has an even distribution of class labels and generative factors.",
    "A measure of the distribution of metadata factors in the dataset. A balanced dataset has an even distribution of class labels and generative factors.",
]

for fig, title, rt, desc in zip(fig_list, fig_titles, rollup_text, description_text):
    fig_path = f"{fig_dir}/{fig}_{split_name}.png"
    deck.add_slide(**generate_slide_from_fig(fig_path, title, rt, desc))

deck.save(
    output_directory=grad_dir,
    name=f"{split_name}_report",
)

In [None]:
# update cache -- copied from increment-5-demo repo
# tvt_str = "trainval"
# for split_name in ["dev", "op"]:
#     # classwise balance
#     fn = Path(f"{split_name}_{tvt_str}_balance_classwise.json")
#     with fn.open() as f:
#         mi = json.load(f)
#     ds.outputs[split_name]["balance_classwise"] = mi

#     # classwise diversity
#     fn = Path(f"{split_name}_{tvt_str}_diversity_classwise.json")
#     with fn.open() as f:
#         mi = json.load(f)
#     ds.outputs[split_name]["diversity_classwise"] = mi

#     # balance
#     fn = Path(f"{split_name}_{tvt_str}_balance_data.json")
#     with fn.open() as f:
#         mi = json.load(f)
#     ds.outputs[split_name]["balance"] = mi

#     # classwise diversity
#     fn = Path(f"{split_name}_{tvt_str}_diversity_data.json")
#     with fn.open() as f:
#         mi = json.load(f)
#     ds.outputs[split_name]["diversity"] = mi

# new_fn = "new_cache.json"
# with open(new_fn, "w") as f:
#     json.dump(ds.outputs, f)