In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import zscore
import argparse
import itertools
from tqdm import tqdm
import random
import warnings

In [2]:
def extract_data(data_path,gene_loc_path,tad_path,hivar_pctl=None,excl_chrom=['chrM','chrX','chrY']):
    """
    excl_chrom: list of strings corresponding to chromosomes to be excluded
    data_path: path to single cell tpm data
    gene_loc_path: path to file containing information on gene locations
    ===
    return a dataframe with genes as rows and samples as columns
    make sure the genes are contained in the chromosome location info file (take intersection)
    """
    # read in data
    sc_df = pd.read_csv(data_path,index_col=0)
    sc_df = sc_df.loc[[idx for idx in sc_df.index if 'ENSMUSG' in idx]]
    gloc = pd.read_csv(gene_loc_path,sep="\t",index_col=0)
    tad = pd.read_csv(tad_path)
    # index manipulation
    sc_df_idx = [idx.split(".")[0] for idx in sc_df.index]
    sc_df.index = sc_df_idx
    gloc_idx = [idx.split(".")[0] for idx in gloc.index]
    gloc.index = gloc_idx
    # get rid of genes with 0 exp across all samples
    sc_df_filtered = sc_df.loc[np.sum(sc_df,axis=1)!=0]
    # filter gene by variance across samples/tissues if thus specified
    if type(hivar_pctl)==int or type(hivar_pctl)==float:
        if hivar_pctl>=0 and hivar_pctl<=100:
            sc_df_filtered = filter_genes_by_variance(sc_df_filtered,hivar_pctl)
    # get rid of chromosomes in exclusion list
    gloc_filtered = gloc[gloc['seqname'].isin(set(gloc.seqname).difference(excl_chrom))]
    tad_filtered = tad[tad['chrom'].isin(set(gloc.seqname).difference(excl_chrom))]
    # get intersecting genes in both
    gene_list = set(sc_df_filtered.index).intersection(gloc_filtered.index)
    sc_df_filtered = sc_df_filtered.loc[gene_list]
    gloc_filtered = gloc_filtered.loc[gene_list]
    return sc_df_filtered, gloc_filtered, tad_filtered
def filter_genes_by_variance(sc_df,percentile):
    """
    filters a TPM (genes x tissue !!!!) dataset by highest percentile of variance across tissue/samples

    IMPORTANT: do not feed in normalized data. That would make this pointless.
    """
    sc_df = sc_df.T
    gene_vars = sc_df.var()
    most_var_genes = (gene_vars >= np.percentile(gene_vars,percentile))
    return sc_df.loc[:,most_var_genes].T # inverts back to genes x samples/tissues


def log2norm_tpm(tpm_data):
    """
    returns log2-normalized tpm data
    """
    return zscore(np.log2(tpm_data+1),axis=1)


def tad_gene_dict(tad_locs,gloc,filter_bar=5):
    tg_dict = dict(zip(list(range(len(tad_locs))),[[]]*len(tad_locs)))
    for i in range(len(tad_locs)):
        data = tad_locs.loc[i]
        tg_dict[i] = get_genes_in_interval(data['chrom'],data['start'],data['end'],gloc)
    if filter_bar > 0:
        tg_dict = {k:v for k,v in tg_dict.items() if len(v)>=filter_bar}
    return tg_dict
def get_genes_in_interval(chrom,start,end,gloc):
    """
    get genes in an interval on a chromosome
    """
    gloc_chr = gloc[gloc['seqname']==chrom]
    gloc_chr = gloc_chr[(gloc_chr['start'] >= start) & (gloc_chr['end'] < end)]
    return list(gloc_chr.index)

In [3]:
def get_genes_from_chromosome(chr_name,tpm_data,tad_data,gloc_data):
    """
    chr_name: e.g. 'chr1', the string corresponding to the chromosome you want to extract data on
    tpm_data: dataframe of tpms
    gene_loc_data: dataframe of gene locations
    ===
    returns filtered dataframes corresponding to chromosome of interest
    """
    gloc_filtered = gloc_data[gloc_data['seqname']==chr_name]
    tad_filtered = tad_data[tad_data['chrom']==chr_name]
    genes = gloc_filtered.index
    return tpm_data.loc[genes], gloc_filtered, tad_filtered

