In [None]:
import os
import os.path as osp
os.makedirs("plots", exist_ok=True)

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from tabulate import tabulate

import torch
from torch.nn.functional import one_hot

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")


print(f"Device: {device}")
print(f"Path: {os.getcwd()}")

In [None]:
# Take SEFA results from main and put them in the ablation results.
# Take the SEFA beta=0.0, 1 acq_sample, 1 train_sample, latent_dim=1, ablations
# and put them in the sensitivity results.
# Avoids complex evaluation code and checks.

main_folder = osp.join("experiments", "results", "main")
ablation_folder = osp.join("experiments", "results", "ablations")

for dataset in os.listdir(main_folder):
  main_results = torch.load(osp.join(main_folder, dataset))
  ablation_results = torch.load(osp.join(ablation_folder, dataset))
  ablation_results["metrics"]["sefa"] = main_results["metrics"]["sefa"]
  ablation_results["selections"]["sefa"] = main_results["selections"]["sefa"]
  torch.save(ablation_results, osp.join(ablation_folder, dataset))


sensitivity_folder = osp.join("experiments", "results", "ablations", "sensitivity")
for dataset in ["syn1", "syn2", "syn3"]:
  ablation_results = torch.load(osp.join(ablation_folder, f"{dataset}.pt"))

  acq_sensitivity = torch.load(osp.join(sensitivity_folder, dataset, "acq_sample.pt"))
  acq_sensitivity[1] = ablation_results["selections"]["acq_sample"]
  torch.save(acq_sensitivity, osp.join(sensitivity_folder, dataset, "acq_sample.pt"))

  beta_sensitivity = torch.load(osp.join(sensitivity_folder, dataset, "beta.pt"))
  beta_sensitivity["0_0"] = ablation_results["selections"]["beta"]
  torch.save(beta_sensitivity, osp.join(sensitivity_folder, dataset, "beta.pt"))

  latent_sensitivity = torch.load(osp.join(sensitivity_folder, dataset, "num_latents.pt"))
  latent_sensitivity[1] = ablation_results["selections"]["num_latents"]
  torch.save(latent_sensitivity, osp.join(sensitivity_folder, dataset, "num_latents.pt"))

  train_sensitivity = torch.load(osp.join(sensitivity_folder, dataset, "train_sample.pt"))
  train_sensitivity[1] = ablation_results["selections"]["train_sample"]
  torch.save(train_sensitivity, osp.join(sensitivity_folder, dataset, "train_sample.pt"))

# Set up Fonts and Names

In [None]:
# Set color maps, and fonts.

custom_cmap = ListedColormap(matplotlib.colormaps["Blues"](np.linspace(0, 1, 1000))[:900])
plt.rcParams["font.family"] = "serif"
matplotlib.rcParams["mathtext.fontset"] = "dejavuserif"

In [None]:
# Set names for tables and plots.

model_names_dict = {
  "acflow": "ACFlow",
  "dime": "DIME",
  "eddi": "EDDI",
  "fixed_mlp": "Fixed MLP",
  "gdfs": "GDFS",
  "gsmrl": "GSMRL",
  "opportunistic": "Opportunistic RL",
  "random": "Random",
  "vae": "VAE",
  "sefa": "SEFA (ours)",
}

ablations_names_dict = {
  "beta": "$\\beta = 0$",
  "acq_sample": "1 Acquisition Sample",
  "train_sample": "1 Train Sample",
  "deterministic": "Deterministic Encoder",
  "num_latents": "Latent Dim = 1",
  "feature_space_ablation": "Feature Space",
  "copula": "WO Copula",
  "no_normalize": "WO Normalization",
  "prob_weighting": "WO Prob Weighting",
  "sefa": "SEFA (full)",
}

dataset_names_dict = {
  "bank": "Bank Marketing",
  "california_housing": "California Housing",
  "cube": "Cube",
  "fashion_mnist": "Fashion MNIST",
  "metabric": "METABRIC",
  "miniboone": "MiniBooNE",
  "mnist": "MNIST",
  "syn1": "Syn1",
  "syn2": "Syn2",
  "syn3": "Syn3",
  "tcga": "TCGA",
}

dataset_metrics_dict = {
  "bank": "AUROC",
  "california_housing": "Accuracy",
  "cube": "Accuracy",
  "fashion_mnist": "Accuracy",
  "metabric": "Accuracy",
  "miniboone": "AUROC",
  "mnist": "Accuracy",
  "syn1": "AUROC",
  "syn2": "AUROC",
  "syn3": "AUROC",
  "tcga": "Accuracy",
}

