In [None]:
import sys  # noqa
import os
import subprocess
import logging
import warnings
from multiprocessing import Manager, Pool, Process
import pickle as pkl

import numpy as np
import dask.dataframe as dd
import gseapy as gp
import pandas as pd
import polars as pl
from tqdm import tqdm
from pandarallel import pandarallel
import seaborn as sns
import matplotlib.pyplot as plt
from statannotations.Annotator import Annotator
from scipy import stats

import scanpy as sc
import decoupler as dc
from anndata import AnnData
import anndata2ri
import rpy2.rinterface_lib.callbacks as rcb
import rpy2.robjects as ro

from downstream.rna_seq import volcano, convert_ensg_to_symbol, pyPPI, pyTCGA, plot_gene,plot_boxplot_gene,get_contrast
from downstream.editing_reditools import get_dict_reditools, plot_site_counts, get_samp_df, plot_editing_overview_shared_site, plot_editing_overview_all_site, get_melt_df_outer, plot_editing_events_pk_read, get_file_info, plot_editing_events, get_REDIT_count_df, check_editing_group, get_samp_df_frequency
warnings.filterwarnings("ignore", category=DeprecationWarning)

THREADS = 60
os.environ["POLARS_MAX_THREADS"] = str(THREADS)
logging.basicConfig(
    stream=sys.stdout,
    format="%(asctime)s %(levelname)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info("work_dir: %s" % os.getcwd())

def do_something_background(myfunction, *args, **kwargs):
    my_process = Process(target=myfunction, args=args, kwargs=kwargs)
    my_process.daemon = True
    my_process.start()

def do_some_command_background(command, log_file=None):
    if log_file is None:
        log_file = open('log_file.log', 'a')
    my_process = Process(
        target=subprocess.run,
        kwargs={
            'args': command,
            'shell': True,
            'stderr': log_file,
            'stdout': log_file})
    my_process.daemon = True
    my_process.start()

ro.pandas2ri.activate()
anndata2ri.activate()
%load_ext rpy2.ipython
%load_ext autoreload
%autoreload 2
tqdm.pandas()
# %%R -i data -i data_tod -i genes -i cells -i soupx_groups -o out

In [None]:
import matplotlib.lines as mlines
import importlib
import matplotlib.pyplot as plt
import pandas as pd
import os
import seaborn as sns
from pydeseq2.dds import DeseqDataSet
from pydeseq2.default_inference import DefaultInference
from pydeseq2.ds import DeseqStats
inference = DefaultInference(n_cpus=16)

In [None]:
import matplotlib.colors as mcolors
import matplotlib.cm as cm
samp_df = get_samp_df_frequency(reditools_res, how='inner')
samp_df[columns[0][:-2]] = (samp_df[columns[0]] + samp_df[columns[1]]) / 2
samp_df[columns[2][:-2]] = (samp_df[columns[2]] + samp_df[columns[3]]) / 2
samp_df['ctl_1'] = (samp_df[columns[0]] + samp_df[columns[2]]) / 2
samp_df['ctl_2'] = (samp_df[columns[1]] + samp_df[columns[3]]) / 2
samp_df['log2fc'] = np.log2(
    samp_df[columns[2][:-2]] / samp_df[columns[0][:-2]])
samp_df['log2fc_ctl'] = np.log2(samp_df['ctl_2'] / samp_df['ctl_1'])
samp_df['log2fc_rep_ctl'] = np.log2(samp_df[columns[1]] / samp_df[columns[0]])
samp_df['log2fc_case_ctl'] = np.log2(samp_df[columns[3]] / samp_df[columns[2]])

cmap_name = 'purple_red'
cmap = mcolors.LinearSegmentedColormap.from_list(
    cmap_name, ['#d3363e', '#313f7b'], N=100)

In [None]:
samp_df['log2fc_cmap'] = np.abs(samp_df['log2fc'])

norm = mcolors.Normalize(
    vmin=samp_df['log2fc_cmap'].min(),
    vmax=samp_df['log2fc_cmap'].max())
samp_df['color'] = samp_df['log2fc_cmap'].apply(
    lambda x: cmap(norm(x)))
# samp_df['color'] = samp_df['log2fc_cmap'].apply(
#     lambda x: cmap(norm(x)) if x >= 0.5 else '#aaabae')

plt.figure(figsize=(6.3, 6))
plt.scatter(data=samp_df,
            x=columns[0][:-2],
            y=columns[2][:-2],
            color=samp_df["color"], s=1, alpha=1, rasterized=True)
plt.plot([0, 1], [0, 1], linestyle=(0, (5, 5)), color='black',
         linewidth=1, alpha=0.5)
plt.xlabel(columns[0][:-2])
plt.ylabel(columns[2][:-2])
plt.xlim(0, 0.95)
plt.ylim(0, 0.95)
norm = mcolors.Normalize(
    vmin=0,
    vmax=samp_df['log2fc_cmap'].max())
cbar = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), ax=plt.gca(),
                    cax=plt.axes([1.05, 0.55, 0.03, 0.2]))
