In [1]:
import os.path as op

import pandas as pd
import matplotlib.pyplot as plt
import ptitprince as pt
import seaborn as sns

sns.set(style="whitegrid")

In [2]:
def raincloud_plot(dx, dy, data, ax):
    ort = "v"
    pt.half_violinplot(
        x=dx,
        y=dy,
        order=["case", "control"],
        data=data,
        bw=0.1,
        cut=0.0,
        scale="area",
        width=0.6,
        dodge=False,
        inner=None,
        orient=ort,
        ax=ax,
    )
    sns.stripplot(
        x=dx,
        y=dy,
        order=["case", "control"],
        data=data,
        edgecolor="white",
        dodge=False,
        size=3,
        jitter=1,
        zorder=0,
        orient=ort,
        ax=ax,
    )
    sns.boxplot(
        x=dx,
        y=dy,
        order=["case", "control"],
        data=data,
        width=0.15,
        zorder=10,
        dodge=True,
        showcaps=True,
        boxprops={"alpha": 0.6, "zorder": 10},
        showfliers=True,
        whiskerprops={"linewidth": 2, "zorder": 10},
        saturation=1,
        orient=ort,
        ax=ax,
    )
    sns.pointplot(
        x=dx, 
        y=dy,
        order=["case", "control"],
        data=data,
        orient=ort, 
        color="r",
        ax=ax,
    )
    plt.setp(ax.lines, zorder=100)

In [3]:
roi_dict = {
    "vmPFC1": "vmPFC: Cluster 1",
    "vmPFC2": "vmPFC: Cluster 2",
    "vmPFC3": "vmPFC: Cluster 3",
    "vmPFC4": "vmPFC: Cluster 4",
    "vmPFC5": "vmPFC: Cluster 5",
    "vmPFC6": "vmPFC: Cluster 6",
    "insulaDlh": "L Insula: Dorsal Anterior",
    "insulaPlh": "L Insula: Posterior",
    "insulaVlh": "L Insula: Ventral Anterior",
    "insulaDrh": "R Insula: Dorsal Anterior",
    "insulaPrh": "R Insula: Posterior",
    "insulaVrh": "R Insula: Ventral Anterior",
    "hippocampus3solF1lh": "L Hippocampus: Anterior",
    "hippocampus3solF2lh": "L Hippocampus: Intermediate",
    "hippocampus3solF3lh": "L Hippocampus: Posterior",
    "hippocampus3solF1rh": "R Hippocampus: Anterior",
    "hippocampus3solF2rh": "R Hippocampus: Intermediate",
    "hippocampus3solF3rh": "R Hippocampus: Posterior",
    "striatumMatchCDlh": "L Striatum: Caudal (Dorsal)",
    "striatumMatchCVlh": "L Striatum: Caudal (Ventral)",
    "striatumMatchDLlh": "L Striatum: Dorsolateral",
    "striatumMatchDlh": "L Striatum: Dorsal",
    "striatumMatchRlh": "L Striatum: Rostral",
    "striatumMatchVlh": "L Striatum: Ventral",
    "striatumMatchCDrh": "R Striatum: Caudal (Dorsal)",
    "striatumMatchCVrh": "R Striatum: Caudal (Ventral)",
    "striatumMatchDLrh": "R Striatum: Dorsolateral",
    "striatumMatchDrh": "R Striatum: Dorsal",
    "striatumMatchRrh": "R Striatum: Rostral",
    "striatumMatchVrh": "R Striatum: Ventral",
    "amygdala1lh": "L Amygdala: Cluster 1",
    "amygdala2lh": "L Amygdala: Cluster 2",
    "amygdala3lh": "L Amygdala: Cluster 3",
    "amygdala1rh": "R Amygdala: Cluster 1",
    "amygdala2rh": "R Amygdala: Cluster 2",
    "amygdala3rh": "R Amygdala: Cluster 3",
}

dset_dict = {
    "ALC": "Alcohol (ALC) Dataset",
    "ATS": "Methamphetamine and Dexamphetamine (ATS) Dataset",
    "CANN": "Cannabis (CANN) Dataset",
    "COC": "Cocaine (COC) Dataset",
}

metric_dict = {
    "REHO": "Regional Homogeneity (ReHo)",
    "FALFF": "Fractional Amplitude of Low Frequency Fluctuations (fALFF)",
}

