In [1]:
# Hi-C class script
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import zscore
import argparse
import itertools

data_path = "/Users/ninaxiong/projects/HST508_Fall2022_final_project/data/ENCODE_bulk_rna_seq.csv"
gene_loc_path = "/Users/ninaxiong/projects/HST508_Fall2022_final_project/data/gene_locations.tsv"
tad_path = "/Users/ninaxiong/projects/HST508_Fall2022_final_project/data/TAD_strong_boundary_start_end.csv"
chrlen_path = "/Users/ninaxiong/projects/HST508_Fall2022_final_project/data/chr_lengths"

def extract_data(data_path,gene_loc_path,tad_path,excl_chrom=['chrM','chrX','chrY']):
    """
    excl_chrom: list of strings corresponding to chromosomes to be excluded
    data_path: path to single cell tpm data
    gene_loc_path: path to file containing information on gene locations
    ===
    return a dataframe with genes as rows and samples as columns
    make sure the genes are contained in the chromosome location info file (take intersection)
    """
    # read in data
    sc_df = pd.read_csv(data_path,index_col=0)
    sc_df = sc_df.loc[[idx for idx in sc_df.index if 'ENSMUSG' in idx]]
    gloc = pd.read_csv(gene_loc_path,sep="\t",index_col=0)
    tad = pd.read_csv(tad_path)
    # index manipulation
    sc_df_idx = [idx.split(".")[0] for idx in sc_df.index]
    sc_df.index = sc_df_idx
    gloc_idx = [idx.split(".")[0] for idx in gloc.index]
    gloc.index = gloc_idx
    # get rid of genes with 0 exp across all samples
    sc_df_filtered = sc_df.loc[np.sum(sc_df,axis=1)!=0]
    # get rid of chromosomes in exclusion list
    gloc_filtered = gloc[gloc['seqname'].isin(set(gloc.seqname).difference(excl_chrom))]
    tad_filtered = tad[tad['chrom'].isin(set(gloc.seqname).difference(excl_chrom))]
    # get intersecting genes in both
    gene_list = set(sc_df_filtered.index).intersection(gloc_filtered.index)
    sc_df_filtered = sc_df_filtered.loc[gene_list]
    gloc_filtered = gloc_filtered.loc[gene_list]
    return sc_df_filtered, gloc_filtered, tad_filtered

def chromosome_gene_dict(gloc_data):
    """
    return dictionary where keys are chromosomes and values are gene lists
    """
    chromosomes = set(gloc_data.seqname)
    cg_dict = dict(zip(chromosomes, [None]*len(chromosomes)))
    for chrom in chromosomes:
        gloc = gloc_data[gloc_data.seqname==chrom]
        cg_dict[chrom] = list(gloc.index)
    return cg_dict

def get_genes_in_interval(chrom,start,end,gloc):
    """
    get genes in an interval on a chromosome
    """
    gloc_chr = gloc[gloc['seqname']==chrom]
    gloc_chr = gloc_chr[(gloc_chr['start'] >= start) & (gloc_chr['end'] < end)]
    return list(gloc_chr.index)

def tad_gene_dict(tad_locs,gloc,filter_bar=5):
    tg_dict = dict(zip(list(range(len(tad_locs))),[[]]*len(tad_locs)))
    for i in range(len(tad_locs)):
        data = tad_locs.loc[i]
        tg_dict[i] = get_genes_in_interval(data['chrom'],data['start'],data['end'],gloc)
    if filter_bar > 0:
        tg_dict = {k:v for k,v in tg_dict.items() if len(v)>=filter_bar}
    return tg_dict

def log2norm_tpm(tpm_data):
    """
    returns log2-normalized tpm data
    """
    return zscore(np.log2(tpm_data+1),axis=1)

def get_genes_from_chromosome(chr_name,tpm_data,tad_data,gloc_data):
    """
    chr_name: e.g. 'chr1', the string corresponding to the chromosome you want to extract data on
    tpm_data: dataframe of tpms
    gene_loc_data: dataframe of gene locations
    ===
    returns filtered dataframes corresponding to chromosome of interest
    """
    gloc_filtered = gloc_data[gloc_data['seqname']==chr_name]
    tad_filtered = tad_data[tad_data['chrom']==chr_name]
    genes = gloc_filtered.index
    return tpm_data.loc[genes], gloc_filtered, tad_filtered

def get_chr_lengths(path_to_file="../data/chr_lengths"):
    """
    return a dict of chromosomes and their lengths
    """
    chr_lengths = {}
    infile = open(path_to_file)
    for line in infile:
        line = line.strip().split()
        chr = line[0]
        length = int(line[1])
        chr_lengths[chr] = length
    infile.close()
    return chr_lengths