# Tables

In [None]:
def print_mean_acquisition_metric_table(table_type, ablation):
  first_row = ["cube", "bank", "california_housing", "miniboone"]
  second_row = ["mnist", "fashion_mnist", "metabric", "tcga"]

  if ablation:
    folder_path = osp.join("experiments", "results", "ablations")
    tmp_model_names = ablations_names_dict
    model_ablation = "Ablation"
  else:
    folder_path = osp.join("experiments", "results", "main")
    tmp_model_names = model_names_dict
    model_ablation = "Model"

  # Load in data.
  print_data = {}
  for dataset in first_row + second_row:
    print_data[dataset] = {}
    dataset_data = torch.load(osp.join(folder_path, f"{dataset}.pt"))["metrics"]
    for model in tmp_model_names.keys():
      tmp_data = dataset_data[model][:, 1:].numpy()
      avg_metric = np.mean(tmp_data, axis=-1)
      mean = np.mean(avg_metric)
      std_err = np.std(avg_metric, ddof=1) / (len(avg_metric)**0.5)
      print_data[dataset][model] = (mean, std_err)

  # ~~~~~~~~~~~~~~ Latex Table ~~~~~~~~~~~~~~
  if table_type == "latex":
    print("\\begin{table}")
    print("  \\caption{WRITE THIS}")
    print("  \\label{WRITE_THIS}")
    print("  \\centering")
    print("  \\begin{tabular}{"+"c"*(len(first_row)+1)+"}")
    for row in [first_row, second_row]:
      if row == first_row:
        print("    \\toprule")
      elif row == second_row:
        print("    \\midrule")
      print(f"    {model_ablation}", end="")
      for dataset in row:
        print(f" & {dataset_names_dict[dataset]}", end="")
      print(" \\\\")
      print("    \\midrule")
      for model in tmp_model_names.keys():
        print(f"    {tmp_model_names[model]}", end="")
        for dataset in row:
          mean, std_err = print_data[dataset][model]
          print(f" & ${mean:.3f} \\pm {std_err:.3f}$", end="")
        print(" \\\\")
    print("    \\bottomrule")
    print("  \\end{tabular}")
    print("\\end{table}")

  # ~~~~~~~~~~~~~~ Markdown Table ~~~~~~~~~~~~~~
  elif table_type == "markdown":
    for row in [first_row, second_row]:
      print(f"| {model_ablation} |", end="")
      for dataset in row:
        print(f" {dataset_names_dict[dataset]} |", end="")
      print("")

      if row == first_row:
        print("| --- |" + " --- |"*len(row))

      for model in tmp_model_names.keys():
        print(f"| {tmp_model_names[model]} |", end="")
        for dataset in row:
          mean, std_err = print_data[dataset][model]
          print(f" ${mean:.3f} \\pm {std_err:.3f}$ |", end="")
        print("")

      if row == first_row:
        print("|   |" + "   |"*len(row))

 # ~~~~~~~~~~~~~~ Python Table ~~~~~~~~~~~~~~
  elif table_type == "python":
    table = []
    for row in [first_row, second_row]:
      names = [f"{model_ablation}"] + [dataset_names_dict[dataset] for dataset in row]
      table.append(names)
      for model in tmp_model_names.keys():
        row_data = [tmp_model_names[model]]
        for dataset in row:
          mean, std_err = print_data[dataset][model]
          row_data.append(f"{mean:.3f} \u00B1 {std_err:.3f}")
        table.append(row_data)
      if row == first_row:
        table.append([""] * len(names))
    print(tabulate(table, tablefmt="fancy_grid"))

  else:
    raise ValueError(f"Table type {table_type} not recognized. Must be latex, markdown or python")

In [None]:
# Main results.
print_mean_acquisition_metric_table(table_type="latex", ablation=False)

In [None]:
# Ablations table.
print_mean_acquisition_metric_table(table_type="latex", ablation=True)

In [None]:
# Synthetic Counts table.

def is_mask_complete(mask, known_features):
  return torch.all(mask[:, known_features], dim=-1).float().detach().cpu().numpy()