In [5]:
dsets = ["ALC", "ATS", "CANN", "COC"]
gsrs = ["-gsr", ""]
combats = ["-combat", ""]
metrics = ["REHO", "FALFF"]
seeds = ["amygdala", "hippocampus", "insula", "striatum", "vmPFC"]

for dset in dsets:
    participants_df = pd.read_csv(f"{dset}-sumstats_table.txt", delimiter="\t")
    group_df = participants_df[["participant_id", "group"]]
    for gsr in gsrs:
        for combat in combats:
            for metric in metrics:
                if metric == "REHO":
                    metric_lb = "ReHo"
                else:
                    metric_lb = "fALFF"
                data_df = pd.read_csv(f"{dset}-{metric}{gsr}{combat}.tsv", delimiter="\t")
                data = pd.merge(data_df, group_df, on='participant_id')
                
                fig, axes_arr = plt.subplots(6, 6)
                fig.set_size_inches(15, 15)
        
                for seed in seeds:
                    if seed == "vmPFC":
                        hemispheres = [""]
                    else:
                        hemispheres = ["lh", "rh"]
                    for hemis in hemispheres:
                        if seed == "amygdala":
                            rois = [f"amygdala1{hemis}", f"amygdala2{hemis}", f"amygdala3{hemis}"]
                            if hemis == "lh":
                                axes_lst = axes_arr[0, :3]
                            elif hemis == "rh":
                                axes_lst = axes_arr[0, 3:]
                        elif seed == "hippocampus":
                            rois = [
                                f"hippocampus3solF1{hemis}",
                                f"hippocampus3solF2{hemis}",
                                f"hippocampus3solF3{hemis}",
                            ]
                            if hemis == "lh":
                                axes_lst = axes_arr[1, :3]
                            elif hemis == "rh":
                                axes_lst = axes_arr[1, 3:]
                        elif seed == "insula":
                            rois = [f"insulaD{hemis}", f"insulaP{hemis}", f"insulaV{hemis}"]
                            if hemis == "lh":
                                axes_lst = axes_arr[2, :3]
                            elif hemis == "rh":
                                axes_lst = axes_arr[2, 3:]
                        elif seed == "striatum":
                            rois = [
                                f"striatumMatchCD{hemis}",
                                f"striatumMatchCV{hemis}",
                                f"striatumMatchDL{hemis}",
                                f"striatumMatchD{hemis}",
                                f"striatumMatchR{hemis}",
                                f"striatumMatchV{hemis}",
                            ]
                            if hemis == "lh":
                                axes_lst = axes_arr[3, :]
                            elif hemis == "rh":
                                axes_lst = axes_arr[4, :]
                        elif seed == "vmPFC":
                            rois = ["vmPFC1", "vmPFC2", "vmPFC3", "vmPFC4", "vmPFC5", "vmPFC6"]
                            axes_lst = axes_arr[5, :]
        
                        for roi_i, roi in enumerate(rois):
                            ax = axes_lst[roi_i]
                            raincloud_plot("group", roi, data, ax)
        
                            if seed == "vmPFC":
                                ax.set_xlabel("")
                            else:
                                ax.set_xlabel("")
                                ax.set_xticklabels([])
                            ax.set_ylabel("")
                            ax.set_title(roi_dict[roi])
        
                fig.supylabel(f"{metric_dict[metric]}\n", fontsize=14)
                if gsr == "":
                    if combat == "":
                        title = f"{dset_dict[dset]} GSR=False ComBat=False. {metric_lb} vs Group\n"
                    else:
                        title = f"{dset_dict[dset]} GSR=False ComBat=True. {metric_lb} vs Group\n"
                else:
                    if combat == "":
                        title = f"{dset_dict[dset]} GSR=True ComBat=False. {metric_lb} vs Group\n"
                    else:
                        title = f"{dset_dict[dset]} GSR=True ComBat=True. {metric_lb} vs Group\n"
        
                fig.suptitle(title, fontsize=14)
                fig.tight_layout()
                plt.savefig(f"{dset}-{metric}{gsr}{combat}.png", bbox_inches="tight", dpi=300)
                plt.close()

In [6]:
dsets = ["ALC", "ATS", "CANN", "COC"]
gsrs = ["-gsr", ""]
combats = ["-combat", ""]
metrics = ["REHO", "FALFF"]
seeds = ["amygdala", "hippocampus", "insula", "striatum", "vmPFC"]