def slide_boundary(chr, start, end, num_iter=5, x=0.2):
    """
    shift a TAD boundary left or right [num_iter] times with a step size of [x] * TAD size 
    while retaining the size of the TAD. return a list of new boundaries and the
    corresponding distances from the original TAD.
    """
    chr_length = get_chr_lengths()[chr]
    # set step size
    tad_length = end - start
    step_size = 10000 # tad_length * x
    # default: shift right
    direction = 1
    # TAD at end: shift left
    if end + (step_size * num_iter) > chr_length:
        direction = -1
    new_boundaries = []
    distances = []
    step_size = direction * step_size
    distance_from_origin = step_size
    for i in range(num_iter):
        new_start, new_end = start + step_size, end + step_size
        new_boundaries.append((chr, new_start, new_end))
        distances.append(distance_from_origin)
        start, end = new_start, new_end
        distance_from_origin += step_size
    return new_boundaries, distances

def calc_tad_coexp(chr, start, end, gene_loc, tpm):
    """
    calculate the average pairwise correlation coefficient for a given genomic region
    """
    genes = get_genes_in_interval(chr, start, end, gene_loc)
    if len(genes) == 0:
        return None, 0
    tpm_subset = tpm.loc[genes,:]
    corr_df = tpm_subset.transpose().corr()
    avg_corr = np.mean(corr_df.to_numpy())
    return corr_df, avg_corr

def plot_corr_distance(title, distances, correlation, is_tad=True):
    """
    plot correlation as a function of distance (either distance from TAD boundary or distance between gene pairs)
    """
    plt.plot(distances, correlation)
    if is_tad:
        plt.title("Average pairwise correlation")
        plt.xlabel("Distance from TAD boundary in kb")
        plt.ylabel("Average correlation")
    else:
        plt.title("Pairwise correlation")
        plt.xlabel("Distance between genes in bp")
        plt.ylabel("Correlation")
    # plt.savefig(title + "_lineplot.png")
    plt.show()

def plot_tad_heatmap(title, corr_df):
    """
    plot heatmap of gene correlation within a TAD
    """
    sns.clustermap(corr_df)
    plt.show()

def get_highly_correlated_genes(corr_df, percentile=99):
    """
    given a correlation dataframe, return a list of the most highly correlated gene pairs.
    """
    corr_matrix = corr_df.to_numpy()
    values = corr_matrix[np.triu_indices_from(corr_matrix, 1)].flatten()
    threshold = np.percentile(values, percentile)
    genes = corr_df.index
    all_pairs = list(itertools.combinations(genes, 2))
    highly_correlated_pairs = []
    for pair in all_pairs:
        gene1, gene2 = pair[0], pair[1]
        corr = corr_df.loc[gene1,gene2]
        if corr > threshold:
            highly_correlated_pairs.append((gene1, gene2))
    return highly_correlated_pairs

def calc_gene_dist(same_chrom_gene_pair,gene_loc):
    """
    Calculate gene distance assuming genes are on the same chromosome
    Midpoint distance
    """
    gene1 = same_chrom_gene_pair[0]
    gene2 = same_chrom_gene_pair[1]
    gene1_srt = gene_loc.loc[gene1]['start']
    gene1_end = gene_loc.loc[gene1]['end']
    gene2_srt = gene_loc.loc[gene2]['start']
    gene2_end = gene_loc.loc[gene2]['end']
    return abs((gene1_srt+gene1_end)/2-(gene2_srt+gene2_end)/2)

def genes_in_same_tad(gene_pair,tg_dict):
    """
    given a pair of genes as a tuple, determine whether the genes are in the same TAD.
    """
    for v in tg_dict.values():
        if gene_pair[0] in v:
            if gene_pair[1] in v:
                return True
            else:
                return False
    return False

In [15]:
# read in data
tpm_data, gene_loc_data, tad_data = extract_data(data_path, gene_loc_path, tad_path)
# normalize tpm data
norm_tpm = log2norm_tpm(tpm_data)
# make dictionaries
cg_dict = chromosome_gene_dict(gene_loc_data)
tg_dict = tad_gene_dict(tad_data,gene_loc_data)

chromosome_list = ['chr'+str(i+1) for i in range(19)]

  sc_df_filtered = sc_df_filtered.loc[gene_list]
  gloc_filtered = gloc_filtered.loc[gene_list]