def chromosome_gene_dict(gloc_data):
    """
    return dictionary where keys are chromosomes and values are gene lists
    """
    chromosomes = set(gloc_data.seqname)
    cg_dict = dict(zip(chromosomes, [None]*len(chromosomes)))
    for chrom in chromosomes:
        gloc = gloc_data[gloc_data.seqname==chrom]
        cg_dict[chrom] = list(gloc.index)
    return cg_dict

def get_chr_lengths(path_to_file="../data/chr_lengths"):
    """
    return a dict of chromosomes and their lengths
    """
    chr_lengths = {}
    infile = open(path_to_file)
    for line in infile:
        line = line.strip().split()
        chrom = line[0]
        length = int(line[1])
        chr_lengths[chrom] = length
    infile.close()
    return chr_lengths

def gene_corr_dictionary(tpm):
    gcorr_dict = {}
    corr_matrix = tpm.T.corr()
    for gene_pair in itertools.combinations(tpm.index,2):
        gcorr_dict[gene_pair] = corr_matrix[gene_pair[0]][gene_pair[1]]
    return gcorr_dict

def genes_in_same_tad(gene_pair,tg_dict,return_false_if_same_genes=True):
    """
    given a pair of genes as a tuple, determine whether the genes are in the same TAD.
    """
    if return_false_if_same_genes: # whether to treat same gene as belonging to the same TAD
        if gene_pair[0]==gene_pair[1]:
            return False
    for v in tg_dict.values():
        if gene_pair[0] in v:
            if gene_pair[1] in v:
                return True
            else:
                return False
    return False

In [4]:
def gene_tad_dict(gene_pairs, tg_dict):
    """
    Returns dictionary indicating whether a gene pair shares a TAD
    e.g. {(gene1, gene2) : True/False}
    """
    gt_dict = {}
    for pair in gene_pairs:
        gt_dict[pair] = genes_in_same_tad(pair,tg_dict)
    return gt_dict
def gene_corr_permutation_test(gc_dict, tg_dict, percentile=90, n_iter=100, rand_seed=1):
    """
    Given a dictionary of gene pairs mapping to correlations, get top 90th percentile
    gc_dict should not have autocorrelation key-value pairs
    """
    np.random.seed(rand_seed)
    
    correlations = list(gc_dict.values())
    gpairs = list(gc_dict.keys())
    threshold = np.percentile(correlations, percentile)
    hi_corr_pairs = [gp for gp in gpairs if gc_dict[gp]>threshold]
    gt_dict = gene_tad_dict(gpairs, tg_dict)
    
    if sum(gt_dict.values()) == 0:
        return None # everything will just be 0; no point
    
    proportion = sum([gt_dict[p] for p in hi_corr_pairs])/len(hi_corr_pairs)
    
    sample_size = len(hi_corr_pairs)
    proportions = np.ones(n_iter) # proportion in TAD
    
    for i in range(n_iter):
        sample = random.sample(gpairs,sample_size)
        proportions[i] = sum(gt_dict[gpair] for gpair in sample)/sample_size
    
    return sum(proportions>proportion)/n_iter # p value

In [5]:
data_path = "../data/ENCODE_bulk_rna_seq.csv"
gene_loc_path = "../data/gene_locations.tsv"
tad_path = "../data/TAD_strong_boundary_start_end.csv"

min_genes_in_tad = 5 #30


# read in data
tpm_data, gene_loc_data, tad_data = extract_data(data_path, gene_loc_path, tad_path)
# normalize tpm data
norm_tpm = log2norm_tpm(tpm_data)
# make dictionaries
#cg_dict = chromosome_gene_dict(gene_loc_data)
tg_dict = tad_gene_dict(tad_data,gene_loc_data,filter_bar=min_genes_in_tad)

In [None]:
chromosome_list = ['chr'+str(i+1) for i in range(19)]