for dset in dsets:
    participants_df = pd.read_csv(f"{dset}-sumstats_table.txt", delimiter="\t")
    group_df = participants_df[["participant_id", "group"]]
    for gsr in gsrs:
        for combat in combats:
            for metric in metrics:
                if metric == "REHO":
                    metric_lb = "ReHo"
                else:
                    metric_lb = "fALFF"
                data_df = pd.read_csv(f"{dset}-{metric}{gsr}{combat}.tsv", delimiter="\t")
                data = pd.merge(data_df, group_df, on='participant_id')
        
                for seed in seeds:
                    if seed == "vmPFC":
                        hemispheres = [""]
                    else:
                        hemispheres = ["lh", "rh"]

                    if seed != "striatum":
                        fig, axes_lst = plt.subplots(1, 6)
                        fig.set_size_inches(15, 3)
                        roi_i = 0

                    for hemis in hemispheres:
                        if seed == "amygdala":
                            rois = [f"amygdala1{hemis}", f"amygdala2{hemis}", f"amygdala3{hemis}"]
                        elif seed == "hippocampus":
                            rois = [
                                f"hippocampus3solF1{hemis}",
                                f"hippocampus3solF2{hemis}",
                                f"hippocampus3solF3{hemis}",
                            ]
                        elif seed == "insula":
                            rois = [f"insulaD{hemis}", f"insulaP{hemis}", f"insulaV{hemis}"]
                        elif seed == "striatum":
                            rois = [
                                f"striatumMatchCD{hemis}",
                                f"striatumMatchCV{hemis}",
                                f"striatumMatchDL{hemis}",
                                f"striatumMatchD{hemis}",
                                f"striatumMatchR{hemis}",
                                f"striatumMatchV{hemis}",
                            ]
                            fig, axes_lst = plt.subplots(1, 6)
                            fig.set_size_inches(15, 3)
                            roi_i = 0
                        elif seed == "vmPFC":
                            rois = ["vmPFC1", "vmPFC2", "vmPFC3", "vmPFC4", "vmPFC5", "vmPFC6"]
                        
                        for roi in rois:
                            ax = axes_lst[roi_i]
                            raincloud_plot("group", roi, data, ax)
                            
                            ax.set_xlabel("")
                            if roi_i == 0:
                                ax.set_ylabel(f"{metric}")
                            else:
                                ax.set_ylabel("")
                            ax.set_title(roi_dict[roi])
                            roi_i += 1
        
                        if gsr == "":
                            if combat == "":
                                title = f"{dset_dict[dset]} GSR=False ComBat=False. {metric_lb} vs Group\n"
                            else:
                                title = f"{dset_dict[dset]} GSR=False ComBat=True. {metric_lb} vs Group\n"
                        else:
                            if combat == "":
                                title = f"{dset_dict[dset]} GSR=True ComBat=False. {metric_lb} vs Group\n"
                            else:
                                title = f"{dset_dict[dset]} GSR=True ComBat=True. {metric_lb} vs Group\n"

                        if seed == "striatum":
                            fig.suptitle(title, fontsize=14)
                            fig.tight_layout()
                            if combat == "":
                                plt.savefig(op.join("tables", f"{dset}-1{gsr}{combat}-{metric}-{seed}{hemis}.eps"), bbox_inches="tight")
                                plt.savefig(op.join("tables", f"{dset}-2{gsr}{combat}-{metric}-{seed}{hemis}.eps"), bbox_inches="tight")
                            else:    
                                plt.savefig(op.join("tables", f"{dset}-3{gsr}{combat}-{metric}-{seed}{hemis}.eps"), bbox_inches="tight")
                            plt.close()
                            #plt.show()
                    if seed != "striatum":
                        fig.suptitle(title, fontsize=14)
                        fig.tight_layout()
                        if combat == "":
                            plt.savefig(op.join("tables", f"{dset}-1{gsr}{combat}-{metric}-{seed}.eps"), bbox_inches="tight")
                            plt.savefig(op.join("tables", f"{dset}-2{gsr}{combat}-{metric}-{seed}.eps"), bbox_inches="tight")
                        else:    
                            plt.savefig(op.join("tables", f"{dset}-3{gsr}{combat}-{metric}-{seed}.eps"), bbox_inches="tight")
                        plt.close()
                        #plt.show()

                    

The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript back