In [5]:
import os
import pandas
import argparse
import time
import re
import pickle

In [6]:
__version__ = "V2.0(Editor) 2023-07-12"

In [2]:
# 前置参数
parser = argparse.ArgumentParser()
parser.add_argument("--counts_cutoff", dest="counts_cutoff", required=False, type=int, default=2, help="=2,\t the cutoff_value of transcript in quantification file")
parser.add_argument("--annotation_col", dest="annotation_col", required=False, type=str, default=["gene_id", "transcript_id", "gene_name"], nargs="*", help="=[\"gene_id\", \"transcript_id\", \"gene_name\"],\t selected columns of annotation file")
parser.add_argument("--quantification_col", dest="quantification_col", required=False, type=str, default=["annot_gene_id", "annot_transcript_id","annot_transcript_name"], nargs="*", help="=[\"annot_gene_id\", \"annot_transcript_id\",\"annot_transcript_name\"],\t selected columns of quantification file")
parser.add_argument("--absolute_path", dest="absolute_path", required=False, action="store_true", default=False, help="use absolute path")
parser.add_argument("--file_path", dest="file_path", required=False, type=str, default="./", help="the dir of data, if absolute_path is False")
parser.add_argument("--input_sample_info", dest="input_sample_info", required=False, type=str, default="0000_sample_info.tsv", help="=0000_sample_info.tsv,\t the info file of input sample")
parser.add_argument("--output_df_filename", dest="output_df_filename", required=False, type=str, default="0001_total_info.tsv", help="=,\t the output df file of samples")
parser.add_argument("--output_pickle_filename", dest="output_pickle_filename", required=False, type=str, default="0001_total_info.pickle", help="=,\t the output pickle file of samples")

args = parser.parse_args()
counts_cutoff = args.counts_cutoff
annotation_col = args.annotation_col
quantification_col = args.quantification_col
absolute_path = args.absolute_path
file_path = args.file_path
input_sample_info = args.input_sample_info
output_df_filename = args.output_df_filename
output_pickle_filename = args.output_pickle_filename

NameError: name 'argparse' is not defined

In [7]:
# 前置参数-debug
counts_cutoff = 2

annotation_col = ["gene_id", "transcript_id", "gene_name",]
quantification_col = ["annot_gene_id", "annot_transcript_id","annot_transcript_name"]

absolute_path = False
file_path = "F:/OneDrive/Master/Project/trans/data/"
input_sample_info = "0000_sample_info.tsv"
output_df_filename = "0001_total_info.tsv"
output_pickle_filename = "0001_total_info.pickle"

In [8]:
# 补全路径
if absolute_path == False:
    input_sample_info = "{}{}".format(file_path, input_sample_info)
    output_df_filename = "{}{}".format(file_path, output_df_filename)
    output_pickle_filename = "{}{}".format(file_path, output_pickle_filename)

In [9]:
# 后置参数
# 获取样本名
sample_info = pandas.read_csv(input_sample_info, sep='\t')
sample_name_list = sample_info["GEO_accession"].to_list()

