In [1]:
from cc_wrapper import CCWrapper
from cc_wrapper import SMALL_VAR, MID_VAR, ALL_VAR, LATEX_NAME, POSITIONS_LT, CHAMBER_CONFIGURATIONS
from cc_ground_truth import graph, edges

import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches
import seaborn as sns

# True DAG with actuator/sensor/sensor setting division

In [2]:
EXP_FAMILY = "lt_interventions_standard_v1"
NODESIZE = 400
COLORS = ['lightgreen', 'gold', 'powderblue']
css_colors = []
for color in COLORS:
    if color in mcolors.CSS4_COLORS:
        css_colors.append(mcolors.CSS4_COLORS[color])
    else:
        raise ValueError(f"Color '{color}' is not a valid CSS color name.")

def edge_explanation(save = None):
    edg = edges(chamber=CHAMBER_CONFIGURATIONS[EXP_FAMILY]['chamber'], configuration=CHAMBER_CONFIGURATIONS[EXP_FAMILY]['configuration'])
    edg_var = list(filter(lambda edge: all([node in ALL_VAR for node in edge]), edg)) # Should not change anything
    if not edg == edg_var:
        print("Filtering removed edges.")
    # Split nodes 
    actuators = ['red', 'green', 'blue', 'pol_1', 'pol_2', 'l_11', 'l_12', 'l_21', 'l_22', 'l_31', 'l_32']
    sensors = ['current', 'angle_1', 'angle_2', 'ir_1', 'vis_1', 'ir_2', 'vis_2', 'ir_3', 'vis_3']
    sensor_parameters = list(set(ALL_VAR).difference(set(actuators+sensors)))
    assert len(actuators) + len(sensors) + len(sensor_parameters) == len(ALL_VAR)
    # Filter edges
    edg_actuators = list(filter(lambda edge: any([node in actuators for node in edge]), edg_var))
    edge_sensors = list(filter(lambda edge: all([node in [*actuators, *sensors] for node in edge]), edg_var))
    edge_sensors = list(filter(lambda edge: any([node in sensors for node in edge]), edge_sensors))
    edge_sensor_parameters = list(filter(lambda edge: any([node in sensor_parameters for node in edge]), edg_var))

    # Add nodes
    fig, ax = plt.subplots(figsize=(12,8))
    G = nx.DiGraph()
    common_kwargs = {'G': G,
                    'pos': POSITIONS_LT,
                    'node_size': NODESIZE,
                    'edgelist': [],
                    'ax': ax,
                    'with_labels': False, 
                    'arrows': True}
    nx.draw_networkx(nodelist=actuators,
                    node_color=css_colors[0],
                    **common_kwargs)
    nx.draw_networkx(nodelist=sensors,
                    node_color=css_colors[1],
                    **common_kwargs)
    nx.draw_networkx(nodelist=sensor_parameters,
                    node_color=css_colors[2],
                    **common_kwargs)
    nx.draw_networkx_labels(G=G,
                            pos=POSITIONS_LT,
                            labels = {label: LATEX_NAME(label) for label in ALL_VAR},
                            ax = ax)
    # Add edges
    common_kwargs = {
        'G': G,
        'pos': POSITIONS_LT,
        'ax': ax,
        'node_size': NODESIZE
    }
    nx.draw_networkx_edges(edgelist=edg_actuators, 
                            edge_color=COLORS[0],
                            **common_kwargs)
    nx.draw_networkx_edges(edgelist=edge_sensor_parameters, 
                            edge_color=COLORS[2],
                            **common_kwargs)

    # Add legend
    small_patch = mpatches.Patch(color=css_colors[0], label='Actuator')
    mid_patch = mpatches.Patch(color=css_colors[1], label='Sensor')
    all_patch = mpatches.Patch(color=css_colors[2], label='Sensor Parameter')
    plt.legend(handles=[small_patch, mid_patch, all_patch], 
        loc='lower center',
        bbox_to_anchor=(0.5, -0.15),
        title='Category', 
        ncol=3)
    plt.subplots_adjust(bottom=0.2)
    if save is not None:
        plt.savefig(f"./vis/{save}.pdf", format="pdf", dpi=150, bbox_inches='tight')
    else:
        plt.show()