def run_once(selections, known_features):
  num_features = selections.shape[0]
  batch_size = selections.shape[-1]
  counts = np.zeros(batch_size)
  mask = torch.zeros((batch_size, num_features))

  for i in range(num_features):
    mask = torch.max(mask, one_hot(selections[i], num_features).float())
    counts += (1.0 - is_mask_complete(mask, known_features))
  return counts + 1.0  # Add one for the selection that completes mask.


def get_syn_results(syn_num, ablation):
  if ablation:
    folder_path = osp.join("experiments", "results", "ablations")
    tmp_model_names = ablations_names_dict
  else:
    folder_path = osp.join("experiments", "results", "main")
    tmp_model_names = model_names_dict

  selections = torch.load(osp.join(folder_path, f"syn{syn_num}.pt"))["selections"]

  X_test = torch.load(osp.join("datasets", "data", f"syn{syn_num}", "X_test_std.pt"))
  pos_ids = torch.where(X_test[:, -1] >= 0.0)[0]
  neg_ids = torch.where(X_test[:, -1] < 0.0)[0]

  if syn_num == 1:
    neg_features = np.array([0, 1, 10])
    pos_features = np.array([2, 3, 4, 5, 10])
  elif syn_num == 2:
    neg_features = np.array([0, 1, 10])
    pos_features = np.array([6, 7, 8, 9, 10])
  elif syn_num == 3:
    neg_features = np.array([2, 3, 4, 5, 10])
    pos_features = np.array([6, 7, 8, 9, 10])
  else:
    raise ValueError(f"Unknown synthetic dataset number: {syn_num}, should be 1, 2 or 3.")

  counts = {}

  for model_name in tmp_model_names.keys():
    model_selections = selections[model_name]
    neg_arr = []
    pos_arr = []
    # Loop over the repeats (can be parallelised but isn't slow so no need).
    for rpt in range(len(model_selections)):
      neg_arr.append(run_once(model_selections[rpt][:, neg_ids], neg_features))
      pos_arr.append(run_once(model_selections[rpt][:, pos_ids], pos_features))
    neg_arr = np.stack(neg_arr, axis=0)
    pos_arr = np.stack(pos_arr, axis=0)
    counts[model_name] = {"less_than": neg_arr, "more_than": pos_arr}
  return counts

In [None]:
# Create the main table.

def print_syn_acquisition_count_table(table_type, ablation):

  if ablation:
    tmp_model_names = ablations_names_dict
    model_ablation = "Ablation"
  else:
    tmp_model_names = model_names_dict
    model_ablation = "Model"

  # Get dictionary of counts.
  syn_counts = {}
  for i in [1, 2, 3]:
    syn_counts[i] = {}
    syn_results = get_syn_results(i, ablation)
    for model_name in tmp_model_names.keys():
      model_results = syn_results[model_name]
      less_results = model_results["less_than"]
      more_results = model_results["more_than"]
      overall_results = np.concatenate([less_results, more_results], axis=-1)
      overall_results = np.mean(overall_results, axis=-1)
      overall_mean = np.mean(overall_results)
      overall_std_err = np.std(overall_results, ddof=1) /(len(overall_results)**0.5)
      syn_counts[i][model_name] = (overall_mean, overall_std_err)

  # ~~~~~~~~~~~~~~ Latex Table ~~~~~~~~~~~~~~
  if table_type == "latex":
    print("\\begin{table}")
    print("  \\caption{WRITE THIS}")
    print("  \\label{tbl:WRITE_THIS}")
    print("  \\centering")
    print("  \\begin{tabular}{cccc}")
    print("    \\toprule")
    print(f"    {model_ablation} & Syn1 & Syn2 & Syn3 \\\\")
    print("    \\midrule")
    for model_name in tmp_model_names.keys():
      print(f"    {tmp_model_names[model_name]} ", end="")
      for i in [1, 2, 3]:
        mean, std_err = syn_counts[i][model_name]
        print(f"& ${mean:.3f} \\pm {std_err:.3f}$ ", end="")
      print("\\\\")
    print("    \\bottomrule")
    print("  \\end{tabular}")
    print("\\end{table}")

  # ~~~~~~~~~~~~~~ Markdown Table ~~~~~~~~~~~~~~
  elif table_type == "markdown":
    print(f"| {model_ablation} | Syn1 | Syn2 | Syn3 |")
    print("| --- | --- | --- | --- |")
    for model_name in tmp_model_names.keys():
      print(f"| {tmp_model_names[model_name]} |", end="")
      for i in [1, 2, 3]:
        mean, std_err = syn_counts[i][model_name]
        print(f" ${mean:.3f} \\pm {std_err:.3f}$ |", end="")
      print("")

  # ~~~~~~~~~~~~~~ Python Table ~~~~~~~~~~~~~~
  elif table_type == "python":
    table = [[f"{model_ablation}", "Syn1", "Syn2", "Syn3"]]
    for model_name in tmp_model_names.keys():
      row = [tmp_model_names[model_name]]
      for i in [1, 2, 3]:
        mean, std_err = syn_counts[i][model_name]
        row.append(f"{mean:.3f} \u00B1 {std_err:.3f}")
      table.append(row)
    print(tabulate(table, tablefmt="fancy_grid"))
  
  else:
    raise ValueError(f"Table type {table_type} not recognized.")

