In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import cmcrameri.cm as cmc
sns.set_style("white", {"grid.linestyle": (0, (1,1)), 
                        "axes.edgecolor": "gray",
                        "axes.linewidth": 0.5})
import os
plt.rcParams["figure.figsize"] = [13., 13.]
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = "Helvetica"
title_font = {"size": 15, "weight": "bold", "y":1.05, "horizontalalignment":"center", "verticalalignment":"center"}
label_font = {"size": 20, "labelspacing": 0.1}

#### Description
Radar/Spider plots to visualise TCAV scores.

In [None]:
cmap = cmc.batlowS
cmap

In [None]:
v = 4.02
load_path =  r"../analysis/tcav/plots/%s/" %v
try:
    os.mkdir(load_path)
except:
    pass

In [None]:
def make_plot(mu: list, mad: list, concepts: list, quantiles: list,  r_mu: float, r_mad: float, title: str):
    theta = np.linspace(0, 2 * np.pi, len(concepts), endpoint = False).tolist()
    theta += theta[:1]
    concepts += concepts[:1]
    mad += mad[:1]
    mu += mu[:1]
    quantiles += quantiles[:1]
    mad_l = [x - y for x, y in zip(mu,mad)]
    mad_h = [x + y for x,y in zip(mu,mad)]
    std_l  =[x - y for x, y in zip(mu,quantiles)]
    std_h = [x - y for x, y in zip(mu,quantiles)]

    r_mu = [r_mu] * len(mu)
    r_mad = [r_mad] * len(mu)
    r_mad_l = [r_mu[-1] - r_mad[-1]] * len(mu)
    r_mad_h = [r_mu[-1] + r_mad[-1]] * len(mu)

    
    fig, ax = plt.subplots(subplot_kw=dict(projection="polar"))
    plt.grid(linewidth=2)
    
    ## RANDOM
    #ax.fill_between(theta, r_mad_l, r_mad_h, alpha=0.1, color=cmap(3))
    ax.fill_between(theta, 0.0, r_mad_h, alpha=0.15, color=cmap(3))
    ax.plot(theta, r_mu, color=cmap(3),  alpha= 0.9, label="Random", 
            linewidth=3, marker="o", markersize = 12)
    ### CONCEPT
    #ax.fill_between(theta, std_l, std_h, alpha=0.3, color="#bbd3f3")
    ax.fill_between(theta, mad_l, mad_h, alpha=0.55, color=cmap(2))
    ax.plot(theta, mu, color=cmap(2), label="Concept Scores", 
            linewidth=3.5, marker="o", markersize=14) # concept
    plt.setp(ax.spines.values(),linewidth=0)
    plt.setp(ax.get_yticklabels(), fontsize=17)
    ax.set_theta_offset(np.pi/2)
    ax.set_theta_direction(-1)
    ax.set_thetagrids(np.degrees(theta), concepts)
    for i, (label, angle) in enumerate(zip(ax.get_xticklabels(), theta)):
        if i==0:
            label.set_size(0)
        else:
            label.set_size(label_font["size"])
        if angle in (0, np.pi):
            label.set_horizontalalignment("center")
        elif 0 < angle < np.pi:
            label.set_horizontalalignment("left")
        else:
            label.set_horizontalalignment("right")
        
    ax.set_ylim(0,1.05)
    ax.set_rlabel_position(180/(len(concepts)-1))
    ax.set_title(title, **title_font)
    ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.1))
    return fig

In [None]:
to_use = ["mental", "sex_fm", "infection","income", "managers", "agriculture", "operators"]

## Sign-based

In [None]:
## load data 
with open(load_path + "agg_pos.pkl", "rb") as f:
    stat = pickle.load(f)
stat.keys()
concept_dict = {}
r_mu = stat["random"]["median"]
r_mad = stat["random"]["b_dev"] 
for c in to_use:
    concept_dict[c] = stat[c]

In [None]:
# concepts = [ "Neural Diagnosis", "Mental and Behavioural Diagnosis", "High Income",  
#             "Managerial Position", "Clerical Support workers", 
#             "Work in Agriculture", "Infectious Decease"]
concepts = [k for k in concept_dict.keys()]
mu =  [x["median"] for x in concept_dict.values()]
mad = [x["b_dev"]  for x in concept_dict.values()]
q = [x["std"] for x in concept_dict.values()]

title = "Concept Influences on the increase of the 'Death' outcome"

fig = make_plot(mu=mu, mad=mad, r_mu = r_mu, r_mad = r_mad, title=title, concepts=concepts, quantiles = q)
plt.tight_layout()
plt.show()

fig.savefig(load_path + "tcav_d.svg", format="svg")

In [None]:
## load data 
with open(load_path + "agg_neg.pkl", "rb") as f:
    stat = pickle.load(f)
stat.keys()
concept_dict = {}
r_mu = stat["random"]["median"]
r_mad = stat["random"]["b_dev"]
for c in to_use:
    concept_dict[c] = stat[c]
    
    
concepts = [k for k in concept_dict.keys()]
mu =  [x["median"] for x in concept_dict.values()]
mad = [x["b_dev"]  for x in concept_dict.values()]
q = [x["std"] for x in concept_dict.values()]

title = "Concept Influences on the increase of the 'Survive' outcome"

fig = make_plot(mu=mu, mad=mad, r_mu = r_mu, r_mad = r_mad, title=title, concepts=concepts, quantiles=q)
plt.tight_layout()
plt.show()

fig.savefig(load_path + "tcav_s.svg", format="svg")


## Magnitude-based

In [None]:
# concepts = [ "Neural Diagnosis", "Mental and Behavioural Diagnosis", "High Income",  
#             "Managerial Position", "Clerical Support workers", 
#             "Work in Agriculture", "Infectious Decease"]
## load data 
with open(load_path + "M_agg_pos.pkl", "rb") as f:
    stat = pickle.load(f)
stat.keys()
concept_dict = {}
r_mu = stat["random"]["median"]
r_mad = stat["random"]["b_dev"] 
for c in to_use:
    concept_dict[c] = stat[c]

concepts = [k for k in concept_dict.keys()]
mu =  [x["median"] for x in concept_dict.values()]
mad = [x["b_dev"]  for x in concept_dict.values()]
q = [x["std"] for x in concept_dict.values()]

title = "(M) Concept Influences on the increase of the 'Death' outcome"

fig = make_plot(mu=mu, mad=mad, r_mu = r_mu, r_mad = r_mad, title=title, concepts=concepts, quantiles = q)
plt.tight_layout()
plt.show()

fig.savefig(load_path + "m_tcav_d.svg", format="svg")

In [None]:
## load data 
with open(load_path + "M_agg_neg.pkl", "rb") as f:
    stat = pickle.load(f)
stat.keys()
concept_dict = {}
r_mu = stat["random"]["median"]
r_mad = stat["random"]["b_dev"]
for c in to_use:
    concept_dict[c] = stat[c]
    
    
concepts = [k for k in concept_dict.keys()]
mu =  [x["median"] for x in concept_dict.values()]
mad = [x["b_dev"]  for x in concept_dict.values()]
q = [x["std"] for x in concept_dict.values()]

title = "(M) Concept Influences on the increase of the 'Survive' outcome"

fig = make_plot(mu=mu, mad=mad, r_mu = r_mu, r_mad = r_mad, title=title, concepts=concepts, quantiles=q)
plt.tight_layout()
plt.show()

fig.savefig(load_path + "M_tcav_s.svg", format="svg")