In [None]:
# Figure 3.2 in thesis
edge_explanation(save="causal_chamber_ground_truth.pdf")

In [None]:
# Not used in thesis
EXP_FAMILY = "lt_interventions_standard_v1"
ccw = CCWrapper()
ccw.set_exp_family(EXP_FAMILY)
ccw.set_variables(ALL_VAR)
df = ccw.fetch_experiments(["uniform_reference"])[0]
df.describe()

# True DAG including my division

In [5]:
EXP_FAMILY = "lt_interventions_standard_v1"
NODESIZE = 400
COLORS = ['burlywood', 'thistle', 'yellowgreen']
css_colors = []
for color in COLORS:
    if color in mcolors.CSS4_COLORS:
        css_colors.append(mcolors.CSS4_COLORS[color])
    else:
        raise ValueError(f"Color '{color}' is not a valid CSS color name.")

def draw_true_dag(save = None):
    edg = edges(chamber=CHAMBER_CONFIGURATIONS[EXP_FAMILY]['chamber'], configuration=CHAMBER_CONFIGURATIONS[EXP_FAMILY]['configuration'])
    edg_var = list(filter(lambda edge: all([node in ALL_VAR for node in edge]), edg)) # Should not change anything
    if not edg == edg_var:
        print("Filtering removed edges.")
    # Split nodes 
    small_var = SMALL_VAR
    mid_minus_small_var = list(filter(lambda x: x not in SMALL_VAR, MID_VAR))
    all_minus_mid_var = list(filter(lambda x: x not in MID_VAR, ALL_VAR))
    # Filter edges
    edg_small = list(filter(lambda edge: all([node in small_var for node in edge]), edg_var))
    edg_mid_minus_small = list(filter(lambda edge: all([node in MID_VAR for node in edge]), edg_var))
    edg_mid_minus_small = list(filter(lambda edge: any([node in mid_minus_small_var for node in edge]), edg_mid_minus_small))
    edg_all_minus_mid = list(filter(lambda edge: all([node in ALL_VAR for node in edge]), edg_var))
    edg_all_minus_mid = list(filter(lambda edge: any([node in all_minus_mid_var for node in edge]), edg_all_minus_mid))
    # Add nodes
    fig, ax = plt.subplots(figsize=(12,8))
    G = nx.DiGraph()
    common_kwargs = {'G': G,
                    'pos': POSITIONS_LT,
                    'node_size': NODESIZE,
                    'edgelist': [],
                    'ax': ax,
                    'with_labels': False, 
                    'arrows': True}
    nx.draw_networkx(nodelist=small_var,
                    node_color=css_colors[0],
                    **common_kwargs)
    nx.draw_networkx(nodelist=mid_minus_small_var,
                    node_color=css_colors[1],
                    **common_kwargs)
    nx.draw_networkx(nodelist=all_minus_mid_var,
                    node_color=css_colors[2],
                    **common_kwargs)
    nx.draw_networkx_labels(G=G,
                            pos=POSITIONS_LT,
                            labels = {label: LATEX_NAME(label) for label in ALL_VAR},
                            ax = ax)
    # Add edges
    common_kwargs = {
        'G': G,
        'pos': POSITIONS_LT,
        'ax': ax,
        'node_size': NODESIZE
    }
    nx.draw_networkx_edges(edgelist=edg_small, 
                            edge_color=COLORS[0],
                            **common_kwargs)
    nx.draw_networkx_edges(edgelist=edg_mid_minus_small, 
                            edge_color=COLORS[1],
                            **common_kwargs)
    nx.draw_networkx_edges(edgelist=edg_all_minus_mid, 
                            edge_color=COLORS[2],
                            **common_kwargs)

    # Add legend
    small_patch = mpatches.Patch(color=css_colors[0], label='Small')
    mid_patch = mpatches.Patch(color=css_colors[1], label='Mid')
    all_patch = mpatches.Patch(color=css_colors[2], label='All')
    plt.legend(handles=[small_patch, mid_patch, all_patch], 
        loc='lower center',
        bbox_to_anchor=(0.5, -0.15),
        title='Category', 
        ncol=3)
    plt.subplots_adjust(bottom=0.2)
    if save is not None:
        plt.savefig(f"./vis/{save}.pdf", format="pdf", dpi=150, bbox_inches='tight')
    else:
        plt.show()