In [None]:
print_syn_acquisition_count_table(table_type="latex", ablation=False)

In [None]:
print_syn_acquisition_count_table(table_type="latex", ablation=True)

# Acquisition Trajectories Plots

In [None]:
# In format [left most location, top most location].

axin_positions = {
  "cube": [0.35, 0.72],
  "bank": [0.35, 0.65],
  "california_housing": [0.44, 0.54],
  "miniboone": [0.32, 0.72],
  "mnist": [0.45, 0.65],
  "fashion_mnist": [0.38, 0.75],
  "metabric": [0.42, 0.64],
  "tcga": [0.4, 0.71],
}

axin_ylims = {
  "cube": [0.96, 0.967],
  "bank": [0.90, 0.94],
  "california_housing": [0.63, 0.76],
  "miniboone": [0.960, 0.977],
  "mnist": [0.75, 0.92],
  "fashion_mnist": [0.72, 0.81],
  "metabric": [0.70, 0.772],
  "tcga": [0.87, 0.935],
}

In [None]:
def plot_on_axis(ax, dataset, trajectories_model_list, ablation):
  if ablation:
    data_path = osp.join("experiments", "results", "ablations")
    tmp_model_names = ablations_names_dict
  else:
    data_path = osp.join("experiments", "results", "main")
    tmp_model_names = model_names_dict

  # Insert sub axis.
  axin_lh = axin_positions[dataset]
  bottom_axin = 0.1
  axin = ax.inset_axes([axin_lh[0], bottom_axin, 0.98-axin_lh[0], axin_lh[1]-bottom_axin])

  markers = ["o", "s", "D", "^", "P"]
  marker_id = 0

  plot_data = torch.load(osp.join(data_path, f"{dataset}.pt"))["metrics"]

  for model in trajectories_model_list:
    tmp_data = plot_data[model][:, 1:].numpy()
    num_features = tmp_data.shape[-1]
    mean = np.mean(tmp_data, axis=0)  # Mean and error taken across repeats.
    err = np.std(tmp_data, axis=0, ddof=1) / (len(tmp_data)**0.5)
    x = np.arange(num_features)+1
    ax.plot(x, mean, label=tmp_model_names[model], marker=markers[marker_id], markersize=6, linestyle="-.")
    ax.fill_between(x, mean-err, mean+err, alpha=0.4)
    axin.plot(x, mean, marker=markers[marker_id], markersize=6, linestyle="-.")
    axin.fill_between(x, mean-err, mean+err, alpha=0.4)
    marker_id = (marker_id+1)%len(markers)

  # Ticks and limits
  if num_features > 16:
    ticks = np.arange(0, num_features+1, 2)
  else:
    ticks = np.arange(0, num_features+1)

  ax.tick_params(axis="both", labelsize=11)
  ax.set_xticks(ticks)
  ax.set_xlim(0.4, num_features+0.6)
  ax.grid()

  start_feature = int(num_features/3)
  axin.tick_params(axis="both", labelsize=9)
  axin.set_xticks(ticks)
  axin.set_xlim(start_feature-0.4, num_features+0.4)
  axin.set_ylim(*axin_ylims[dataset])
  axin.grid()

  # Titles and axis labels
  ax.set_title(dataset_names_dict[dataset], fontsize=19)
  ax.set_xlabel("Acquisition No.", fontsize=14)
  ax.set_ylabel(dataset_metrics_dict[dataset], fontsize=14)