cbar.set_label('log2FC')
plt.tight_layout()
plt.savefig('scatter_case.pdf', dpi=600, bbox_inches='tight')
plt.figure(figsize=(5, 1.5))
norm = plt.Normalize(
    samp_df['log2fc_ctl'].min(),
    samp_df['log2fc_ctl'].max())  # 归一化
colors = cmap(norm(samp_df['log2fc_ctl'].sort_values(
    ascending=False)))
plt.bar(
    range(len(samp_df)),
    samp_df['log2fc'].sort_values(
        ascending=False),
    color=colors,
    width=1, rasterized=True)
plt.xlim(-0.5, len(samp_df) + 0.5)
plt.xticks([])
up_count = samp_df[samp_df['log2fc'] > 0.5].shape[0]
down_count = samp_df[samp_df['log2fc'] < -0.5].shape[0]
odds_ratio = up_count / down_count
plt.text(
    0.95,
    0.95,
    f'UP: {up_count}\nDOWN: {down_count}\nOdds Ratio: {odds_ratio:.2f}',
    horizontalalignment='right',
    verticalalignment='top',
    transform=plt.gca().transAxes,
    fontsize=6,
    bbox=None)
plt.ylabel('log2FC')
plt.tight_layout()
plt.savefig('log2fc_case.pdf')
# UP, DOWN 0.5

In [None]:
samp_df['log2fc_cmap'] = np.abs(samp_df['log2fc_rep_ctl'])

norm = mcolors.Normalize(
    vmin=samp_df['log2fc_cmap'].min(),
    vmax=samp_df['log2fc_cmap'].max())
samp_df['color'] = samp_df['log2fc_cmap'].apply(
    lambda x: cmap(norm(x)))

plt.figure(figsize=(6.3, 6))
plt.scatter(data=samp_df,
            x=columns[0],
            y=columns[1],
            color=samp_df["color"], s=1, alpha=1, rasterized=True)
plt.plot([0, 1], [0, 1], linestyle=(0, (5, 5)), color='black',
         linewidth=1, alpha=0.5)
plt.xlabel(columns[0])
plt.ylabel(columns[1])
plt.xlim(0, 0.95)
plt.ylim(0, 0.95)
norm = mcolors.Normalize(
    vmin=0,
    vmax=samp_df['log2fc_cmap'].max())
cbar = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), ax=plt.gca(),
                    cax=plt.axes([1.05, 0.55, 0.03, 0.2]))
cbar.set_label('log2FC')
plt.tight_layout()
plt.savefig('scatter_rep_ctl.pdf', dpi=600, bbox_inches='tight')
plt.figure(figsize=(5, 1.5))
norm = plt.Normalize(
    samp_df['log2fc_rep_ctl'].min(),
    samp_df['log2fc_rep_ctl'].max())  # 归一化
colors = cmap(norm(samp_df['log2fc_rep_ctl'].sort_values(
    ascending=False)))
plt.bar(
    range(len(samp_df)),
    samp_df['log2fc_rep_ctl'].sort_values(
        ascending=False),
    color=colors,
    width=1, rasterized=True)
