# 🖼️ Generating Figure 3

The following code generates Figure 3 of the paper by plotting the performance of the unconditioned model against the conditioned model for SatBird and sPlotOpen.

In [1]:
from os import listdir
from os.path import isfile, join

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch

from torchmetrics.classification import MultilabelAUROC

# SatBird

### Non-songbirds

In [None]:
uncond_path = "SatBird/preds_nonsongbirds"
cond_path = "SatBird/preds_nonsongbirds_conditioned"
indices = np.load("SatBird/nonsongbird_indices.npy")

target_path = "SatBird/satbird_usa_summer_targets.pkl"
targets = np.load(target_path, allow_pickle=True)

non_cond_mae_list = []
non_uncond_mae_list = []

for f in listdir(uncond_path):
    if isfile(join(uncond_path, f)):
        target = targets[f[:-4]]
        uncond_pred = np.load(join(uncond_path, f))
        cond_pred = np.load(join(cond_path, f))
        
        uncond_mae = np.abs(target[indices] - uncond_pred)
        cond_mae = np.abs(target - cond_pred)
        
        non_uncond_mae_list.append(uncond_mae)
        non_cond_mae_list.append(cond_mae)
        
non_cond_mae_list = np.stack(non_cond_mae_list, 0).mean(0)[indices]
non_uncond_mae_list = np.stack(non_uncond_mae_list, 0).mean(0)

### Songbirds

In [None]:
uncond_path = "SatBird/preds_songbirds"
cond_path = "SatBird/preds_songbirds_conditioned"
indices = np.load("SatBird/songbird_indices.npy")

target_path = "SatBird/satbird_usa_summer_targets.pkl"
targets = np.load(target_path, allow_pickle=True)

uncond_mae_list = []
cond_mae_list = []

for f in listdir(uncond_path):
    if isfile(join(uncond_path, f)):
        
        target = targets[f[:-4]]
        uncond_pred = np.load(join(uncond_path, f))
        cond_pred = np.load(join(cond_path, f))
        
        uncond_mae = np.abs(target[indices] - uncond_pred)
        cond_mae = np.abs(target - cond_pred)
        
        uncond_mae_list.append(uncond_mae)
        cond_mae_list.append(cond_mae)
        
cond_mae_list = np.stack(cond_mae_list, 0).mean(0)[indices]
uncond_mae_list = np.stack(uncond_mae_list, 0).mean(0)

# sPlotOpen

### Non-trees

In [None]:
uncond_path = "sPlotOpen/splot_predictions/preds_nontree"
cond_path = "sPlotOpen/splot_predictions/preds_nontree_conditioned"
species_indices = np.load("sPlotOpen/splot_global_species_indices_ge_100_occ_v2.npy").astype(int) 

nontree_indices = pd.read_csv("sPlotOpen/species_merge_duplicates_v2.csv")["isTree"]
nontree_indices = nontree_indices.index[~nontree_indices].tolist()

common_indices = np.intersect1d(species_indices, nontree_indices)
common_indices = np.where(np.isin(species_indices, common_indices))[0]

# target_path = "sPlotOpen/merged_species_occurrences_v2.npy"
target_path = "sPlotOpen/sPlotOpen_targets_v2.pkl"
targets = np.load(target_path, allow_pickle=True)

nontree_uncond_pred_list = []
nontree_cond_pred_list = []
labels = []

for i, f in enumerate(listdir(uncond_path)):
    if isfile(join(uncond_path, f)):
        target = targets[int(f[:-4])]
        uncond_pred = np.load(join(uncond_path, f))
        cond_pred = np.load(join(cond_path, f))
        nontree_uncond_pred_list.append(uncond_pred)
        nontree_cond_pred_list.append(cond_pred)
        labels.append(target[common_indices])
        
nontree_uncond_pred_list = np.array(nontree_uncond_pred_list)
nontree_cond_pred_list = np.array(nontree_cond_pred_list)
tree_labels = np.array(labels)

In [None]:
non_zero_indices = tree_labels.sum(0) != 0

# Unconditioned
metric = MultilabelAUROC(num_labels=tree_labels[:, non_zero_indices].shape[1], average=None)
metric.update(torch.tensor(nontree_uncond_pred_list[:, non_zero_indices]),
              torch.tensor(tree_labels[:, non_zero_indices]))