def make_trajectories_plots(trajectories_dataset_list, trajectories_model_list, ablation):
  tmp_model_names = ablations_names_dict if ablation else model_names_dict
  fig, ax = plt.subplots(nrows=2, ncols=4, figsize=[18, 6.5])
  for i, dataset in enumerate(trajectories_dataset_list):
    plot_on_axis(ax.flatten()[i], dataset, trajectories_model_list, ablation)
  fig.tight_layout()
  lines= ax.flatten()[-1].get_lines()
  labels = [tmp_model_names[m] for m in trajectories_model_list]
  fig.legend(lines, labels, loc="upper center", bbox_to_anchor=(0.5, -0.005), ncol=len(trajectories_model_list), fontsize=13)

  plt.savefig(osp.join("plots", f"trajectories{'_ablation' if ablation else ''}.pdf"), bbox_inches="tight")

In [None]:
trajectories_dataset_list = [
  "cube", "bank", "california_housing", "miniboone", 
  "mnist", "fashion_mnist", "metabric", "tcga",
]

In [None]:
make_trajectories_plots(
  trajectories_dataset_list=trajectories_dataset_list,
  trajectories_model_list=["sefa", "dime", "fixed_mlp", "gdfs", "opportunistic"],
  ablation=False,
)

In [None]:
make_trajectories_plots(
  trajectories_dataset_list=trajectories_dataset_list,
  trajectories_model_list=["sefa", "prob_weighting", "deterministic", "acq_sample", "train_sample"],
  ablation=True,
)

# Heat Maps of Individual Acquisition Trajectories

In [None]:
def draw_rectangle_axis(ax, left, down, right, up, linestyle="-", color="lime"):
  ax.plot([left, right], [up, up], color=color, linestyle=linestyle)
  ax.plot([left, right], [down, down], color=color, linestyle=linestyle)
  ax.plot([left, left], [up, down], color=color, linestyle=linestyle)
  ax.plot([right, right], [up, down], color=color, linestyle=linestyle)


def hm_trajectories_on_axis(ax, selections, aspect=0.4, max_features=6, max_trajectories=5, weight=0.015, feature_labels=None):
  # HM means heatmap, so we have a heat map behind sample trajectories.
  np.random.seed(7819)
  selections = selections.numpy()  # Shape: [repeat, num_features, batch]
  num_features = selections.shape[1]
  hm = []
  for f in range(num_features):
    hm.append(np.mean(selections == f, axis=(0, 2)))
  hm = np.stack(hm, axis=0)
  ax.imshow(hm, cmap=custom_cmap, aspect=aspect, origin="lower", vmin=0.0, vmax=1.0)

  trajectories = np.reshape(np.transpose(selections, (1, 0, 2)), (num_features, -1))  # Just flattens trajectories across repeats and batches.
  trajectories = trajectories[:, np.random.permutation(trajectories.shape[1])]
  max_trajectories = min(max_trajectories, trajectories.shape[1])
  trajectories = trajectories[:max_features, :max_trajectories]
  alpha = weight*500/max_trajectories
  ax.plot(trajectories, color="red", linewidth=2.5, alpha=alpha)
  ax.set_xlim(-0.5, max_features-0.5)
  ax.set_xticks(ticks=np.arange(max_features), labels=np.arange(max_features)+1)
  if feature_labels is None:
    feature_labels = np.arange(num_features)+1
  ax.set_yticks(ticks=np.arange(num_features), labels=feature_labels)