plt.xlim(-0.5, len(samp_df) + 0.5)
plt.xticks([])
up_count = samp_df[samp_df['log2fc_rep_ctl'] > 0.5].shape[0]
down_count = samp_df[samp_df['log2fc_rep_ctl'] < -0.5].shape[0]
odds_ratio = up_count / down_count
plt.text(
    0.95,
    0.95,
    f'UP: {up_count}\nDOWN: {down_count}\nOdds Ratio: {odds_ratio:.2f}',
    horizontalalignment='right',
    verticalalignment='top',
    transform=plt.gca().transAxes,
    fontsize=6,
    bbox=None)
plt.ylabel('log2FC')
plt.tight_layout()
plt.savefig('log2fc_rep_ctl.pdf')
# UP, DOWN 0.5

In [None]:
samp_df['log2fc_cmap'] = np.abs(samp_df['log2fc_case_ctl'])

norm = mcolors.Normalize(
    vmin=samp_df['log2fc_cmap'].min(),
    vmax=samp_df['log2fc_cmap'].max())
samp_df['color'] = samp_df['log2fc_cmap'].apply(
    lambda x: cmap(norm(x)))

plt.figure(figsize=(6.3, 6))
plt.scatter(data=samp_df,
            x=columns[2],
            y=columns[3],
            color=samp_df["color"], s=1, alpha=1, rasterized=True)
plt.plot([0, 1], [0, 1], linestyle=(0, (5, 5)), color='black',
         linewidth=1, alpha=0.5)
plt.xlabel(columns[2])
plt.ylabel(columns[3])
plt.xlim(0, 0.95)
plt.ylim(0, 0.95)
norm = mcolors.Normalize(
    vmin=0,
    vmax=samp_df['log2fc_cmap'].max())
cbar = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), ax=plt.gca(),
                    cax=plt.axes([1.05, 0.55, 0.03, 0.2]))
cbar.set_label('log2FC')
plt.tight_layout()
plt.savefig('scatter_rep_case.pdf', dpi=600, bbox_inches='tight')
plt.figure(figsize=(5, 1.5))
norm = plt.Normalize(
    samp_df['log2fc_case_ctl'].min(),
    samp_df['log2fc_case_ctl'].max())  # 归一化
colors = cmap(norm(samp_df['log2fc_case_ctl'].sort_values(
    ascending=False)))
plt.bar(
    range(len(samp_df)),
    samp_df['log2fc_case_ctl'].sort_values(
        ascending=False),
    color=colors,
    width=1, rasterized=True)
plt.xlim(-0.5, len(samp_df) + 0.5)
plt.xticks([])
up_count = samp_df[samp_df['log2fc_case_ctl'] > 0.5].shape[0]
down_count = samp_df[samp_df['log2fc_case_ctl'] < -0.5].shape[0]
odds_ratio = up_count / down_count
plt.text(
    0.95,
    0.95,
    f'UP: {up_count}\nDOWN: {down_count}\nOdds Ratio: {odds_ratio:.2f}',
    horizontalalignment='right',
    verticalalignment='top',
    transform=plt.gca().transAxes,
    fontsize=6,
    bbox=None)
plt.ylabel('log2FC')
plt.tight_layout()
plt.savefig('log2fc_case_ctl.pdf')

In [None]:
samp_df['log2fc_cmap'] = np.abs(samp_df['log2fc_ctl'])

norm = mcolors.Normalize(
    vmin=samp_df['log2fc_cmap'].min(),
    vmax=samp_df['log2fc_cmap'].max())
samp_df['color'] = samp_df['log2fc_cmap'].apply(
    lambda x: cmap(norm(x)))

plt.figure(figsize=(6.3, 6))
plt.scatter(data=samp_df,
            x='ctl_1',
            y='ctl_2',
            color=samp_df["color"], s=1, alpha=1, rasterized=True)
plt.plot([0, 1], [0, 1], linestyle=(0, (5, 5)), color='black',
         linewidth=1, alpha=0.5)
plt.xlabel('ctl_1')
plt.ylabel('ctl_2')
plt.xlim(0, 0.95)
plt.ylim(0, 0.95)
norm = mcolors.Normalize(
    vmin=0,
    vmax=samp_df['log2fc_cmap'].max())