In [None]:
# 打印参数
print('\n')
print("[Version]{}".format(__version__))
print("[Date]{}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))))
print("[Paraments]counts_cutoff: {}".format(counts_cutoff))
print("[Paraments]annotation_col: {}".format(annotation_col))
print("[Paraments]quantification_col: {}".format(quantification_col))
print("[Paraments]input_sample_info: {}".format(input_sample_info))
print("[Paraments]sample_name_list: {}".format(sample_name_list))
print("[Paraments]output_df_filename: {}".format(output_df_filename))
print("[Paraments]output_pickle_filename: {}".format(output_pickle_filename))
print('\n')

In [10]:
# Class
# 声明一个类，存储所有gene的信息
class Total(object):
    def __init__(self, sample_list, gene_dict={}):
        self.gene_dict = gene_dict  # 存储原始的Gene对象
        self.sample_list = sample_list
        self.start_dict = {}  # 为了加快check_gene_id()的检索速度测试使用 {<start>: [<gene_id>, <gene_id>],
                              #                                          <start>: [<gene_id>, <gene_id>, <gene_id>, ...]}

    def __check_gene_id(self, Gene):
        # change:
        #   检验gene是否已被记录
        #       1.gene_id是否已被记录
        #       2.在相应染色体及正负链上, 该gene的范围是否已被记录
        # output:
        #   str, 若该gene_id已被记录则返回'0', 若该gene_id未被记录(start未被记录)则返回'1',
        #        若该gene_id未被记录(start已被记录但end未被记录)则返回'2'
        #        若该gene_id未被记录(start与end均已被记录)则返回被记录的gene_id
        
        if Gene.gene_id in self.gene_dict.keys():
            return '0'
        else:
            if Gene.start not in self.start_dict.keys():
                return '1'
            else:
                gene_id_list = self.start_dict.get(Gene.start)
                for existed_gene in gene_id_list:
                    existed_gene = self.gene_dict.get(existed_gene)
                    if existed_gene.strand != Gene.strand:
                        continue
                    elif existed_gene.chr != Gene.chr:
                        continue
                    else:
                        if existed_gene.end == Gene.end:
                            return existed_gene.gene_id
                return '2'


    def add_gene(self, Gene):
        # change:
        #   若gene信息未被记录, 则向gene_dict添加Gene对象，并向start_dict添加相关信息
        # output:
        #   str, 若成功添加该gene信息则返回"True"，若该gene_id已被记录则返回"False",
        #        若该gene_id未被记录但相关信息已被记录则返回相应的gene_id

        check_exist = self.__check_gene_id(Gene)
        if check_exist == '0':
            return "False"
        elif check_exist == '1':
            self.gene_dict[Gene.gene_id] = Gene
            self.start_dict[Gene.start] = [Gene.gene_id]
            return "True"
        elif check_exist == '2':
            self.gene_dict[Gene.gene_id] = Gene
            temp = self.start_dict.get(Gene.start)
            temp.append(Gene.gene_id)
            self.start_dict[Gene.start] = temp
            return "True"
        else:
            return check_exist
    

    def get_df(self, sample_name_list=[]):
        # change:
        #   整理该对象下所有gene的信息到一个df中
        # output:
        #   df, pandas.DataFrame, 含有列chr, strand, source, gene_id, transcript_id, transcript_start, transcript_end, sample_name...
        sample_name_list = sample_name_list

        gene_id_list = list(self.gene_dict.keys())
        columns = ["chr", "strand", "source", "gene_id", "gene_name",
                   "transcript_id", "transcript_name", "transcript_start", "transcript_end"]
        columns = columns + sample_name_list
        df = pandas.DataFrame(columns=columns)
        
        """progress_num = 0  # 进度条
        progress_end = len(gene_id_list)  # 进度条
        for gene_id in gene_id_list:
            progress_num +=1  # 进度条
            print("{}/{}".format(progress_num,progress_end), end='\r')  # 打印进度条

            temp_df = self.gene_dict[gene_id].get_df(sample_name_list)
            df = pandas.concat([df, temp_df], axis=0)"""
        
        # 为了加快concat速度
        remain_num = len(gene_id_list)  # 进度条
        while remain_num != 0:
            print("{}".format(remain_num), end='\r')  # 打印进度条

            if remain_num >= 30:
                temp1 = gene_id_list.pop()
                temp2 = gene_id_list.pop()
                temp3 = gene_id_list.pop()
                temp4 = gene_id_list.pop()
                temp5 = gene_id_list.pop()
                temp6 = gene_id_list.pop()
                temp7 = gene_id_list.pop()
                temp8 = gene_id_list.pop()
                temp9 = gene_id_list.pop()
                temp10 = gene_id_list.pop()
                temp11 = gene_id_list.pop()
                temp12 = gene_id_list.pop()
                temp13 = gene_id_list.pop()
                temp14 = gene_id_list.pop()
                temp15 = gene_id_list.pop()
                temp16 = gene_id_list.pop()
                temp17 = gene_id_list.pop()
                temp18 = gene_id_list.pop()
                temp19 = gene_id_list.pop()
                temp20 = gene_id_list.pop()
                temp21 = gene_id_list.pop()
                temp22 = gene_id_list.pop()
                temp23 = gene_id_list.pop()
                temp24 = gene_id_list.pop()
                temp25 = gene_id_list.pop()
                temp26 = gene_id_list.pop()
                temp27 = gene_id_list.pop()
                temp28 = gene_id_list.pop()
                temp29 = gene_id_list.pop()
                temp30 = gene_id_list.pop()

                temp1 = self.gene_dict[temp1].get_df(sample_name_list)
                temp2 = self.gene_dict[temp2].get_df(sample_name_list)
                temp3 = self.gene_dict[temp3].get_df(sample_name_list)
                temp4 = self.gene_dict[temp4].get_df(sample_name_list)
                temp5 = self.gene_dict[temp5].get_df(sample_name_list)
                temp6 = self.gene_dict[temp6].get_df(sample_name_list)
                temp7 = self.gene_dict[temp7].get_df(sample_name_list)
                temp8 = self.gene_dict[temp8].get_df(sample_name_list)
                temp9 = self.gene_dict[temp9].get_df(sample_name_list)
                temp10 = self.gene_dict[temp10].get_df(sample_name_list)
                temp11 = self.gene_dict[temp11].get_df(sample_name_list)
                temp12 = self.gene_dict[temp12].get_df(sample_name_list)
                temp13 = self.gene_dict[temp13].get_df(sample_name_list)
                temp14 = self.gene_dict[temp14].get_df(sample_name_list)
                temp15 = self.gene_dict[temp15].get_df(sample_name_list)
                temp16 = self.gene_dict[temp16].get_df(sample_name_list)
                temp17 = self.gene_dict[temp17].get_df(sample_name_list)
                temp18 = self.gene_dict[temp18].get_df(sample_name_list)
                temp19 = self.gene_dict[temp19].get_df(sample_name_list)
                temp20 = self.gene_dict[temp20].get_df(sample_name_list)
                temp21 = self.gene_dict[temp21].get_df(sample_name_list)
                temp22 = self.gene_dict[temp22].get_df(sample_name_list)
                temp23 = self.gene_dict[temp23].get_df(sample_name_list)
                temp24 = self.gene_dict[temp24].get_df(sample_name_list)
                temp25 = self.gene_dict[temp25].get_df(sample_name_list)
                temp26 = self.gene_dict[temp26].get_df(sample_name_list)
                temp27 = self.gene_dict[temp27].get_df(sample_name_list)
                temp28 = self.gene_dict[temp28].get_df(sample_name_list)
                temp29 = self.gene_dict[temp29].get_df(sample_name_list)
                temp30 = self.gene_dict[temp30].get_df(sample_name_list)
                df = pandas.concat([df, temp1, temp2, temp3, temp4, temp5, temp6, temp7, temp8, temp9, temp10,
                                    temp11, temp12, temp13, temp14, temp15, temp16, temp17, temp18, temp19, temp20,
                                    temp21, temp22, temp23, temp24, temp25, temp26, temp27, temp28, temp29, temp30],
                                    axis=0)
            else:
                temp1 = gene_id_list.pop()
                temp1 = self.gene_dict[temp1].get_df(sample_name_list)
                df = pandas.concat([df, temp1], axis=0)

            remain_num = len(gene_id_list)

        return df


# 定义一个类，以存储基因的相关信息
#   具有属性：染色体号，source，正负链，start/end，包含的转录本id
#   存储：转录本名称及范围、外显子范围
#   具有方法：
#       1.添加外显子
#       2.添加转录本
#       3.统计该基因的转录本数量

# 需要建立一个df, 可以实现gene_id与transcript_id的互查

# 类
class Gene(object):
    def __init__(self, chr, gene_id, gene_name, source, 
                 strand, start, end, 
                 transcript_dict, exon_dict):
        self.chr = chr  # str
        self.gene_id = gene_id  # str
        self.gene_name = gene_name
        self.source = source  # str
        self.strand = strand  # str
        self.start = int(start)  # int
        self.end = int(end)  # int
        # 可删除 self.counts_dict = counts_dict  # {sample_name1: counts, sample_name2: counts, ...}
        self.transcript_dict = transcript_dict  # {transcript_id: {"transcript_name": <transcript_name>
                                                #                  "range": [<start>, <end>],
                                                #                  "exon_range": {<start>: [<end>], <start>: [<end>, <end>], ...},
                                                #                  <sample_name1>: <counts>,
                                                #                  <sample_name2>: <counts>, ...}
                                                # }
        self.exon_dict = exon_dict  # {start: [end], start: [end, end, ...], ...}


    def __check_exon_exist(self, exon_start, exon_end):
        # change:
        #   检查指定的exon是否已被记录
        # output:
        #   int, 若start未被记录则返回0, 若start被记录end未被记录则返回1, 若该exon的start与end均已被记录则返回2
        
        if exon_start not in self.exon_dict.keys():
            return 0
        else:
            if exon_end not in self.exon_dict.get(exon_start):
                return 1
            else:
                return 2


    def add_exon(self, exon_start, exon_end):
        # change:
        #   对基因增加一个exon
        # output:
        #   str, 若成功添加exon则返回"success_add", 若该exon已被记录则返回"existed_exon"

        exon_exist_mark = self.__check_exon_exist(exon_start, exon_end)
        if exon_exist_mark == 0:
            # 存储start与end未记录的exon的信息
            self.exon_dict[exon_start] = [exon_end]
            return "success add"
        elif exon_exist_mark == 1:
            # 存储start已记录而end未记录的exon的信息
            self.exon_dict[exon_start].append(exon_end)
            return "success add"
        else:
            return "existed_exon"


    def __check_transcript_exist(self, transcript_id, start, end, exon_list):
        # change:
        #   检验指定的transcript是否已被记录
        # output:
        #   str, 若trnascript_id已被记录则返回transcript_id, 表示该transcript已被记录, 若transcript_id未被记录则检验该transcript的range是否已被记录
        #           若该transcript的range未被记录，则返回1，表示该transcript未被记录
        #           若该transcript的range已被记录, 则进一步检验该transcript的exon组成是否已被记录
        #               若该transcript的exon组成已被记录, 则返回exist_transcript_id, 表示该transcript已被记录
        #               若该transcript的exon组成未被记录，则返回3, 表示该transcript未被记录

        if transcript_id in self.transcript_dict.keys():
            return transcript_id  # transcript_id已被记录
        else:
            for exist_transcript_id in self.transcript_dict.keys():
                if start != self.transcript_dict[exist_transcript_id]["range"][0]:
                    continue
                else:
                    if end != self.transcript_dict[exist_transcript_id]["range"][1]:
                        continue
                    else:
                        # transcript的id未被记录, 但start与end均已被记录且与exist_transcript_id的start与end一致
                        # 接下来比较transcript与exist_transcript的exon组成是否一致

                        # 将exist_transcript的exon_range由dict转为list
                        exist_transcript_exon = []
                        for exist_start,exist_end_list in self.transcript_dict[exist_transcript_id]["exon_range"].items():
                            for exist_end in exist_end_list:
                                exist_transcript_exon.append([exist_start, exist_end])

                        # 判断exon的数量是否一致,
                        length = len(exon_list)
                        if length != len(exist_transcript_exon):
                            # 若不一致则表示该transcript未被记录
                            return '3'  # range一致但exon组成不一致, 新transcript
                        else:
                            # 若一致，判断exon组成是否一致
                            # 将两个exon_list按start的位置升序排序，在两两比较list是否一致
                            exon_list = sorted(exon_list, key=lambda x: x[0])
                            exist_transcript_exon = sorted(exist_transcript_exon, key=lambda x: x[0])
                            for i in range(0,length):
                                if exon_list[i] != exist_transcript_exon[i]:
                                    # 若存在不一致，则表示transcript为新transcript
                                    return '3'  # range一致但exon组成不一致, 新transcript
                            return exist_transcript_id  # range一致且exon一致, 旧transcript
            return '1'  # range未被记录, 新transcript


    def add_transcript(self, transcript_id, transcript_name, start, end, exon_list, sample_name, sample_counts):
        # change:
        #   对基因增加一个新的transcript, 并记录该transcript在相应样本中的counts数
        # output:
        #   bool, 若成功添加该transcript并添加了相应counts数则返回True，若未添加transcript只添加了相应counts数则返回False

        exist_mark = self.__check_transcript_exist(transcript_id, start, end, exon_list)
        if exist_mark == '1' or exist_mark == '3':
            # 将exon_list转换为transcript_dict中exon_range的格式  {<start>: [<end>], <start>: [<end>, <end>], ...}
            exon_range = {}
            for [exon_start, exon_end] in exon_list:
                if exon_start not in exon_range.keys():
                    exon_range[exon_start] = [exon_end]  # list type
                else:
                    temp = exon_range[exon_start]
                    temp.append(exon_end)
                    exon_range[exon_start] = temp
            # gene中的transcript_dict新增一个transcript_id，并添加相应的键值对
            self.transcript_dict[transcript_id] = {"transcript_name": transcript_name,
                                                   "range": [start, end],
                                                   "exon_range": exon_range,
                                                   sample_name: sample_counts
                                                   }
            return True
        else:
            # 仅更新exist_transcript_id在相应样本中的counts数
            self.transcript_dict[exist_mark][sample_name] = sample_counts
            return False


    def get_df(self, sample_name_list=[]):
        # change:
        #   整理该gene中的所有信息，返回一个pandas.DataFrame对象
        # output:
        #   df, pandas.DataFrame, 含有列chr, strand, source, gene_id, transcript_id, transcript_start, transcript_end, sample_name...
        sample_name_list = sample_name_list

        chr = self.chr
        strand = self.strand
        source = self.source
        gene_id = self.gene_id
        gene_name = self.gene_name

        columns = ["chr", "strand", "source", "gene_id", "gene_name", 
                   "transcript_id", "transcript_name", "transcript_start", "transcript_end"]
        columns = columns + sample_name_list
        df = pandas.DataFrame(index=list(self.transcript_dict.keys()),
                              columns=columns)
        
        for transcript_id in self.transcript_dict.keys():
            transcript_name = self.transcript_dict[transcript_id]["transcript_name"]
            transcript_start = self.transcript_dict[transcript_id]["range"][0]
            transcript_end = self.transcript_dict[transcript_id]["range"][1]

            df.at[transcript_id, "chr"] = chr
            df.at[transcript_id, "strand"] = strand
            df.at[transcript_id, "source"] = source
            df.at[transcript_id, "gene_id"] = gene_id
            df.at[transcript_id, "gene_name"] = gene_name
            df.at[transcript_id, "transcript_id"] = transcript_id
            df.at[transcript_id, "transcript_name"] = transcript_name
            df.at[transcript_id, "transcript_start"] = transcript_start
            df.at[transcript_id, "transcript_end"] = transcript_end

            for sample_name in sample_name_list:
                df.at[transcript_id, sample_name] = self.transcript_dict[transcript_id].get(sample_name, 0)
        
        return df



In [11]:
# Function
def log(function):
    def wrapper(*args, **kwargs):
        print('\n')
        print("[Function]{} start.".format(function.__name__))
        print("\t[Time]time: {}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))))
        for key, value in kwargs.items():
            if type(value) in [int, str, bool, list]:
                print("\t[Paraments]{}: {}".format(key, value))
            else:
                print('\t[Paraments]{}: <...>'.format(key, value))
        result = function(*args, **kwargs)
        print("[{}]{} finished.".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())),
                                        function.__name__))
        return result
    return wrapper


