In [23]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import pandas as pd
from tqdm import tqdm
import multiprocessing as mp
plt.style.use('bmh')

os.chdir('project/learning_causal_discovery/')

In [None]:
# get transistor meta-data
!cd .cache/sim_data && wget -N https://s3-us-west-2.amazonaws.com/ericmjonas-public/data/neuroproc/transistors.csv

In [27]:
interval = 10
stepLimit = 2000
root_dir = '.cache/sim_data'

In [28]:
# helper functions
def get_cmap_slice(cmap, start, stop, n=256, name='my_slice'):
    '''
    Create a slice of a colormap.
    '''
    return colors.LinearSegmentedColormap.from_list(name, cmap(np.linspace(start, stop, cmap.N)),N=n)

# get geometry information of transistors
ts_info = pd.read_csv(os.path.join(os.path.dirname(root_dir), 'transistors.csv'))

def get_strength(kwargs):
    idx, game = kwargs['idx'], kwargs['game']
    if isinstance(idx, int):
        unique_perturb = pickle.load(open(os.path.join(root_dir, f"{game}/perturb_config.pkl"), "rb"), encoding='latin1')
        orig = np.load(os.path.join(root_dir, f'{game}/HR/Regular_3510_step_256_rec_2e3.npy'), mmap_mode='r')[:, ::interval]
        potential_resultant = []
        perturb = np.load(os.path.join(root_dir, f"{game}/HR/Adaptive_3510_step_256_tidx_{idx}.npy"),
                          mmap_mode='r')
        padded_perturb = np.concatenate(
            (perturb[:, :-1], np.tile(perturb[:, -2].reshape(-1, 1), stepLimit - perturb.shape[1] + 1)), axis=1)[:, ::interval]
        for i in range(3510):
            if i != idx:
                if unique_perturb[idx][1] == 'high':
                    div_point = np.where(orig[idx] != 1)[0][0]
                else:
                    div_point = np.where(orig[idx] != 0)[0][0]

                if not (padded_perturb[i][(div_point - (unique_perturb[idx][0] * stepLimit//interval)):] == orig[i, div_point: ((unique_perturb[idx][0] + 1) * stepLimit//interval)]).all() and \
                        (padded_perturb[i][:(div_point - (unique_perturb[idx][0] * stepLimit//interval))] == orig[i, (unique_perturb[idx][0] * stepLimit//interval):div_point]).all():
                    # calculate the cause effect of the perturbation
                    potential_resultant.append(
                        sum(abs(padded_perturb[i][(div_point - (unique_perturb[idx][0] * stepLimit//interval)):] - orig[i, div_point: ((unique_perturb[idx][0] + 1) * stepLimit//interval)])))
                else:
                    potential_resultant.append(0)
            else:
                potential_resultant.append(0)
        return potential_resultant
    elif isinstance(idx, str) and idx == 'all':
        idx = list(pickle.load(open(os.path.join(root_dir, f"{game}/perturb_config.pkl"), "rb")).keys())
    else:
        assert isinstance(idx, list), 'idx must be an integer or a list of integers'
    with mp.Pool(mp.cpu_count()) as pool:
        potential_resultant = pool.map(get_strength, [{'idx': i, 'game': game} for i in idx])
        potential_resultant = iter(np.mean(potential_resultant, axis=-1).tolist())
        potential_resultant = [next(potential_resultant) if i in idx else 0 for i in range(3510)]
    return potential_resultant

def plot_cause_effect(target_idx, save_dir=None):
    games = ['DonkeyKong', 'Pitfall', 'SpaceInvaders']
    effect = {game: get_strength({'idx': target_idx, 'game': game}) for game in games}
    max_effect = np.max([_effect for t_effect in effect.values() for _effect in t_effect])
    fig, axs = plt.subplots(1, len(games), figsize=(len(games)*10, 10), sharex=True, sharey=True)
    cmap = get_cmap_slice(plt.get_cmap("Reds"), 0.2, 1.0)
    for idx, (game, _effect) in enumerate(effect.items()):
        # Normalize the causal effect
        if max_effect != 0:
            causal_effect = np.array(_effect) / max_effect
        else:
            causal_effect = np.array(_effect)
        # Heatmap of causal effect
        im = axs[idx].scatter(ts_info.x, ts_info.y, c=causal_effect, cmap=cmap, edgecolor='none', vmin=0, vmax=1)
        if not isinstance(target_idx, list) and not isinstance(target_idx, str):
            axs[idx].scatter(ts_info.x.iloc[target_idx], ts_info.y.iloc[target_idx], c='blue', s=10**2, edgecolor='none')
        axs[idx].tick_params(axis='x', labelsize=12)
        axs[idx].tick_params(axis='y', labelsize=12)
        axs[idx].set_title(game, fontdict={'family': 'Serif'}, fontsize=30)

    fig.set_facecolor('w')
    fig.subplots_adjust(bottom=0.1, top=0.9, left=0.1, right=0.8, wspace=0.05)
    fig.text(0.45, 0.05, "X Position (um)", fontdict={'family': 'Serif'}, size=20)
    fig.text(0.07, 0.5, "Y Position (um)", rotation='vertical', fontdict={'family': 'Serif'}, size=20)
    cbar = fig.colorbar(im, cax=plt.axes([0.82, 0.1, 0.02, 0.8]))
    cbar.set_label('Cause Effect Strength', fontdict={'family': 'Serif'}, size=20)
    cbar.set_ticks(ticks=np.arange(0.0, 1.0, 0.1))
    cbar.ax.tick_params(labelsize=10)
    fig.clim = (0.0, 1.0)
    if save_dir is not None:
        plt.savefig(f"{save_dir}.svg")


def plot_transistors(target_idx, save_dir=None):
    fig = plt.figure(figsize=(10, 10))
    # specific colors for target idx
    cls = iter(['#DE1334', '#FF414C', '#6752FF'])
    cmap = get_cmap_slice(plt.get_cmap("Reds"), 0.2, 1.0)
    c = [idx for idx in range(3510) if idx not in target_idx]
    # Heatmap of causal effect
    plt.scatter(ts_info.x[c], ts_info.y[c], c=[0]*len(c), cmap=cmap, edgecolor='none')
    for idx in target_idx:
        plt.scatter(ts_info.x[idx], ts_info.y[idx], c=next(cls), s=10**2, cmap=cmap, edgecolor='none', label=f'transistor {idx}')
    if not isinstance(target_idx, list) and not isinstance(target_idx, str):
        plt.scatter(ts_info.x.iloc[target_idx], ts_info.y.iloc[target_idx], c='blue', s=10**2, cmap=cmap, edgecolor='none')
    plt.xticks([])
    plt.yticks([])
    fig.clim = (0.0, 1.0)
    plt.legend()
    if save_dir is not None:
        plt.savefig(f"{save_dir}.svg")


In [None]:
plot_cause_effect(target_idx=990, save_dir='/home/charon/project/nmos_inference/figures/Figure 2. Causal Effect/transistor_990')

In [None]:
plot_cause_effect(target_idx=3057, save_dir='/home/charon/project/nmos_inference/figures/Figure 2. Causal Effect/transistor_3057')

In [None]:
plot_cause_effect(target_idx=1, save_dir='/home/charon/project/nmos_inference/figures/Figure 2. Causal Effect/transistor_1')