In [None]:
# Not used in thesis
draw_true_dag()

# Pairwise comparison plot

In [7]:
EXP_FAMILY = "lt_interventions_standard_v1"
MARKER = "Experiment"

def pairplot(experiments: list[str], 
             colors: list[str], 
             variables: list[str], 
             alpha: float = 0.1,
             datapoints_per_frame: int = 250,
             smoother: bool = False, 
             legend: bool = True, 
             save = None):
    assert len(experiments)==len(colors)
    # Validate colors: https://matplotlib.org/stable/gallery/color/named_colors.html
    css_colors = []
    for color in colors:
        if color in mcolors.CSS4_COLORS:
            css_colors.append(mcolors.CSS4_COLORS[color])
        else:
            raise ValueError(f"Color '{color}' is not a valid CSS color name.")
    # Set up CausalChamber Wrapper
    ccw = CCWrapper()
    ccw.set_exp_family(EXP_FAMILY)
    ccw.set_variables(variables)
    # Fetch data
    dataframes = ccw.fetch_experiments(
                experiments=experiments, 
                sizes=[datapoints_per_frame]*len(experiments)
                ) # Returns list
    # Concatenate all dfs with tag
    for i,df in enumerate(dataframes):
        df['Experiment'] = experiments[i]
    total_data = pd.concat(dataframes)
    # Raname columns to latex format
    ltx_names = {f"{i}": f"{LATEX_NAME(i)}" for i in list(total_data.columns)}
    ltx_names[MARKER] = MARKER
    total_data.rename(columns=ltx_names, inplace=True)
    # Plot
    n_vars = len(variables)
    fig_size = 3 * n_vars
    color_palette = dict(zip(experiments, css_colors))
    pp = sns.pairplot(data=total_data, 
                      plot_kws={'alpha': alpha}, 
                      hue=MARKER, 
                      palette=color_palette,
                      diag_kind='hist', 
                      diag_kws={'bins': 20},
                      height=fig_size / n_vars, 
                      aspect=1)
    pp.fig.set_size_inches(fig_size, fig_size)
    # Add smoother to each subplot
    if smoother and len(experiments)==1:
        for i in range(len(pp.axes)):
            for j in range(len(pp.axes)):
                if i != j:  # Skip diagonal
                    sns.regplot(x=pp.x_vars[j], y=pp.y_vars[i], 
                                data=total_data, 
                                lowess=True, 
                                scatter=False, 
                                ci=None,
                                line_kws={'color': 'red'},
                                ax=pp.axes[i, j])

    for ax in pp.axes.flatten():
        ax.set_xlabel(ax.get_xlabel(), fontsize=25, labelpad=20)
        ax.set_ylabel(ax.get_ylabel(), fontsize=25, labelpad=20)
        ax.xaxis.label.set_rotation(0)
        ax.yaxis.label.set_rotation(0)
    # Deal with legend, first delete the old one
    handles = [mpatches.Patch(color=color, label=exp) for color, exp in zip(colors, experiments)]
    labels = experiments
    pp._legend.remove()
    # New legend
    if legend:
        legend = pp.fig.legend(handles=handles, 
                            labels=labels, 
                            loc='lower center', 
                            ncol=len(experiments), 
                            fontsize='xx-large', 
                            title=MARKER, 
                            title_fontsize='xx-large',
                            borderpad=1,    
                            labelspacing=1,
                            handlelength=1.5,  # Length of color patch
                            handleheight=1.5) # Height of color patch
        # Legend's frame
        legend.get_frame().set_linewidth(2)
        # Make space for legend
        pp.fig.subplots_adjust(top=0.92, bottom=0.12)
    if save is not None:
        plt.savefig(f"./vis/{save}.pdf", format="pdf", dpi=150, bbox_inches='tight')
    else:
        plt.show()


In [None]:
# Figure 3.3 in thesis
pairplot(experiments=['uniform_reference'], 
         colors = ['black'],
         variables=[*SMALL_VAR,"pol_1", "pol_2"], 
         alpha=0.1,
         datapoints_per_frame=500,
         smoother=False, 
         legend=False, 
         save="Pairplot-UniformReference-SmallVarplusTheta")