def get_file_path(file_path, sample_name_list):
    # input:
    #   file_path, str, data文件夹路径
    #   sample_name_list, list, 含有样本名的列表
    # change:
    #   搜索file_path下的每个样本文件夹中的文件名,
    #       记录annotation及quantification文件的绝对路径
    # output:
    #   sample_file_location, dict, 该字典存储每个样本的annotation及quantification文件的绝对路径
    #                               {<sample1>:{"annotation": <dir1>, "quantification": <dir2>},
    #                                <sample2>:{"annotation": <dir1>, "quantification": <dir2>}, ...}

    file_path = file_path
    sample_name_list = sample_name_list

    sample_file_location = {sample: {} for sample in sample_name_list}
    for sample in sample_name_list:
        filename_list = os.listdir("{}raw_data/{}".format(file_path, sample))
        for filename in filename_list:
            if re.search(pattern="annotation", string=filename) is not None:
                sample_file_location[sample]["annotation"] = "{}raw_data/{}/{}".format(file_path, sample, filename)
            elif re.search(pattern="quantification", string=filename) is not None:
                sample_file_location[sample]["quantification"] = "{}raw_data/{}/{}".format(file_path, sample, filename)
            else:
                continue

    return sample_file_location


def load_annotation(input_filename, additional_info, only_type="transcript"):
    input_filename = input_filename
    additional_info_list = additional_info
    only_type = only_type

    columns_name = ["chr", "source", "type", "start", "end", "strand"]
    columns_name = columns_name + additional_info_list

    df = []
    with open(input_filename, 'r') as file:
        for line in file:
            line = line.strip('\n')
            line = line.split('\t')

            # 只保留transcript
            if only_type == "all":
                pass
            elif line[2] != only_type:
                continue
            
            # 过滤掉 chr为ERCC 或 未能准确匹配到染色体上 的注释
            if line[0][0:4] == "ERCC":
                continue
            elif line[0][0:5] == "chrUn":
                continue
            elif line[0][-6:] == "random":
                continue
            
            # 去除不想要的列
            line.pop(7)
            line.pop(5)

            # 统计最后一列的信息
            additional_info = {}
            temp_line = line[-1].split(';')
            temp_line.remove('')
            temp_line = [i.strip(' ') for i in temp_line]
            for i in temp_line:
                key,value = i.split(' ')
                value = value.strip('"')
                additional_info[key] = value

            # 保存前8列以及指定信息
            line = line[0:6]
            for i in additional_info_list:
                line.append(additional_info.get(i,""))

            df.append(line)

    df = pandas.DataFrame(df, columns=columns_name)

    return df