In [None]:
def full_syn_heatmap(syn_num, model_list, ablation):
  if ablation:
    tmp_model_names = ablations_names_dict
    selection_path = osp.join("experiments", "results", "ablations")
  else:
    tmp_model_names = model_names_dict
    selection_path = osp.join("experiments", "results", "main")

  syn_x = torch.load(osp.join("datasets", "data", f"syn{syn_num}", "X_test_std.pt"))
  less_ids = torch.where(syn_x[:, -1] < 0.0)[0]
  more_ids = torch.where(syn_x[:, -1] >= 0.0)[0]

  fig, ax = plt.subplots(nrows=2, ncols=len(model_list), figsize=[4*len(model_list), 6.5])
  all_selections = torch.load(osp.join(selection_path, f"syn{syn_num}.pt"))["selections"]

  for m, model in enumerate(model_list):
    selections = all_selections[model]

    hm_trajectories_on_axis(ax[0, m], selections[:, :, less_ids], max_trajectories=400)
    hm_trajectories_on_axis(ax[1, m], selections[:, :, more_ids], max_trajectories=400)

    # Rectangles on feature 11 first. Then on remaining features.
    draw_rectangle_axis(ax[0, m], -0.5, 9.5, 0.5, 10.5)
    draw_rectangle_axis(ax[1, m], -0.5, 9.5, 0.5, 10.5)
    if syn_num == 1:
      draw_rectangle_axis(ax[0, m], 0.5, -0.5, 2.5, 1.5)
      ax[0, m].axvline(x=2.5, color="black", linestyle="--")
      draw_rectangle_axis(ax[1, m], 0.5, 1.5, 4.5, 5.5)
      ax[1, m].axvline(x=4.5, color="black", linestyle="--")
    elif syn_num == 2:
      draw_rectangle_axis(ax[0, m], 0.5, -0.5, 2.5, 1.5)
      ax[0, m].axvline(x=2.5, color="black", linestyle="--")
      draw_rectangle_axis(ax[1, m], 0.5, 5.5, 4.5, 9.5)
      ax[1, m].axvline(x=4.5, color="black", linestyle="--")
    elif syn_num == 3:
      draw_rectangle_axis(ax[0, m], 0.5, 1.5, 4.5, 5.5)
      ax[0, m].axvline(x=4.5, color="black", linestyle="--")
      draw_rectangle_axis(ax[1, m], 0.5, 5.5, 4.5, 9.5)
      ax[1, m].axvline(x=4.5, color="black", linestyle="--")
    else:
      raise ValueError(f"Unknown synthetic dataset number: {syn_num}, should be 1, 2 or 3.")

    ax[0, m].tick_params(axis="both", labelsize=12)
    ax[1, m].tick_params(axis="both", labelsize=12)
    ax[0, m].set_title(f"{tmp_model_names[model]} " + r"$x_{11} < 0.0$", fontsize=15)
    ax[1, m].set_title(f"{tmp_model_names[model]} " + r"$x_{11} \geq 0.0$", fontsize=15)

  fig.supxlabel("Acquisition", fontsize=18, y=0.05)
  fig.supylabel("Feature", fontsize=18, x=0.09)
  fig.suptitle(f"Syn{syn_num} Acquisition Proportions and Trajectories", y=0.96, fontsize=22)

  cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
  cbar = fig.colorbar(ax[0, 0].get_images()[0], cax=cbar_ax)
  cbar.ax.set_yticks(ticks=[0, 0.2, 0.4, 0.6, 0.8, 1.0], labels=[0, 0.2, 0.4, 0.6, 0.8, 1.0], fontsize=12)
  cbar_ax.set_ylabel("Acquisition Proportions", fontsize=18)

  plt.savefig(osp.join("plots", f"syn{syn_num}_heatmap{'_ablation' if ablation else ''}.pdf"), bbox_inches="tight")

In [None]:
# Main heatmaps for synthetic datasets.

for i in [1, 2, 3]:
  full_syn_heatmap(i, model_list=["sefa", "opportunistic", "dime", "gdfs"], ablation=False)

In [None]:
# Ablation heatmaps.

for i in [1, 2, 3]:
  full_syn_heatmap(i, model_list=["sefa", "feature_space_ablation", "deterministic", "acq_sample"], ablation=True)

In [None]:
# TCGA heatmaps

tumor_names = {
  0: "Breast",
  1: "Lung",
  2: "Kidney",
  3: "Brain",
  4: "Ovary",
  5: "Endometrium",
  6: "Head and Neck",
  7: "Central Nervous System",
  8: "Thyroid",
  9: "Prostate",
  10: "Colon",
  11: "Stomach",
  12: "Bladder",
  13: "Liver",
  14: "Cervix",
  15: "Bone Marrow",
  16: "Pancreas",
}