nontree_uncond_aucs = metric.compute()
print(nontree_uncond_aucs.mean())

# Conditioned
metric = MultilabelAUROC(num_labels=tree_labels[:, non_zero_indices].shape[1], average=None)
metric.update(torch.tensor(nontree_cond_pred_list[:, non_zero_indices]),
              torch.tensor(tree_labels[:, non_zero_indices]))
nontree_cond_aucs = metric.compute()
print(nontree_cond_aucs.mean())

### Trees

In [None]:
uncond_path = "sPlotOpen/splot_predictions/preds_tree"
cond_path = "sPlotOpen/splot_predictions/preds_tree_conditioned"
species_indices = np.load("sPlotOpen/splot_global_species_indices_ge_100_occ_v2.npy").astype(int) 

tree_indices = pd.read_csv("sPlotOpen/species_merge_duplicates_v2.csv")["isTree"]
tree_indices = tree_indices.index[tree_indices].tolist()

common_indices = np.intersect1d(species_indices, tree_indices)
common_indices = np.where(np.isin(species_indices, common_indices))[0]

# target_path = "sPlotOpen/merged_species_occurrences_v2.npy"
target_path = "sPlotOpen/sPlotOpen_targets_v2.pkl"
targets = np.load(target_path, allow_pickle=True)

tree_uncond_pred_list = []
tree_cond_pred_list = []
labels = []

for i, f in enumerate(listdir(uncond_path)):
    if isfile(join(uncond_path, f)):
        target = targets[int(f[:-4])]
        uncond_pred = np.load(join(uncond_path, f))
        cond_pred = np.load(join(cond_path, f))
        tree_uncond_pred_list.append(uncond_pred)
        tree_cond_pred_list.append(cond_pred)
        labels.append(target[common_indices])
        
tree_uncond_pred_list = np.array(tree_uncond_pred_list)
tree_cond_pred_list = np.array(tree_cond_pred_list)
tree_labels = np.array(labels)

In [None]:
non_zero_indices = tree_labels.sum(0) != 0

# Unconditioned
metric = MultilabelAUROC(num_labels=tree_labels[:, non_zero_indices].shape[1], average=None)
metric.update(torch.tensor(tree_uncond_pred_list[:, non_zero_indices]),
              torch.tensor(tree_labels[:, non_zero_indices]))
tree_uncond_aucs = metric.compute()
print(tree_uncond_aucs.mean())

# Conditioned
metric = MultilabelAUROC(num_labels=tree_labels[:, non_zero_indices].shape[1], average=None)
metric.update(torch.tensor(tree_cond_pred_list[:, non_zero_indices]),
              torch.tensor(tree_labels[:, non_zero_indices]))
tree_cond_aucs = metric.compute()
print(tree_cond_aucs.mean())

## Plotting

In [None]:
s = 4
font_size = 13
plt.rcParams.update({'font.size': font_size})

bbox_props = dict(facecolor="white", edgecolor="gray", boxstyle="round,pad=0.3", alpha=0.5)

fontsize_txt = 12
cmap = plt.cm.viridis

# Create subplots
fig, axes = plt.subplots(1, 4, figsize=(16, 5))

xlim_min = 0
xlim_max = 0.25

x_txt_up = 0.01
y_txt_up = 0.225
x_txt_low = 0.09
y_txt_low = 0.02

axes[0].scatter(cond_mae_list, uncond_mae_list, s=s, color=sns.color_palette("Paired")[9])
axes[0].set_xlabel("conditioned CISO (MAE$\\!\\downarrow\!\!$)")
axes[0].set_ylabel("unconditioned CISO (MAE$\\!\\downarrow\!\!$)")
axes[0].set_xlim(xlim_min, xlim_max)
axes[0].set_ylim(xlim_min, xlim_max)
x = np.linspace(xlim_min, xlim_max, 100)
axes[0].plot(x, x, linestyle='--', color='black')
axes[0].grid()
axes[0].set_aspect('equal', adjustable='box')
axes[0].set_title("Songbirds")

