In [1]:
import os.path as op
import json

import pandas as pd
from nimare.dataset import Dataset
import plotly.graph_objects as go
import plotly.express as px

import numpy as np

In [2]:
results_dir = '../results/ibma'
data_dir = '../data'

In [4]:
def _wrap_label(s, index=3):
    parts = s.split(' ', index)
    return ' '.join(parts[:index]) + '<br>' + parts[index]

In [5]:
cogat_df = pd.read_csv(op.join(data_dir, "cogat_terms.csv"))
task_names = cogat_df["cogat_nm"].tolist()
task_names.reverse()
n_tasks = len(task_names)

with open(op.join(data_dir, "cogat_terms.json"), 'r') as f:
    cogat_dict = json.load(f)

In [6]:
p_data_df = pd.read_csv(op.join(data_dir, "maps_count.csv"))

pivot_org_df = p_data_df.pivot(index='task', columns='map', values='count')
pivot_org_df = pivot_org_df[["Z", "T", "Other", "Not Available"]]
df_sorted_org_df = pivot_org_df.loc[task_names]
df_sorted_org_df.index = [cogat_dict[label] for label in df_sorted_org_df.index]
df_sorted_org_df.index = [_wrap_label(label) if label.count(' ') >= 3 else label for label in df_sorted_org_df.index]

In [7]:
colors = ["#393E46", '#6D9886', '#F2E7D5', '#F7F7F7']

fig = px.bar(
    df_sorted_org_df, 
    x=["Z", "T", "Other", "Not Available"], 
    orientation='h',
    color_discrete_sequence=colors,
)
fig.update_layout(
    height=n_tasks*30,
    xaxis_type="log",
    yaxis=dict(automargin=True),
)
fig.show()

In [11]:
data_dict = {
    "line_x": [],
    "line_y": [],
    "Filtered": [],
    "Unfiltered": [],
    "colors": [],
    "dset_type": [],
    "CogAt": [],
}
for task_name in task_names:
    dset = Dataset.load(op.join(results_dir, task_name, f"{task_name}_dset-raw.pkl.gz"))
    clean_dset = Dataset.load(op.join(results_dir, task_name, f"{task_name}_dset-clean.pkl.gz"))

    n_ids = len(dset.ids)
    n_clean_ids = len(clean_dset.ids)

    data_dict["CogAt"].extend(task_name)
    data_dict["Unfiltered"].extend([n_ids])
    data_dict["Filtered"].extend([n_clean_ids])

    data_dict["line_x"].extend([n_ids, len(clean_dset.ids), None])
    data_dict["line_y"].extend([task_name, task_name, None])

In [14]:
task_names_org = [cogat_dict[task_name] for task_name in task_names]
task_names_wrap = [
    task_name if task_name.count(" ") < 3 else _wrap_label(task_name)
    for task_name in task_names_org
]

fig = go.Figure()
fig.add_trace(
    go.Bar(
        y=task_names_wrap,
        x=data_dict["Filtered"],
        name="Filtered Data",
        marker_color="green",
        orientation="h",
    )
)
fig.add_trace(
    go.Bar(
        y=task_names_wrap,
        x=data_dict["Unfiltered"],
        name="Unfiltered Data",
        marker_color="red",
        orientation="h",
    )
)

fig.update_layout(
    barmode="group",
    height=n_tasks * 30,
    xaxis_type="log",
)
fig.show()