ordered_tcga_class_genes = {
  "Breast": [[1, "ST6GAL1"], [2, "DEF6"], [3, "C7orf51"]],
  "Lung": [[1, "ST6GAL1"], [2, "DNASE1L3"], [3, "PON3"]],
  "Kidney": [[1, "ST6GAL1"], [2, "POU3F3"], [3, "KAAG1"], [4, "HOXA9"]],
  "Brain": [[1, "ST6GAL1"], [2, "DNASE1L3"], [3, "DEF6"]],
  "Ovary": [[1, "ST6GAL1"], [2, "DNASE1L3"], [3, "KAAG1"]],
  "Endometrium": [[1, "ST6GAL1"], [2, "DNASE1L3"], [3, "LYPLAL1"], [4, "PON3"]],
  "Head and Neck": [[1, "ST6GAL1"], [2, "DNASE1L3"], [3, "GRIA2"], [4, "HOXA9"]],
  "Central Nervous System": [[1, "ST6GAL1"], [2, "PON3"], [3, "SERPINB1"]],
  "Thyroid": [[1, "ST6GAL1"], [2, "PON3"], [3, "FOXE1"]],
  "Prostate": [[1, "ST6GAL1"], [2, "SERPINB1"], [3, "GRIA2"]],
  "Colon": [[1, "ST6GAL1"], [2, "SERPINB1"], [3, "LYPLAL1"], [4, "HOXA9"]],
  "Stomach": [[1, "ST6GAL1"], [2, "DNASE1L3"], [3, "LYPLAL1"]],
  "Bladder": [[1, "ST6GAL1"], [2, "DNASE1L3"], [3, "TMEM106A"]],
  "Liver": [[1, "ST6GAL1"], [2, "DNASE1L3"], [3, "PON3"]],
  "Cervix": [[1, "ST6GAL1"], [2, "DNASE1L3"], [3, "GRIA2"]],
  "Bone Marrow": [[1, "ST6GAL1"], [2, "DNASE1L3"], [3, "DEF6"], [4, "GPR81"]],
  "Pancreas": [[1, "ST6GAL1"], [2, "DNASE1L3"], [3, "LYPLAL1"], [4, "C7orf51"]],
}

tcga_feature_labels = [
  "C7orf51",
  "DEF6",
  "DNASE1L3",
  "EFS",
  "FOXE1",
  "GPR81",
  "GRIA2",
  "GSDMC",
  "HOXA9",
  "KAAG1",
  "KLF5",
  "LOC283392",
  "LTBR",
  "LYPLAL1",
  "PON3",
  "POU3F3",
  "SERPINB1",
  "ST6GAL1",
  "TMEM106A",
  "ZNF583",
  "ZNF790",
]


def tcga_large_heatmap(tcga_feature_labels, reduced):
  selections = torch.load(osp.join("experiments", "results", "main", "tcga.pt"))["selections"]["sefa"]
  tcga_y = torch.load(osp.join("datasets", "data", "tcga", "y_test.pt"))

  if reduced:
    fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(16, 5))
    classes = [0, 1, 9, 13]
    aspect = 0.5
  else:
    fig, ax = plt.subplots(nrows=4, ncols=5, figsize=(16, 24))
    classes = list(range(17))
    aspect = 0.8

  for i, c in enumerate(classes):
    ids = torch.where(tcga_y == c)[0]
    hm_trajectories_on_axis(ax.flatten()[i], selections[:, :, ids], aspect=aspect, max_trajectories=100, weight=0.008, feature_labels=tcga_feature_labels)
    ax.flatten()[i].set_title(tumor_names[c], fontsize=18 if reduced else 16)

    # Draw bounding boxes.
    for relevant_feature in ordered_tcga_class_genes[tumor_names[c]]:
      acquisition_count = relevant_feature[0]
      feature_id = tcga_feature_labels.index(relevant_feature[1])
      draw_rectangle_axis(ax.flatten()[i], left=acquisition_count-1.5, down=feature_id-0.5, right=acquisition_count-0.5, up=feature_id+0.5, linestyle="-", color="lime")

  if not reduced:
    ax[-1, -3].set_axis_off()
    ax[-1, -2].set_axis_off()
    ax[-1, -1].set_axis_off()

  if reduced:
    fig.supxlabel("Acquistions", fontsize=18, y=0.01)
    fig.supylabel("Features", fontsize=18, x=0.05)
  else:
    fig.supxlabel("Acquistions", fontsize=22, y=0.08)
    fig.supylabel("Features", fontsize=22, x=0.06)

  cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
  if reduced:
    cbar = fig.colorbar(ax[0].get_images()[0], cax=cbar_ax)
  else:
    cbar = fig.colorbar(ax[0, 0].get_images()[0], cax=cbar_ax)
  cbar.ax.set_yticks(ticks=[0, 0.2, 0.4, 0.6, 0.8, 1.0], labels=[0, 0.2, 0.4, 0.6, 0.8, 1.0], fontsize=14)
  cbar_ax.set_ylabel("Acquisition Proportions", fontsize=22)

  plt.savefig(osp.join("plots", f"tcga_{'reduced' if reduced else 'full'}.pdf"), bbox_inches="tight")