cbar = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), ax=plt.gca(),
                    cax=plt.axes([1.05, 0.55, 0.03, 0.2]))
cbar.set_label('log2FC')
plt.tight_layout()
plt.savefig('scatter_ctl.pdf', dpi=600, bbox_inches='tight')
plt.figure(figsize=(5, 1.5))
norm = plt.Normalize(
    samp_df['log2fc_ctl'].min(),
    samp_df['log2fc_ctl'].max())  # 归一化
colors = cmap(norm(samp_df['log2fc_ctl'].sort_values(
    ascending=False)))
plt.bar(
    range(len(samp_df)),
    samp_df['log2fc_ctl'].sort_values(
        ascending=False),
    color=colors,
    width=1, rasterized=True)
plt.xlim(-0.5, len(samp_df) + 0.5)
plt.xticks([])
up_count = samp_df[samp_df['log2fc_ctl'] > 0.5].shape[0]
down_count = samp_df[samp_df['log2fc_ctl'] < -0.5].shape[0]
odds_ratio = up_count / down_count
plt.text(
    0.95,
    0.95,
    f'UP: {up_count}\nDOWN: {down_count}\nOdds Ratio: {odds_ratio:.2f}',
    horizontalalignment='right',
    verticalalignment='top',
    transform=plt.gca().transAxes,
    fontsize=6,
    bbox=None)
plt.ylabel('log2FC_ctl')
plt.tight_layout()
plt.savefig('log2fc_ctl.pdf')

In [None]:
plt.figure(figsize=(6, 6))
melt_df_outer = get_melt_df_outer(reditools_res)
ax = sns.boxplot(
    data=melt_df_outer,
    x='Condition',
    y='Frequency',
    hue='Repeat',
    showfliers=False,
    width=0.5,
    # legend=False,
    fill=False, gap=.1, palette=['#747474', '#aaabae'])
plt.ylim(0, 0.6)
pairs = [(columns[0][:-2], columns[2][:-2])]
annotator = Annotator(
    ax,
    pairs,
    data=melt_df_outer,
    x='Condition',
    y='Frequency')
annotator.configure(test='Mann-Whitney', text_format='star', loc='inside')
annotator.apply_and_annotate()
plt.ylabel('Editing Index')
plt.tight_layout()
plt.savefig('editing_index_outer.pdf')

In [None]:
plt.figure(figsize=(6, 6))
samp_df = samp_df[columns]
samp_df_melt = samp_df.melt(
    var_name='Samp',
    value_name='Frequency',
    ignore_index=False)
samp_df_melt['Repeat'] = samp_df_melt.Samp.str[-1]
samp_df_melt['Condition'] = samp_df_melt.Samp.str[:-2]
ax = sns.boxplot(
    data=samp_df_melt,
    x='Condition',
    y='Frequency',
    hue='Repeat',
    showfliers=False,
    width=0.5,
    # legend=False,
    fill=False, gap=.1, palette=['#747474', '#aaabae'])
plt.ylim(0, 1)
pairs = [(columns[0][:-2], columns[2][:-2])]
annotator = Annotator(
    ax,
    pairs,
    data=samp_df_melt,
    x='Condition',
    y='Frequency')
annotator.configure(test='Mann-Whitney', text_format='star', loc='inside')
annotator.apply_and_annotate()
plt.ylabel('Editing Index')
plt.tight_layout()
plt.savefig('editing_index_inner.pdf')

In [None]:
plot_editing_events_pk_read(reditools_res, file_info)
plt.title('Editing Events ALL')
plt.tight_layout()
plt.savefig('editing_events_downsample.pdf')

In [None]:
plot_editing_events_pk_read(res_downsample, file_info_downsample)
plt.title('Editing Events Downsample')
plt.tight_layout()
plt.savefig('editing_events_downsample.pdf')

In [None]:
aei = pd.read_csv(aei)
aei = aei[aei['Sample'].isin(
    columns)]
aei['Repeat'] = aei['Sample'].str.split('-').str[2]
aei['Condition'] = aei['Sample'].str.split('-').str[0] + '-' + \
    aei['Sample'].str.split('-').str[1]