In [None]:
# Not used in thesis
pairplot(experiments=['uniform_reference'], 
         colors = ['black'],
         variables=MID_VAR, smoother=False, legend=False)

In [None]:
# Not used in thesis
pairplot(experiments=['uniform_reference', 'uniform_red_mid', 'uniform_green_mid'], 
         colors = ['grey', 'red', 'green'],
         variables=['red', 'green', 'blue', 'ir_1', 'vis_1'], smoother=False)

In [11]:
import numpy as np
from scipy.stats import ks_2samp
EXP_FAMILY = "lt_interventions_standard_v1"

def ks_test(experiments: list[str], variable: str):
    assert len(experiments)==2
    ccw = CCWrapper()
    ccw.set_exp_family(EXP_FAMILY)
    ccw.set_variables([variable])
    # Fetch data
    dataframes = ccw.fetch_experiments(
                experiments=experiments
                ) # Returns list
    data1, data2 = dataframes[0].values.flatten(), dataframes[1].values.flatten()
    stat, p_value = ks_2samp(data1-np.mean(data1), data2-np.mean(data2))
    print(f"{variable} on {experiments[0]} vs. {experiments[1]}: P-value: {p_value}; Std-Diff {np.std(data1)- np.std(data2)}")

In [None]:
# Not used in thesis
ks_test(experiments=['uniform_reference', 'uniform_red_mid'], variable="ir_1")
ks_test(experiments=['uniform_reference', 'uniform_green_mid'], variable="ir_1")
ks_test(experiments=['uniform_red_mid', 'uniform_green_mid'], variable="ir_1")
ks_test(experiments=['uniform_reference', 'uniform_red_mid'], variable="vis_1")
ks_test(experiments=['uniform_reference', 'uniform_green_mid'], variable="vis_1")
ks_test(experiments=['uniform_red_mid', 'uniform_green_mid'], variable="vis_1")

In [None]:
# Not used in thesis
pairplot(experiments=['uniform_reference', 'uniform_red_strong', 'uniform_pol_1_strong'], 
         colors = ['grey', 'red', 'blue'],
         variables=MID_VAR, smoother=False)

In [None]:
# Not used in thesis
pairplot(experiments=['uniform_reference', 'uniform_l_11_mid'], 
         colors = ['blue', 'red'],
         variables=SMALL_VAR, smoother=False)

In [None]:
# Not used in thesis
pairplot(experiments=['uniform_reference'], 
         colors = ['grey'],
         variables=['red','green','blue','l_11', 'l_12', 'ir_1', 'vis_1'], smoother=True)

# Marginal Variances

In [16]:
EXP_FAMILY = "lt_interventions_standard_v1"

def plot_marginal_variance(experiments: list[str], colors: list[str], variables: list[str], log_scale: bool, save = None):
    css_colors = []
    for color in colors:
        if color in mcolors.CSS4_COLORS:
            css_colors.append(mcolors.CSS4_COLORS[color])
        else:
            raise ValueError(f"Color '{color}' is not a valid CSS color name.")

    # Set up CausalChamber Wrapper
    ccw = CCWrapper()
    ccw.set_exp_family(EXP_FAMILY)
    ccw.set_variables(variables)
    # Fetch data
    dfs = ccw.fetch_experiments(experiments=experiments)
        
    stdvs = {}
    for exp, df in zip(experiments, dfs):
        df.columns = [LATEX_NAME(col) for col in list(df.columns)]
        stdvs[exp] = df.std()
    plt.figure(figsize=(12,12))
    combined_stdv = pd.DataFrame(stdvs)
    fig, ax = plt.subplots(figsize=(len(variables)/3,6))
    combined_stdv.plot(kind='bar', logy=log_scale, color = css_colors, ax=ax)

    plt.title('Marginal Standard Deviation')
    plt.xlabel('Variables')
    plt.ylabel('Standard deviation')
    plt.xticks(rotation=0)
    ax.legend(title='Experiments', loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=len(experiments))
    plt.tight_layout()
    if save is not None:
        plt.savefig(f"./vis/{save}.pdf", format="pdf", dpi=150, bbox_inches='tight')
    else:
        plt.show()


