In [1]:
import pandas
import re
import argparse
import time

In [9]:
__version__ = "V1.2(Editor) 2023-07-12"

In [2]:
# 前置参数-debug
absolute_path = False
file_path = "F:/OneDrive/Master/Project/trans/data/"
input_sample_info_filename = "0000_sample_info.tsv"
input_df_filename = "0001_total_info.tsv"
output_df_filename = "0002_TSS_TES.tsv"

In [10]:
# 前置参数
parser = argparse.ArgumentParser()
parser.add_argument("--absolute_path", dest="absolute_path", required=False, action="store_true", help="use absolute file path")
parser.add_argument("--file_path", dest="file_path", required=False, type=str, default="./", help="=\"./\",\t the path of data directory, the end of this parament should be '/'")
parser.add_argument("--input_sample_info_filename", dest="input_sample_info_filename", required=False, type=str, default="0000_sample_info.tsv", help="=\"0000_sample_info.tsv\",\t the sample info file")
parser.add_argument("--input_df_filename", dest="input_df_filename", required=False, type=str, default="0001_total_info.tsv", help="\"0001_total_info.tsv\",\t the output tsv file of 0001.py")
parser.add_argument("--output_df_filename", dest="output_df_filename", required=False, type=str, default="0002_TSS_TES.tsv", help="\"0002_TSS_TES.tsv\",\t the output filename of this script")

args = parser.parse_args()
absolute_path = args.absolute_path
file_path = args.file_path
input_sample_info_filename = args.input_sample_info_filename
input_df_filename = args.input_df_filename
output_df_filename = args.output_df_filename

SyntaxError: invalid syntax (2254374255.py, line 3)

In [3]:
# 补全路径
if absolute_path is False:
    input_sample_info_filename = "{}{}".format(file_path, input_sample_info_filename)
    input_df_filename = "{}{}".format(file_path, input_df_filename)
    output_df_filename = "{}{}".format(file_path, output_df_filename)

In [None]:
# 打印参数
print('\n')
print("[Script]{}".format(__file__))
print("[Version]{}".format(__version__))
print("[Date]{}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))))
print("[Paraments]input_sample_info_filename: {}".format(input_sample_info_filename))
print("[Paraments]input_df_filename: {}".format(input_df_filename))
print("[Paraments]output_df_filename: {}".format(output_df_filename))
print('\n')

In [4]:
# Function
def log(function):
    def wrapper(*args, **kwargs):
        print('\n')
        print("[Function]{} start.".format(function.__name__))
        print("\t[Time]{}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))))
        for key, value in kwargs.items():
            if type(value) in [int, str, bool, list]:
                print("\t[Paraments]{}: {}".format(key, value))
            else:
                print("\t[Paraments]{}: <...>".format(key))
        result = function(*args, **kwargs)
        print("[{}]{} finished.".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())),
                                        function.__name__))
        return result
    return wrapper