for chromosome in chromosome_list:
    print(f"Processing {chromosome}...")
    tpm, gene_loc, tad = get_genes_from_chromosome(chromosome,norm_tpm,tad_data,gene_loc_data)
    gc_dict = gene_corr_dictionary(tpm)
    p_val = gene_corr_permutation_test(gc_dict, tg_dict)
    with open("../results/pvalues.txt",'a') as f:
        f.write(chromosome+"\t"+str(p_val)+"\n")
    f.close()
    print(f"Analysis 4 done.")

Processing chr1...
Analysis 4 done.
Processing chr2...
Analysis 4 done.
Processing chr3...
Analysis 4 done.
Processing chr4...
Analysis 4 done.
Processing chr5...
Analysis 4 done.
Processing chr6...
Analysis 4 done.
Processing chr7...
Analysis 4 done.
Processing chr8...
Analysis 4 done.
Processing chr9...
Analysis 4 done.
Processing chr10...
Analysis 4 done.
Processing chr11...
Analysis 4 done.
Processing chr12...
Analysis 4 done.
Processing chr13...
Analysis 4 done.
Processing chr14...
Analysis 4 done.
Processing chr15...
Analysis 4 done.
Processing chr16...
Analysis 4 done.
Processing chr17...
Analysis 4 done.
Processing chr18...
Analysis 4 done.
Processing chr19...


In [None]:
def high_low_corr_genes2(corr_df, margin=10):
    """
    return all gene pairs greater than percentile
    and all gene pairs less than percentile
    and all genes in bottom (1-percentile) percentile
    """
    corr_matrix = corr_df.to_numpy()
    values = corr_matrix[np.triu_indices_from(corr_matrix, 1)].flatten()
    threshold = np.percentile(values, 100-margin) # default 90th percentile
    threshold_lower = np.percentile(values, margin) # default 10th percentile
    all_pairs = get_gene_pairs(corr_df.index)
    correlated_gene_pairs = []
    low_corr_gene_pairs = []
    neg_corr_gene_pairs = []
    for pair in all_pairs:
        gene1, gene2 = pair[0], pair[1]
        corr = corr_df.loc[gene1,gene2]
        if corr > threshold and corr < 1:
            correlated_gene_pairs.append(pair)
        elif corr < threshold_lower and corr > -1:
            neg_corr_gene_pairs.append(pair)
        elif corr <= threshold and corr >= threshold_lower:
            low_corr_gene_pairs.append(pair)
    return correlated_gene_pairs, low_corr_gene_pairs, neg_corr_gene_pairs
def get_gene_pairs(genes):
    """
    given a list of genes, return all possible pairs (not including self pairs)
    """
    return list(itertools.combinations(genes, 2))

In [None]:
# ignore warnings
warnings.filterwarnings("ignore")

for chromosome in chromosome_list:
    print(f"Processing {chromosome}...")
    tpm, gene_loc, tad = get_genes_from_chromosome(chromosome,norm_tpm,tad_data,gene_loc_data)
    # analysis 1: correlation vs sliding windows
    sliding_windows(tpm, gene_loc, tad, plot=True)
    print("Analysis 1 done.")
    # analysis 2: are highly correlated genes in the same TAD?
    all_genes_corr_df = tpm.transpose().corr()
    analysis2_df = hc_genes_in_tads(chromosome, all_genes_corr_df, gene_loc, tad, tg_dict, plot=True)
    analysis2_df.to_csv(f"../results/analysis2_{chromosome}_df.csv")
    print("Analysis 2 done.")
    # analysis 3: correlation as a function of distance between genes
    analysis3_high_corr_df, analysis3_other_corr_df = corr_vs_dist(all_genes_corr_df, gene_loc_data, plot=True, title = f"Gene distance between high correlated and low correlated genes \n in {chromosome}", save=f"../results/corr_vs_dist_plot_{chromosome}.png")
    analysis3_high_corr_df.to_csv(f"../results/analysis3_{chromosome}_high_corr_df.csv")
    analysis3_other_corr_df.to_csv(f"../results/analysis3_{chromosome}_other_corr_df.csv")
    print(f"Analysis 3 done.")
    
    

In [None]:
# analysis 5: do NMF-identified gene modules tend to be in the same TAD?

In [None]:
# analysis 6: do PCA component elements tend to be in the same TAD?