In [None]:
# Figure 4.9 in thesis
plot_marginal_variance(experiments=['uniform_reference'], #, 'uniform_red_strong', 'uniform_green_mid'
         colors = ['grey'], #, 'tomato', 'palegreen'
         variables=MID_VAR,
           log_scale=True,
           save = "marginal_variance_plot")

# Conditional dependence/independence

In [18]:
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
EXP_FAMILY = "lt_interventions_standard_v1"
experiments = ['uniform_reference']
def conditional_dependence(X: str, Y: str, Z: list[str], linear: bool, plot: bool = True, save = False):
    ccw = CCWrapper()
    ccw.set_exp_family(EXP_FAMILY)
    var = [X,Y,*Z] if Z is not False else [X,Y]
    ccw.set_variables(var)
    # Fetch data
    df = ccw.fetch_experiments(
                experiments=experiments,
                sizes=[10000]
                )[0] # Returns list
    X_dat, Y_dat = df[[X]].values, df[[Y]].values
    if Z is not False:
        Z_dat = df[Z].values
    X_dat = X_dat.flatten()
    Y_dat = Y_dat.flatten()

    split_point = int(len(X_dat) * 0.5)

    X_train, X_test = X_dat[:split_point], X_dat[split_point:]
    Y_train, Y_test = Y_dat[:split_point], Y_dat[split_point:]
    if Z is not False:
        Z_train, Z_test = Z_dat[:split_point], Z_dat[split_point:]
    
    # Fit 
    if Z is not False:
        if linear:
            ols = LinearRegression(fit_intercept=True)
            ols.fit(X=Z_train, y=X_train)
            res_X = X_test - ols.predict(X=Z_test)

            ols.fit(X=Z_train, y=Y_train)
            res_Y = Y_test - ols.predict(X=Z_test)
        else:
            rf = RandomForestRegressor(n_estimators=1000)
            rf.fit(X=Z_train, y=X_train)
            res_X = X_test - rf.predict(X=Z_test)
            
            rf.fit(X=Z_train, y=Y_train)
            res_Y = Y_test - rf.predict(X=Z_test)

        # Compute residual correlation:
        corr = np.corrcoef(res_X,res_Y)[0,1]
        if plot is False:
            return corr
    else:
        corr = np.corrcoef(X_test, Y_test)[0,1]
        if plot is False:
            return corr
    
    # Plot the residuals
    prefix = "OLS " if linear else "RF "
    plt.scatter(res_X, res_Y, alpha=0.01)
    plt.xlabel(f'{LATEX_NAME(X)} Residuals')
    plt.ylabel(f'{LATEX_NAME(Y)} Residuals')
    # plt.title(f'{prefix}Residuals: {LATEX_NAME(X)} vs {LATEX_NAME(Y)} after fitting both on {",".join([LATEX_NAME(z) for z in Z])}')
    plt.grid(True)
    if save is not False:
        plt.savefig(f"./vis/{save}.pdf", format="pdf", dpi = 150, bbox_inches="tight")
        print(f"Figure saved under: {save}")

In [19]:
import itertools
import seaborn as sns
import matplotlib.pyplot as plt
def conditional_correlation_matrix(variables: list[str], conditioning_set: list[str], leave_out = set(), take_in = set(), linear: bool = True, save = False):
    """
    Provide conditioning set to condition on the provided variables.
    Provide "remaining" as conditioning set to condition on all other variables.
    Provide "marginal" as conditioning set to compute marginal correlations without conditioning on anything.
    """
    pairs = list(itertools.combinations(variables, 2))
    conditional_corr_matrix = np.ones(shape=(len(variables), len(variables)))
    ltx_variables = [LATEX_NAME(var) for var in variables]
    conditional_corr_matrix = pd.DataFrame(conditional_corr_matrix, index=ltx_variables, columns=ltx_variables)
    for pair in pairs: 
        x, y = pair
        if not isinstance(conditioning_set, str):
            cond_set = conditioning_set
        elif conditioning_set == "remaining":
            rem = set(variables).difference(leave_out)
            rem = rem.union(take_in)
            cond_set = list(rem.difference(set(pair)))
        elif conditioning_set == "marginal":
            cond_set = False
        else:
            raise ValueError("Provide a valid conditioning set!")
        corr = conditional_dependence(X=x, Y=y, Z=cond_set, linear=linear, plot=False)
        conditional_corr_matrix.loc[LATEX_NAME(x), LATEX_NAME(y)] = abs(corr)
        conditional_corr_matrix.loc[LATEX_NAME(y), LATEX_NAME(x)] = abs(corr)
    cmap = plt.colormaps.get_cmap("Greys")
    ax = sns.heatmap(data=conditional_corr_matrix, cmap=cmap, vmin=0, vmax=1, annot=True, fmt=".2f")
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    #prefix = "OLS " if linear else "RF "
    #ax.set_title(f"{prefix}Residual absolute correlation after fitting on {','.join([LATEX_NAME(z) for z in conditioning_set])}")
    if save is not False:
        plt.savefig(f"./vis/{save}.pdf", format="pdf", dpi = 150, bbox_inches="tight")
        print(f"Figure saved under: {save}")