axes[0].text(x_txt_low , y_txt_low, "Unconditioned better", fontsize=fontsize_txt, rotation=0, bbox=bbox_props)
axes[0].text(x_txt_up, y_txt_up, "Conditioned better", fontsize=fontsize_txt, rotation=0, bbox=bbox_props)

axes[1].scatter(non_cond_mae_list, non_uncond_mae_list, s=s, color=sns.color_palette("Paired")[8])
axes[1].set_xlabel("conditioned CISO (MAE$\\!\\downarrow\!\!$)")
axes[1].set_ylabel("unconditioned CISO (MAE$\\!\\downarrow\!\!$)")
axes[1].set_xlim(xlim_min, xlim_max)
axes[1].set_ylim(xlim_min, xlim_max)
axes[1].plot(x, x, linestyle='--', color='black')
axes[1].grid()
axes[1].set_aspect('equal', adjustable='box')
axes[1].set_title("Non-Songbirds")

axes[1].text(x_txt_low , y_txt_low, "Unconditioned better", fontsize=fontsize_txt, rotation=0, bbox=bbox_props)
axes[1].text(x_txt_up, y_txt_up, "Conditioned better", fontsize=fontsize_txt, rotation=0, bbox=bbox_props)

xlim_min = 0.6
xlim_max = 1

x_txt_up = 0.62
y_txt_up = 0.96
x_txt_low = 0.77
y_txt_low = 0.63

axes[2].scatter(tree_cond_aucs, tree_uncond_aucs, s=s, color=sns.color_palette("Paired")[3])
axes[2].set_xlabel("conditioned CISO (AUC$\\!\\uparrow\!\!$)")
axes[2].set_ylabel("unconditioned CISO (AUC$\\!\\uparrow\!\!$)")
axes[2].set_xlim(xlim_min, xlim_max)
axes[2].set_ylim(xlim_min, xlim_max)
x = np.linspace(xlim_min, xlim_max, 100)
axes[2].plot(x, x, linestyle='--', color='black')
axes[2].grid()
axes[2].set_xticks([0.6, 0.7, 0.8, 0.9, 1])
axes[2].set_yticks([0.6, 0.7, 0.8, 0.9, 1])
axes[2].set_aspect('equal', adjustable='box')
axes[2].set_title("Trees")

axes[2].text(x_txt_low , y_txt_low, "Conditioned better", fontsize=fontsize_txt, rotation=0, bbox=bbox_props)
axes[2].text(x_txt_up, y_txt_up, "Unconditioned better", fontsize=fontsize_txt, rotation=0, bbox=bbox_props)

# axes[3].scatter(nontree_cond_aucs, nontree_uncond_aucs, s=s, color=sns.color_palette("Paired")[2])
occurrences = tree_labels[:, non_zero_indices].sum(0)/max(tree_labels[:, non_zero_indices].sum(0))
sc = axes[3].scatter(nontree_cond_aucs, nontree_uncond_aucs, s=s, c=occurrences, cmap=cmap)
axes[3].set_xlabel("conditioned CISO (AUC$\\!\\uparrow\!\!$)")
axes[3].set_ylabel("unconditioned CISO (AUC$\\!\\uparrow\!\!$)")
axes[3].set_xlim(xlim_min, xlim_max)
axes[3].set_ylim(xlim_min, xlim_max)
axes[3].plot(x, x, linestyle='--', color='black')
axes[3].grid()
axes[3].set_xticks([0.6, 0.7, 0.8, 0.9, 1])
axes[3].set_yticks([0.6, 0.7, 0.8, 0.9, 1])
axes[3].set_aspect('equal', adjustable='box')
axes[3].set_title("Non-Trees")

axes[3].text(x_txt_low , y_txt_low, "Conditioned better", fontsize=fontsize_txt, rotation=0, bbox=bbox_props)
axes[3].text(x_txt_up, y_txt_up, "Unconditioned better", fontsize=fontsize_txt, rotation=0, bbox=bbox_props)

# Add colorbars to indicate occurrences scale
cbar = fig.colorbar(sc, ax=axes, orientation="horizontal", fraction=0.5, pad=0.1)
cbar.set_label("Occurrences")

plt.tight_layout()
plt.savefig("cond_vs_uncond_distribution.pdf", bbox_inches='tight')
plt.show()