This script creates the class weight plots shown in the Methods chapter.

In [1]:
import pandas as pd
import os
from datetime import date
from matplotlib import pyplot as plt

In [2]:
# input dir
in_dir = os.path.join("..", "data", "class_weights_overview")
tasks = ["morphology", "histology", "behavior", "site"]
assert os.path.exists(in_dir)

In [3]:
# output dir
store_results_dir = os.path.join("..", "plots", f"{date.today()}_cw_visualization")
if not os.path.exists(store_results_dir):
    os.mkdir(store_results_dir)

In [4]:
for task in tasks:
    file = os.path.join(in_dir, f"class_weights_{task}.csv")
    df = pd.read_csv(file, header=0, index_col=0)

    # add column with shortened class weight names for grouping + legend in plot
    df["cw_type_short"] = [val[:3] for val in df["cw_type+fold"].values]
    df["cw_type_short"] = ["μ=1" if val == "bl1" else val for val in df["cw_type_short"].values]
    df["cw_type_short"] = ["μ=0.15" if val == "bl0" else val for val in df["cw_type_short"].values]
    
    # exclude first column with long names
    df = df.iloc[:, 1:]
    
    # get means over folds
    # transpose: index=class labels, cols=cw_types_short
    pivot_df = df.groupby("cw_type_short").mean().transpose()

    # adjust plot size to number of classes per task
    if task == "morphology" or task == "histology":
        fig_size = (15, 5)
    else:
        fig_size = (5, 4)

    pivot_df.plot(kind="bar", figsize=fig_size, colormap="Paired")
    plt.xlabel(f"Classes")
    plt.ylabel("Class weight")
    plt.legend(loc="upper left")
    
    # font sizes
    plt.rc("axes", titlesize=15)
    plt.rc("axes", labelsize=15)
    plt.rc("xtick", labelsize=20)
    plt.rc("ytick", labelsize=20)
    plt.rc("legend", fontsize=15)
    plt.tight_layout()

    # save file to output file
    plt.savefig(os.path.join(store_results_dir, f"{task}_cws.png"))
    print("Created class weight plot for task ", task)
    plt.close()


Created class weight plot for task  morphology
Created class weight plot for task  histology
Created class weight plot for task  behavior
Created class weight plot for task  site