def load_quantification(input_filename,
                        sample_name,
                        cutoff_value=1,
                        col=["annot_gene_id", "annot_transcript_id", "gene_novelty", "transcript_novelty"]):
    # input:
    #   input_filename, str, the absolute path of quantification file
    #   sample_name, str, the name of sample
    #   col, list, the columns need to be retained of the df
    # change:
    #   load input file and retain the specified columns
    #   filter rows according to cutoff_value
    # output:
    input_filename = input_filename
    sample_name = sample_name
    cutoff_value = cutoff_value
    col = col
    
    df = pandas.read_csv(input_filename, sep='\t')
    col_list = df.columns.to_list()
    df.rename(columns={col_list[-1]: sample_name}, inplace=True)
    col = [col_list.index(x) for x in col]
    col.append(len(col_list)-1)  # the last column in df
    df = df.iloc[:,col]

    sig_row = df[sample_name].map(lambda x: x >= cutoff_value)
    df = df.loc[sig_row,:]

    return df


def merge_quantification_annotation(quantification, annotation):
    df_quan = quantification
    df_annotation = annotation

    # 注意：这一步删除了annotation中不存在于quantification的transcript
    df = pandas.merge(df_quan, df_annotation,
                    left_on="annot_transcript_id", right_on="transcript_id",
                    how="left")
    # 删除quantification中不存在于annotation的transcript
    df = df.dropna(axis=0, how="any", subset="transcript_id")
    # 删除df中的gene_id列与transcript_id列，因为与quantification重复
    df = df.drop(labels=["gene_id", "transcript_id"], axis=1)

    return df