@log
def get_disease_sample_info(df, disease, sample_disease_list):
    """
    input:
        df, pandas.DataFrame, 含有所有sample的df
        disease, str, 疾病类型
        sample_disease_list, list, # 属于指定类型疾病的样本名
    change:
        统计同一类型疾病中的四种类型gene的gene数及transcript数
    output:
        result, dict, 含有在指定疾病类型中表达的transcript的信息(4种类型的gene及transcript数量)
    """
    df = df
    disease = disease
    sample_disease_list = sample_disease_list

    result = {disease:{"TSS_PAS@gene": 0, "TSS_PAS@transcript": 0,
                       "TSS_APA@gene": 0, "TSS_APA@transcript": 0,
                       "ATSS_PAS@gene": 0, "ATSS_PAS@transcript": 0,
                       "ATSS_APA@gene": 0, "ATSS_APA@transcript": 0
                       }
              }

    # 筛选得到在同一种疾病类型中存在表达的transcript
    disease_columns = ["chr", "strand", "gene_id", "gene_name", "transcript_id",
                       "transcript_name", "transcript_start", "transcript_end"]
    disease_columns = disease_columns + sample_disease_list
    # 过滤掉在指定疾病类型疾病样本中均未表达的transcript
    disease_df = df[disease_columns].copy()
    disease_df[sample_disease_list] = disease_df[sample_disease_list].applymap(lambda x: None if x==0 else x)
    disease_df = disease_df.dropna(subset=sample_disease_list, axis=0, how="all")

    # 开始整理统计信息表
    disease_sample_info = []
    disease_df_group = disease_df.groupby("gene_id")
    for gene_id in disease_df_group.groups.keys():
        # 获得单个基因的所有转录本信息
        temp_df = disease_df_group.get_group(gene_id)
        transcript_num = temp_df.shape[0]  # 该gene所具有的transcript的种类数
        strand = list(set(temp_df["strand"].to_list()))  # 该gene的链的+/-方向
        # 根据strad对起始点及终止点的数量进行统计
        if len(strand) == 1 and strand[0] == '-':
            # 负链
            num_end = len(set(temp_df["transcript_start"].to_list()))
            num_start = len(set(temp_df["transcript_end"].to_list()))
            strand = strand[0]
        elif len(strand) == 1 and strand[0] == '+':
            # 正链
            num_start = len(set(temp_df["transcript_start"].to_list()))
            num_end = len(set(temp_df["transcript_end"].to_list()))
            strand = strand[0]
        else:
            raise KeyError("[错误]strand: {}".format(strand))
        disease_sample_info.append([gene_id, strand, num_start, num_end, transcript_num])
    disease_sample_info = pandas.DataFrame(disease_sample_info)  # from list to pandas.DataFrame
    # modify the name of the df's columns
    new_columns = disease_sample_info.columns.to_list()
    new_columns = dict(zip(new_columns, ["gene_id", "strand", "start_counts", "end_counts", "transcript_counts"]))
    disease_sample_info = disease_sample_info.rename(columns=new_columns)
    disease_sample_info = disease_sample_info.set_index("gene_id")

    # 根据统计信息表保存统计数据
    ## 统计四种类型的gene的gene及transcript数量
    TSS_PAS = disease_sample_info.query("start_counts==1 & end_counts==1")
    TSS_APA = disease_sample_info.query("start_counts==1 & end_counts>1")
    ATSS_PAS = disease_sample_info.query("start_counts>1 & end_counts==1")
    ATSS_APA = disease_sample_info.query("start_counts>1 & end_counts>1")
    result[disease]["TSS_PAS@gene"] = TSS_PAS.shape[0]
    result[disease]["TSS_APA@gene"] = TSS_APA.shape[0]
    result[disease]["ATSS_PAS@gene"] = ATSS_PAS.shape[0]
    result[disease]["ATSS_APA@gene"] = ATSS_APA.shape[0]
    result[disease]["TSS_PAS@transcript"] = TSS_PAS["transcript_counts"].sum()
    result[disease]["TSS_APA@transcript"] = TSS_APA["transcript_counts"].sum()
    result[disease]["ATSS_PAS@transcript"] = ATSS_PAS["transcript_counts"].sum()
    result[disease]["ATSS_APA@transcript"] = ATSS_APA["transcript_counts"].sum()

    return result

In [5]:
# 读取数据
sample_info = pandas.read_csv(input_sample_info_filename, sep='\t')
sample_info["disease"] = sample_info["disease"].map(lambda x: x.replace(' ', '_'))
sample_df = pandas.read_csv(input_df_filename, sep='\t', index_col=0)

In [6]:
# 建立疾病类型与样本的关系
disease_dict = {}  # {<disease_1>: [<sample_1>], <disease_2>: [<sample_2>, <sample_3, ...>], ...}
sample_info_group = sample_info.groupby("disease")
for disease in sample_info_group.groups.keys():
    disease_dict[disease] = sample_info_group.get_group(disease)["GEO_accession"].to_list()

In [7]:
# 准备统计信息表
df_index = ["{}_{}".format(disease, x) for disease in disease_dict.keys() for x in ["gene", "transcript"]]
df_index = ["all_gene", "all_transcript"] + df_index
df = pandas.DataFrame(index=df_index,
                      columns=["TSS_PAS", "TSS_APA", "ATSS_PAS", "ATSS_APA"])

# 统计信息
df_info = get_disease_sample_info(df=sample_df,
                                  disease="all",
                                  sample_disease_list=sample_info["GEO_accession"].to_list())
for disease in disease_dict.keys():
    disease_sample_list = disease_dict.get(disease)
    temp_dict = get_disease_sample_info(df=sample_df,
                                        disease=disease,
                                        sample_disease_list=disease_sample_list)
    df_info.update(temp_dict)

# add info to df
for disease in df_info:
    temp_disease_info = df_info.get(disease)
    for key, counts in temp_disease_info.items():
        [gene_type,g_or_t] = re.split(pattern="@", string=key)
        counts_index = "{}_{}".format(disease, g_or_t)
        counts_column = gene_type
        df.at[counts_index, counts_column] = counts

In [8]:
# save data
df.to_csv(output_df_filename, sep='\t')

In [None]:
# output end
print("[{}]All blocks finished.".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))))