In [3]:
# for chromosome in chromosome_list:
#     tpm, gene_loc, tad = get_genes_from_chromosome(chromosome,norm_tpm,tad_data,gene_loc_data)
#     for t in tad.index:
#         if t not in tg_dict.keys():
#             continue
#         tad_chr, tad_start, tad_end = tad.loc[t,:]
#         tad_corr_df, tad_corr = calc_tad_coexp(tad_chr, tad_start, tad_end, gene_loc_data, tpm)
#         new_boundaries, slide_distances = slide_boundary(tad_chr, tad_start, tad_end, 20)
#         corr = [tad_corr]
#         distances = [0] + slide_distances
#         for new_chr, new_start, new_end in new_boundaries:
#             new_corr_df, new_corr = calc_tad_coexp(new_chr, new_start, new_end, gene_loc_data, tpm)
#             corr.append(new_corr)
#         tad_location = "{}_{}-{}".format(tad_chr, tad_start, tad_end)
#         plot_tad_heatmap(tad_location, tad_corr_df)
#         plot_corr_distance(tad_location, np.array(distances)/1000 , corr)

In [33]:
def high_low_corr_genes(corr_df, percentile, high=True):
    """
    if high, return all gene pairs greater than percentile
    if low, return all gene pairs less than percentile
    """
    corr_matrix = corr_df.to_numpy()
    values = corr_matrix[np.triu_indices_from(corr_matrix, 1)].flatten()
    threshold = np.percentile(values, percentile)
    genes = corr_df.index
    all_pairs = list(itertools.combinations(genes, 2))
    correlated_gene_pairs = []
    for pair in all_pairs:
        gene1, gene2 = pair[0], pair[1]
        corr = corr_df.loc[gene1,gene2]
        if high:
            if corr > threshold and corr < 0.99:
                correlated_gene_pairs.append((gene1,gene2))
        else:
            if corr < threshold:
                correlated_gene_pairs.append((gene1,gene2))
    return correlated_gene_pairs
    # pairs_in_tad = []
    # print(len(correlated_gene_pairs))
    # for pair in correlated_gene_pairs:
    #     if genes_in_same_tad(pair, tg_dict):
    #         pairs_in_tad.append(pair)
    # return len(pairs_in_tad) / len(correlated_gene_pairs) * 100

In [30]:
chromosome = "chr3"
tpm, gene_loc, tad = get_genes_from_chromosome(chromosome,norm_tpm,tad_data,gene_loc_data)
all_genes_corr_df = tpm.transpose().corr()
all_genes_corr_df.head()

Unnamed: 0,ENSMUSG00000089493,ENSMUSG00000027712,ENSMUSG00000102218,ENSMUSG00000040016,ENSMUSG00000049796,ENSMUSG00000028145,ENSMUSG00000040600,ENSMUSG00000102846,ENSMUSG00000070372,ENSMUSG00000028081,...,ENSMUSG00000053769,ENSMUSG00000104822,ENSMUSG00000105449,ENSMUSG00000027997,ENSMUSG00000037211,ENSMUSG00000049128,ENSMUSG00000097583,ENSMUSG00000091405,ENSMUSG00000106462,ENSMUSG00000105279
ENSMUSG00000089493,1.0,-0.607134,0.401431,-0.430653,0.294673,0.474112,-0.199498,0.337268,-0.015633,-0.422481,...,0.550158,0.279538,0.165525,-0.388498,-0.370888,-0.248973,0.538285,-0.238905,0.617816,-0.134472
ENSMUSG00000027712,-0.607134,1.0,-0.179539,0.675815,-0.148466,-0.584874,0.178609,-0.478927,0.557784,0.818824,...,-0.615915,-0.093476,-0.1276,0.760592,0.740099,0.590021,-0.509723,0.136095,-0.371516,-0.30066
ENSMUSG00000102218,0.401431,-0.179539,1.0,-0.338086,-0.176866,0.04735,-0.196988,0.26774,0.152257,0.068454,...,0.158865,0.192704,0.242292,0.025125,0.278828,-0.152617,0.303522,0.068005,-0.134568,-0.134568
ENSMUSG00000040016,-0.430653,0.675815,-0.338086,1.0,0.018216,-0.201007,0.558651,-0.258342,0.592005,0.703546,...,-0.292282,0.093897,0.175123,0.6697,0.64358,0.404859,-0.267811,0.179591,-0.093638,-0.380587
ENSMUSG00000049796,0.294673,-0.148466,-0.176866,0.018216,1.0,0.409292,-0.347019,-0.08988,0.23198,-0.323303,...,0.43467,-0.163697,0.161233,-0.273581,-0.164675,-0.383144,0.09418,-0.249552,0.433385,-0.238792


In [34]:
# top 99 percentile
high_low_corr_genes(all_genes_corr_df, 99, True)

22738


0.9411557744744481

In [32]:
# bottom 1 percentile
high_low_corr_genes(all_genes_corr_df, 1, False)

22738


0.07916263523616854