def get_gene_info(df):
    # 统计每个基因的TSS及TES的数量
    df = df

    df_group = df.groupby("annot_gene_id")
    gene_info = pandas.DataFrame(index=list(set(df["annot_gene_id"].to_list())),
                                columns=["start","end","strand","transcripts"])

    process_n = 0
    process_end = len(df_group)
    for gene_id in df_group.groups.keys():
        # 进度条
        process_n += 1
        print("\t[Process]{}/{}".format(process_n, process_end), end='\r')
        # 获得单个基因的所有转录本信息
        temp_df = df_group.get_group(gene_id)

        num_transcript = temp_df.shape[0]
        strand = list(set(temp_df["strand"].to_list()))
        # 根据strad对起始点及终止点的数量进行统计
        if len(strand) == 1 and strand[0] == '-':
            # 负链
            num_end = len(set(temp_df["start"].to_list()))
            num_start = len(set(temp_df["end"].to_list()))
            strand = strand[0]
        elif len(strand) == 1 and strand[0] == '+':
            # 正链
            num_start = len(set(temp_df["start"].to_list()))
            num_end = len(set(temp_df["end"].to_list()))
            strand = strand[0]
        else:
            raise KeyError("[错误]strand: {}".format(strand))

        gene_info.at[gene_id, "start"] = num_start
        gene_info.at[gene_id, "end"] = num_end
        gene_info.at[gene_id, "strand"] = strand
        gene_info.at[gene_id, "transcripts"] = num_transcript
    
    return gene_info