aei = aei.sort_values('Condition', ascending=False)


def plot_info(
    df,
    x='Condition',
    y='A2GEditingIndex',
    pairs=[
        (columns[0][:-2],
         columns[3][:-2])]):
    ax = sns.barplot(data=aei, x=x, y=y, legend=False, width=0.5)
    pairs = pairs
    annotator = Annotator(ax, pairs, data=df, x=x, y=y)
    annotator.configure(test='t-test_ind', text_format='star', loc='outside')
    annotator.apply_and_annotate()
    ax.set_ylim(0, 2.5)
    return ax


plot_info(aei, y='A2GEditingIndex')
plt.tight_layout()
plt.savefig('A2GEditingIndex.pdf')
plt.figure()
plot_info(aei, y='A2TEditingIndex')
plt.tight_layout()
plt.savefig('A2TEditingIndex.pdf')

In [None]:

import venn
from pandas import CategoricalDtype
padj = 0.05
logfc = 1


def get_data_and_keep_columns(path, columns):
    data = pd.read_csv(path, sep='\t', index_col=0)
    data = data[columns]
    data = data.loc[data.sum(axis=1) > 0]
    data.index = data.index.map(convert_ensg_to_symbol)
    data = data[~data.index.duplicated(keep='first')]
    return data


data = get_data_and_keep_columns(
    expression_count,
    columns)
tpm_data = get_data_and_keep_columns(
    expression_tpm,
    columns)
data = data.astype(int)
cat_condition = CategoricalDtype(categories=['ctl', 'kd'], ordered=True)
meta_data.Condition = meta_data.Condition.astype(cat_condition)
dds = DeseqDataSet(
    counts=data.T,
    metadata=meta_data,
    design_factors=["Condition"],
    refit_cooks=True,
    inference=inference)
dds.deseq2()
res_expression = get_contrast(dds, 'Condition', 'ctl', 'kd', padj, logfc)
res_expression.to_csv(f'{name}_expression.tsv', sep='\t')
volcano(
    res_expression,
    pval_threshold=padj,
    fc_max=logfc)
plt.tight_layout()
plt.savefig(f'volcano_expression.pdf')

In [None]:
progeny = dc.get_progeny(top=500)
pathway_acts, pathway_pvals = dc.run_mlm(
    mat=pd.DataFrame(res_expression['stat']).T, net=progeny, verbose=True)
dc.plot_barplot(
    pathway_acts,
    'stat',
    top=25,
    vertical=False,
    figsize=(6, 3)
)
plt.tight_layout()
plt.savefig('pathway_acts_expression_progeny.pdf')
pre_res = gp.prerank(rnk=res_expression['stat'],
                     gene_sets=[
    'GO_Molecular_Function_2023', 'GO_Biological_Process_2023'],
    threads=4,
    verbose=True
)
terms = pre_res.res2d.Term
axs = pre_res.plot(terms=terms[1])
plt.tight_layout()
plt.savefig('go_first.pdf')
pre_res.res2d['Term_raw'] = pre_res.res2d.Term
pre_res.res2d['Term'] = pre_res.res2d.Term.apply(
    lambda x: x.split('__')[1] if '__' in x else x)
gp.dotplot(pre_res.res2d, column='FDR q-val', top_term=10, cutoff=0.25)
plt.tight_layout()
plt.savefig('go.pdf', bbox_inches='tight')
pre_res = gp.prerank(rnk=res_expression['stat'],
                     gene_sets=['MSigDB_Hallmark_2020'],
                     threads=4,
                     verbose=True
                     )
terms = pre_res.res2d.Term
axs = pre_res.plot(terms=terms[0])
plt.tight_layout()
plt.savefig('hallmark_first.pdf', bbox_inches='tight')
pre_res.res2d['Term_raw'] = pre_res.res2d.Term
pre_res.res2d['Term'] = pre_res.res2d.Term.apply(
    lambda x: x.split('__')[1] if '__' in x else x)
gp.dotplot(pre_res.res2d, column='FDR q-val', top_term=10, cutoff=0.25)
plt.tight_layout()
plt.savefig('hallmark.pdf', bbox_inches='tight')