In [None]:
# Figure 4.1 a) in thesis
conditional_correlation_matrix(variables=["ir_1", "ir_2", "ir_3", "vis_1", "vis_2", "vis_3"], 
                               conditioning_set=['red', 'green', 'blue', 'l_11', 'l_12', 'l_21', 'l_22', 'l_31', 'l_32', 'pol_1', 'pol_2'],
                               save="light_sensors_OLS_conditional_correlation")

In [None]:
# Figure 4.1 b) in thesis
# Add theta1 and theta2 to conditioning set. 
conditional_correlation_matrix(variables=["ir_1", "ir_2", "ir_3", "vis_1", "vis_2", "vis_3"], 
                               conditioning_set=['red', 'green', 'blue', 'l_11', 'l_12', 'l_21', 'l_22', 'l_31', 'l_32', 'pol_1', 'pol_2'],
                               linear=False,
                               save="light_sensors_RF_conditional_correlation")

In [None]:
# Figure 4.1 c) in thesis
conditional_dependence(X="ir_1", Y="vis_1", Z=['red', 'green', 'blue', 'l_11', 'l_12'], linear=False,
                       save="ir_1_versus_vis_1_conditional_residuals")

Why PC directs edges from light sensors to light sources

In [None]:
# Figure 4.2 a) in thesis
conditional_correlation_matrix(variables=["red", "green", "blue", "current", "ir_1", "ir_2", "ir_3", "vis_1", "vis_2", "vis_3"],
                               conditioning_set="marginal",
                               save="marginal_correlation")

In [None]:
# Figure 4.2 b) in thesis
conditional_correlation_matrix(variables=["red", "green", "blue", "current", "ir_1", "ir_2", "ir_3", "vis_1", "vis_2", "vis_3"],
                               conditioning_set="remaining",
                               take_in = {"l_11", "l_12", "l_21", "l_22", "l_31", "l_32", "pol_1", "pol_2", "angle_1", "angle_2"},
                               save="conditional_correlation_based_on_all_other")

In [None]:
# Figure 4.3 a) in thesis
conditional_correlation_matrix(variables=["red", "green", "blue", "current", "ir_1", "ir_2", "ir_3", "vis_1", "vis_2", "vis_3"],
                               conditioning_set="remaining", 
                               take_in={"l_11", "l_12", "l_21", "l_22", "l_31", "l_32", "pol_1", "pol_2", "angle_1", "angle_2"},
                               leave_out={"green"},
                               save="conditional_correlation_leave_out_green")

In [None]:
# Figure 4.3 b) in thesis
conditional_correlation_matrix(variables=["red", "green", "blue", "current", "ir_1", "ir_2", "ir_3", "vis_1", "vis_2", "vis_3"],
                               conditioning_set="remaining", 
                               take_in={"l_11", "l_12", "l_21", "l_22", "l_31", "l_32", "pol_1", "pol_2", "angle_1", "angle_2"},
                               leave_out={"blue"},
                               save="conditional_correlation_leave_out_blue")

In [None]:
# Figure 4.7 in thesis
conditional_correlation_matrix(variables=["ir_3","vis_3", "pol_1", "pol_2", "angle_1", "angle_2"],
                               conditioning_set="marginal",
                               save="sensor3_polarizers_marginal_correlation")