@log
def load_sample(total, annotation_location, quantification_location, sample_name, 
                counts_cutoff, annotation_col, quantification_col):
    # input:
    #   total, Total, 含有整理的所有信息的对象
    #   annotation_location, str, 单个样本的annotation文件
    #   quantification_location, str, 单个样本的quantification文件
    #   sample_name, str, 单个样本的样本名称
    #   counts_cutoff, int, quantification文件中counts的阈值(>=value)
    #   annotation_col, list, 在annotation文件中要保留的额外的列名
    #   quantification_col, list, 在quantification文件中要保留的额外的列名
    # change:
    #   整合单个样本的annotation和quantification数据到total对象中
    # output:
    #   total, Total, 存储整合annotation和quantification文件得到的所有数据

    total = total
    annotation_location = annotation_location
    quantification_location = quantification_location
    sample_name = sample_name
    counts_cutoff = counts_cutoff
    annotation_col = annotation_col
    quantification_col = quantification_col

    # 获取单个样本的transcript信息
    df_annotation = load_annotation(input_filename=annotation_location,
                            additional_info=annotation_col,
                            only_type="transcript")
    df_quantification = load_quantification(input_filename=quantification_location,
                                            sample_name=sample_name,
                                            cutoff_value=counts_cutoff,
                                            col=quantification_col)
    df_temp = merge_quantification_annotation(df_quantification, df_annotation)
    df_temp[sample_name] = df_temp[sample_name].astype(int)
    df_temp["start"] = df_temp["start"].astype(int)
    df_temp["end"] = df_temp["end"].astype(int)


    # 获取单个样本的gene信息
    gene_info = load_annotation(input_filename=annotation_location,
                            additional_info=annotation_col,
                            only_type="gene")
    expressed_gene = list(set(df_temp["annot_gene_id"].to_list()))
    expressed_gene = pandas.DataFrame(expressed_gene)
    gene_info = pandas.merge(expressed_gene, gene_info, left_on=0, right_on="gene_id", how="left")
    gene_info["start"] = gene_info["start"].astype(int)
    gene_info["end"] = gene_info["end"].astype(int)

    # 记录单个样本的gene信息
    exist_gene_id = {}  # 存储重复出现的gene的id {<new_gene_id>: existed_gene_id}
    for i in gene_info.index:
        chr = gene_info.at[i, "chr"]
        gene_id = gene_info.at[i, "gene_id"]
        gene_name = gene_info.at[i, "gene_name"]
        source = gene_info.at[i, "source"]
        strand = gene_info.at[i, "strand"]
        start = gene_info.at[i, "start"]
        end = gene_info.at[i, "end"]
        
        # gene_mark有三种值: "True", "False", <gene_id> 此处的<gene_id>代表与当前Gene对象重复的gene的id
        gene_mark = total.add_gene(Gene(chr=chr, gene_id=gene_id, gene_name=gene_name, source=source,
                                   strand=strand, start=start, end=end,
                                   transcript_dict={}, exon_dict={}))
        if gene_mark == "True" or gene_mark == "False":
            pass
        else:
            exist_gene_id[gene_id] = gene_mark


    # 获取单个样本的exon信息
    exon_info = load_annotation(input_filename=annotation_location,
                                additional_info=annotation_col,
                                only_type="exon")
    expressed_gene = list(set(df_temp["annot_transcript_id"].to_list()))
    expressed_gene = pandas.DataFrame(expressed_gene)
    exon_info = pandas.merge(expressed_gene, exon_info, left_on=0, right_on="transcript_id")
    exon_info["start"] = exon_info["start"].astype(int)
    exon_info["end"] = exon_info["end"].astype(int)

    # 记录单个样本的exon信息
    exon_info_group = exon_info.groupby("transcript_id")
    for transcript_id in exon_info_group.groups.keys():
        temp_df = exon_info_group.get_group(transcript_id)

        gene_id = temp_df.at[temp_df.index[0], "gene_id"]
        # 由于单个样本的gene信息已在先前步骤中被统计，所以，此时再对同一样本的gene添加exon信息时,
            # 不会出现未被记录的gene. 也就是说，此处exon所对应的gene可分为两种情况：
            #   1.该gene_id已被记录  2.该gene_id未被记录但已记录相应的gene信息(即gene信息相同但id不同)
        if gene_id in exist_gene_id.keys():
            # 该exon所在的gene属于重复出现的gene(该gene_id被映射到其他的gene_id)
            gene_id = exist_gene_id.get(gene_id)
        else:
            # 该exon所在的gene未被映射到其他gene_id
            pass
        
        # 向指定gene_id添加exon信息
        for i in temp_df.index:
            total.gene_dict[gene_id].add_exon(temp_df.at[i,"start"], temp_df.at[i,"end"])


    # 记录单个样本的transcript信息
    # 根据df_temp
    # 向每一个gene添加相关的transcript_id, transcript的start及end, 在单个样本中的counts数量
    temp_transcript_info = df_temp.set_index("annot_transcript_id")
    exon_info_group = exon_info.groupby("transcript_id")
    for transcript_id in temp_transcript_info.index:
        temp_exon_info = exon_info_group.get_group(transcript_id)

        transcript_name = temp_transcript_info.at[transcript_id, "annot_transcript_name"]
        start = temp_transcript_info.at[transcript_id, "start"]
        end = temp_transcript_info.at[transcript_id, "end"]
        sample_name = sample_name
        sample_counts = temp_transcript_info.at[transcript_id, sample_name]
        exon_list = [[temp_exon_info.at[i, "start"], temp_exon_info.at[i, "end"]] for i in temp_exon_info.index]
        
        gene_id = temp_transcript_info.at[transcript_id, "annot_gene_id"]
        if gene_id in exist_gene_id.keys():
            gene_id = exist_gene_id.get(gene_id)
        else:
            pass

        total.gene_dict[gene_id].add_transcript(transcript_id=transcript_id,
                                                transcript_name=transcript_name,
                                                start=start, end=end,
                                                exon_list=exon_list,
                                                sample_name=sample_name,
                                                sample_counts=sample_counts)


    # 获取单个样本的sample_name并添加到total的对象中
    sample_name = sample_name
    sample_list = total.sample_list
    sample_list.append(sample_name)
    total.sample_list = sample_list

    return total

