In [None]:
import json

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from fmow_utils import extrinsic_factors_fmow
from matplotlib.patches import Rectangle
from metadata import MetadataBias, str2int, validate_dict

demo_classes = [
    "airport",
    "border_checkpoint",
    "dam",
    "factory_or_powerplant",
    "hospital",
    "military_facility",
    "nuclear_powerplant",
    "oil_or_gas_facility",
    "place_of_worship",
    "port",
    "prison",
    "stadium",
    "electric_substation",
    "road_bridge",
]

split_name = "dev"
# country_code = "RUS"

tvt_splits = ["train", "val"]

In [None]:
dfs = []
for split in tvt_splits:
    dfs.append(pd.read_csv(f"splits/{split_name}_{split}.csv"))

df = pd.concat(dfs)

In [None]:
# load precomputed table of FMOW labels and metadata
# df = pd.read_pickle("../trainval_labels_factors.pkl").reset_index(drop=True)
# df["class"] = df["class"].astype("category")
# df["split"] = df.split.astype("category")

# df = df[df.country_code == "USA"]
# df = df[df["class"].isin(demo_classes)]
# # xywh
# boxes = get_fmow_boxes(df)
# img_sizes = np.column_stack((df.img_width.to_numpy(), df.img_height.to_numpy()))

# # quick check for missing classes if we filter down to USA
# us_classes = list(df["class"].unique())
# missing = [c for c in demo_classes if c not in us_classes]
# missing

In [None]:
# gather intrinsic factors (dataset agnostic)
# int_fmow, int_categorical = intrinsic_factors_xywh(boxes, img_sizes)

# gather extrinsic factors (custom to FMOW)
ext_fmow, ext_categorical = extrinsic_factors_fmow(df)

# class labels
cls_fmow = {"class": df["class"].to_numpy()}
cls_categorical = {"class": True}

In [None]:
# combine factors
factors = {**cls_fmow, **ext_fmow}
is_categorical = {**cls_categorical, **ext_categorical}
# match insertion order --- done in MetadataBias class as well
is_categorical = {key: is_categorical[key] for key in factors}

# map non-numeric variables to integers
orig_class = factors["class"]
factors = str2int(factors)

# make sure we have a categorical label for each factor
assert all(k in is_categorical for k in factors)
# make sure each factor has the same number of entries
validate_dict(factors)

In [None]:
md = MetadataBias(factors, is_categorical)

# mi_joint = md.compute_mutual_information(num_neighbors=5)
# mi_onehot = md.mutual_information_by_class(num_neighbors=5)

# factors.keys() == is_categorical.keys()
# zip(md.names, )
# list(md.is_categorical.keys())
# np.unique(md.data[:, md.names.index("cloud_cover")]).shape
# print(md.is_categorical)
# md.data.shape

In [None]:
mi_onehot = md.mutual_information_by_class(class_var="class")

In [None]:
# MI with one-hot classes
classes, _ = np.unique(orig_class, return_inverse=True)

f, ax = plt.subplots(figsize=(24, 8))
# cmap = sns.diverging_palette(220, 10, as_cmap=True)
sns_plot = sns.heatmap(
    mi_onehot,
    cmap="viridis",
    vmin=0,
    vmax=1,
    square=True,
    linewidths=0.5,
    cbar_kws={"shrink": 0.5, "label": "Normalized Mutual Information"},
    xticklabels=md.names[1:],
    yticklabels=classes,
    annot=True,
)
plt.xlabel("Class")
plt.tight_layout(pad=0)

In [None]:
# MI between factors (joint class variable)
mi = md.compute_mutual_information()

In [None]:
f, ax = plt.subplots(figsize=(12, 8))
# mask out lower triangular portion
mask = np.zeros_like(mi, dtype=np.bool_)
mask[np.tril_indices_from(mask)] = True
mask[np.diag_indices_from(mask)] = True
# Generate a custom diverging colormap
cmap = sns.diverging_palette(220, 10, as_cmap=True)

# Draw the heatmap with the mask and correct aspect ratio
sns_plot = sns.heatmap(
    np.minimum(mi[:, 1:], 1),
    mask=mask[:, 1:],
    cmap="viridis",
    vmin=0,
    vmax=1,
    square=True,
    linewidths=0.5,
    cbar_kws={"shrink": 0.5, "label": "Normalized Mutual Information"},
    xticklabels=md.names[1:],
    yticklabels=md.names[:-1],
    annot=True,
)
# highlight correlation with class
ax.add_patch(Rectangle((0, 0), mi.shape[0], 1, fill=False, edgecolor="k", lw=4))
plt.tight_layout(pad=0)

In [None]:
# #  demo data
balance_data = {
    "mutual_information": mi.tolist(),
    "factors": md.names,
}

balance_classwise_data = {
    "mutual_information": mi_onehot.tolist(),
    "classes": classes.tolist(),
    "factors": md.names,
}

tvt_str = "".join(tvt_splits)

with open(f"{split_name}_{tvt_str}_balance_data.json", "w") as fp:
    json.dump(balance_data, fp)
with open(f"{split_name}_{tvt_str}_balance_classwise.json", "w") as fp:
    json.dump(balance_classwise_data, fp)

# balance_rollup = np.sum(np.array(balance_data["mutual_information"])[0, 1:] > 0.5)
# print(balance_rollup)

In [None]:
balance_classwise_data