In [None]:
tcga_large_heatmap(tcga_feature_labels, reduced=True)

In [None]:
tcga_large_heatmap(tcga_feature_labels, reduced=False)

## Sensitivity Analyses

In [None]:
def create_syn_sensitivity_results(x, ablation):
  all_selections = torch.load(osp.join("experiments", "results", "ablations", "sensitivity", f"syn{x}", f"{ablation}.pt"))
  X_test = torch.load(osp.join("datasets", "data", f"syn{x}", "X_test_std.pt"))

  pos_ids = torch.where(X_test[:, -1] >= 0.0)[0]
  neg_ids = torch.where(X_test[:, -1] < 0.0)[0]

  if x == 1:
    neg_features = np.array([0, 1, 10])
    pos_features = np.array([2, 3, 4, 5, 10])
  elif x == 2:
    neg_features = np.array([0, 1, 10])
    pos_features = np.array([6, 7, 8, 9, 10])
  elif x == 3:
    neg_features = np.array([2, 3, 4, 5, 10])
    pos_features = np.array([6, 7, 8, 9, 10])
  else:
    raise ValueError(f"Unknown synthetic dataset number: {x}, should be 1, 2 or 3.")

  counts = {}
  for sensitivity_value in all_selections.keys():
    selections = all_selections[sensitivity_value]
    neg_arr = []
    pos_arr = []
    for rpt in range(len(selections)):
      neg_arr.append(run_once(selections[rpt][:, neg_ids], neg_features))
      pos_arr.append(run_once(selections[rpt][:, pos_ids], pos_features))
    neg_arr = np.stack(neg_arr, axis=0)
    pos_arr = np.stack(pos_arr, axis=0)
    counts[sensitivity_value] = np.concatenate([neg_arr, pos_arr], axis=-1)
  return counts

In [None]:
def make_sensitivity_plot(ablation):
  syn_counts = {
    "syn1": create_syn_sensitivity_results(1, ablation),
    "syn2": create_syn_sensitivity_results(2, ablation),
    "syn3": create_syn_sensitivity_results(3, ablation),
  }

  fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(18, 5))

  for i in range(3):
    means = []
    errs = []
    x = []
    counts = syn_counts[f"syn{i+1}"]
    for x_value in counts.keys():
      results = counts[x_value]
      results = np.mean(results, axis=-1)
      mean = np.mean(results)
      err = np.std(results, ddof=1)/(len(results)**0.5)
      means.append(mean)
      errs.append(err)
      if ablation == "beta":
        x_value = x_value.replace("_", ".")
      x.append(float(x_value))

    x = np.array(x)
    sorted_ids = np.argsort(x)
    x = x[sorted_ids]
    means = np.array(means)[sorted_ids]
    errs = np.array(errs)[sorted_ids]

    ax[i].fill_between(x, means-errs, means+errs, alpha=0.3, color="tab:blue")
    ax[i].plot(x, means, linewidth=2, color="tab:blue", linestyle="--", marker="o", markersize=8)
    if ablation == "beta":
      ax[i].set_xscale('symlog', linthresh=0.000001)
      ax[i].set_xlabel(r"$\beta$", fontsize=20)
    elif ablation == "acq_sample":
      ax[i].set_xscale("log")
      ax[i].set_xlabel(f"No. Acquisition Samples", fontsize=15)
    elif ablation == "train_sample":
      ax[i].set_xscale("log")
      ax[i].set_xlabel(f"No. Training Samples", fontsize=15)
    elif ablation == "num_latents":
      if i == 0:
        ax[i].set_ylim([3.9, 4.3])
      if i == 1:
        ax[i].set_ylim([3.9, 4.3])
      if i == 2:
        ax[i].set_ylim([5.0, 5.8])
      ax[i].set_xlabel(f"No. Latent Components", fontsize=15)
    ax[i].set_title(f"Syn{i+1}", fontsize=20)
    ax[i].set_ylabel("Required No. Acquisitions", fontsize=15)
    ax[i].grid()
    ax[i].tick_params(axis="both", labelsize=12)

  plt.savefig(osp.join("plots", f"{ablation}_sensitivity.pdf"), bbox_inches="tight")

In [None]:
make_sensitivity_plot("beta")
make_sensitivity_plot("acq_sample")
make_sensitivity_plot("train_sample")
make_sensitivity_plot("num_latents")