@log
def save_df(total, file_path):
    # input:
    #   total, Total, 存储数据的对象
    #   file_path, str, 要保存的文件的绝对路径
    # change:
    #   将Total对象转为pandas.DataFrame并保存到指定位置
    # output:
    #   df, pandas.DataFrame, 由Total转换而来的df
    total = total
    file_path = file_path

    df = total.get_df(total.sample_list)
    df = df.sort_values(by="gene_id")
    df = df.sort_values(by="chr")
    df.to_csv(file_path, sep='\t')

    return df

In [81]:
# 准备input文件的绝对路径
file_path = get_file_path(file_path=file_path,
                          sample_name_list=sample_name_list)

In [82]:
# 准备存储gene信息
total = Total(sample_list=[], gene_dict={})

In [83]:
# 读取每个样本的数据
for sample_name in sample_name_list:
    quantification_file_location = file_path[sample_name]["quantification"]
    annotation_file_location = file_path[sample_name]["annotation"]

    total = load_sample(total,
                        annotation_location=annotation_file_location,
                        quantification_location=quantification_file_location,
                        sample_name=sample_name,
                        counts_cutoff=counts_cutoff,
                        annotation_col=annotation_col,
                        quantification_col=quantification_col)



[Function]load_sample start.
	[Time]time: 2023-07-12 09:38:45
	[Paraments]annotation_location: F:/OneDrive/Master/Project/trans/data/raw_data/GSM6783527/GSM6783527_ENCFF384PHN_transcriptome_annotations_GRCh38.gtf
	[Paraments]quantification_location: F:/OneDrive/Master/Project/trans/data/raw_data/GSM6783527/GSM6783527_ENCFF816TJN_transcript_quantifications_GRCh38.tsv
	[Paraments]sample_name: GSM6783527
	[Paraments]counts_cutoff: 2
	[Paraments]annotation_col: ['gene_id', 'transcript_id', 'gene_name']
	[Paraments]quantification_col: ['annot_gene_id', 'annot_transcript_id', 'annot_transcript_name']
[2023-07-12 09:39:52]load_sample finished.


[Function]load_sample start.
	[Time]time: 2023-07-12 09:39:52
	[Paraments]annotation_location: F:/OneDrive/Master/Project/trans/data/raw_data/GSM6782551/GSM6782551_ENCFF856FNN_transcriptome_annotations_GRCh38.gtf
	[Paraments]quantification_location: F:/OneDrive/Master/Project/trans/data/raw_data/GSM6782551/GSM6782551_ENCFF217QQW_transcript_quantific

In [84]:
# 保存数据文件
df = save_df(total, output_df_filename)
with open(output_pickle_filename, 'wb') as file:
    pickle.dump(total, file)



[Function]save_df start.
	[Time]time: 2023-07-12 09:41:26
[2023-07-12 09:42:37]save_df finished.


In [None]:
# debug-找几个gene_id，确定原始数据与合并后数据是否一致

---