# This notebook takes Hi-C interaction data in an IJV format and chromatin data in .bigwig format and learns, for each chromosome, an attraction-repulsion using all other chromosomes that it then uses to predict an interaction map.

In [None]:
import matplotlib
%matplotlib inline

In [None]:
import numpy as np

#files are fed in as directories and prefixes as for speed we want a IJV or triplet or COO or coordinate 
#format file of interactions for every chromosome
celltype = "HCT116-untreated"
hic_directory = "/Zulu/mike/dumped-hic/HCT116/"
hic_prefix = "HCT116-untreated-q30-KR-dumped"

male = True #whether the .bigwig signals on the X chromosome should be doubled

sizefile = "/Zulu/mike/hg19.chrom.sizes" #genome size file

resolution = 100000 #resolution of the analysis
small_resolution = 100000 #resolution of bigwig signals if you want it to be finer

distance_min = 0 #minimum distance below which not to simulate
distance_min_bins = int(distance_min/resolution)
distance_max = 1000000000 #maximum distance above which not to simulate
distance_max_bins = int(distance_max/resolution)

chrstart = 0 #the index of the first chromosome to simulate
chrstop = 23 #the index of the last chromosome to simulate

tolerance = 10000 #tolerance and maximum iterations for the gradient descent
max_iter = 10

#these are thresholds used for excluding bins of each chromosomfe from the analysis for various reasons

#100kb
if resolution == 100000:
    correction = "sums nonzero medians quant"
    medianhighstd = 3
    medianlowstd= 3
    highstd = 3.5
    lowstd = 2.5
    nz_low_threshold = 3
    nz_high_threshold = 3
    zscore_threshold = 10

#25kb
elif resolution == 25000:
    correction = "sums nonzero"
    medianhighstd = 3
    medianlowstd= 3
    highstd = 5
    lowstd = 5
    nz_low_threshold = 3
    nz_high_threshold = 3
    zscore_threshold = 25
else:
    print("not a supported resolution, manually specify thresholds.")

ignorebins = [False] #the bins which will be excluded from analysis initialized to nothing.


#signal files
filenames = []
#signal names
names = []
filenames.append("hct116-H3K9me3-ENCFF402WZH.bigWig")
names.append("H3K9me3")
filenames.append("hct116-H3K27me3-ENCFF030SYQ.bigWig")
names.append("H3K27me3")
filenames.append("hct116-H3K27ac-ENCFF225QAB.bigWig")
names.append("H3K27ac")

comp_num = len(filenames) #number of compartmental forces being simulated

ignoreplots = False #toggles whether plots will be generated by the exclusion threshold methods. 
#Useful for tuning the above thresholds

tiles = 21 #the number of quantiles for use in the analysis.
#increased quantiles decreases bias, but in practice introduces more noise and increases variance.


In [None]:
#here we read in the sizefile

sizes = open(sizefile,'r')

chrnames = []
chrsizes = []
for line in sizes:
    li = line.split()
    if li[0] == 'Y' or li[0] == 'MT':
        continue
    chrnames.append(li[0])
    chrsizes.append(int(li[1]))

sizes.close()

print(chrnames)

In [None]:
import math
#determine size of each chromosome in bins:
chrsizebins = []

for i in range(len(chrnames)):
    s = int(math.ceil(chrsizes[i]/resolution + 1))
    chrsizebins.append(s)
    
totalsize = sum(chrsizebins)
print(totalsize)
print(chrsizebins)

In [None]:
#here we initialize several variables for 

import numpy as np

max_size = max(chrsizebins) #largest chromosome

D = list(np.zeros(max_size)) #a vector to hold on to the distance normalization.

M = [] #These are the attraction-repulsion mappings that the model will learn.
for i in range(comp_num):
    M.append(np.zeros((tiles,tiles)))

In [None]:
#here we read in the bigwig files and put them into dataframes

import pandas as pd
import pyBigWig
import math
from os.path import commonprefix

resolution_ratio = resolution//small_resolution

dfs = []
mean_dfs = []
#lets make a pandas dataframe
for chrom in chrnames[chrstart:chrstop]:
    medi_dic = {}
    mean_dic = {}
    lastlength = 0
    for name,filen in zip(names,filenames):
        try:
            bw = pyBigWig.open("/Zulu/mike/chips/"+filen)
        except RuntimeError:
            print("Trouble opening {0}".format(filen))
        print(bw.chroms())
        prefix = commonprefix(bw.chroms().keys())
        #calculate number of bins
        chromlength = bw.chroms(prefix+chrom)
        binnum = int(math.ceil(chromlength/small_resolution))
        smalls = bw.stats(prefix+chrom,nBins=binnum)
        medi_vals = []
        mean_vals = []
        for pos in range(resolution_ratio,binnum,resolution_ratio):
            chunk = np.array(smalls[pos-resolution_ratio:pos], dtype=np.float)
            try:
                medi_vals.append(np.median(chunk))
                mean_vals.append(np.mean(chunk))
            except TypeError:
                print("TypeError at:")
                print(chunk)
        
        if lastlength:
            medi_vals.extend([0,0,0,0,0,0,0,0,0,0])
            medi_vals = medi_vals[:lastlength]
            mean_vals.extend([0,0,0,0,0,0,0,0,0,0])
            mean_vals = mean_vals[:lastlength]
        lastlength = len(medi_vals)
        medi_dic[name] = medi_vals
        mean_dic[name] = mean_vals
        
    
    dfs.append(pd.DataFrame(data=medi_dic).fillna(0))
    mean_dfs.append(pd.DataFrame(data=mean_dic).fillna(0))

In [None]:
#here we double every signal on the X chromosome if male is true

if male == True:
    xindex = chrnames.index('X')
    dfs[xindex] = dfs[xindex] * 2

In [None]:
#Here we convert the values from the bigwig files in quantiles. quantiles aren't evenly spaced across range,
#instead each contains the same number of bins.

quantiles = np.linspace(0,1,num=tiles,endpoint=True)

signals = []

#first we just have to construct a big concatenated dataframe so we can call quantile on it.
#Then we just apply those quantiles to each individual chrom df.

concatdf = pd.concat(dfs)

#get rid of zeros, since we are using fold/control bigwigs zeros indicate unmappable regiosn we should exclude
nozconcatdf = concatdf[concatdf.sum(axis=1) != 0]

wg_quantiles = nozconcatdf.quantile(quantiles)

concatdf = []

#now we know the quantile thresholds we construct new columns in the dataframe with the new quantile normalized values
for df in dfs:
    newcolumns = []
    for column in df.iloc[:,-comp_num:]:        
        newcolumns.append([])
        for r in df[column]:
            place = wg_quantiles[column].searchsorted(r)#[0]
            newcolumns[-1].append(place)
        
    for index,column in enumerate(df.iloc[:,:comp_num]):
        df[column+"-tiles"] = pd.Series(newcolumns[index],dtype='int16')


#let's export those quantiles to a file for use in predicting maps in the subsequent notebook.
filename = "{0}-small:{1}-actual:{2}-{3}".format(celltype,small_resolution,resolution,"wgtiles")
wgt = open(filename,'w')

wg_quantiles.to_csv(wgt)

wgt.close()


In [None]:
def mu(X,B):
    v = np.ones(X.shape[0])
    old = v + X @ B
    return old


def logL(Y,X,B):
    """
    A function for computing the log likelihood given Y, X, and B
    """
    u = mu(X,B)
    return np.nansum((Y - u)**2)
    

def grad(Y,X,B):
    """
    A function for computing the gradient given Y, X, and B
    """
    u = mu(X,B)
    return -2 * X.T @ (Y - u)

def hess(Y,X,B):
    """
    A function for computing the Hessian given Y, X, and B
    """
    return (-2*X.T @ X)

In [None]:
def MLE_ncomp(chrindex,resolution=resolution,distance_min=distance_min,comp_num=comp_num,tiles=tiles,
                     distance_min_bins=distance_min_bins,distance_max=distance_max,
                    distance_max_bins=distance_max_bins,chrnames=chrnames,D=D,M=M,
                     correction=correction,
                     chrsizebins=chrsizebins,ignoreplots=ignoreplots,tolerance=tolerance,
                     max_iter=max_iter,hic_directory=hic_directory,hic_prefix=hic_prefix,
                     ignorebins=ignorebins,dfs=dfs,chrstart=chrstart,chrstop=chrstop,
                     medianhighstd=medianhighstd,medianlowstd=medianlowstd,highstd=highstd,lowstd=lowstd,
                      nz_low_threshold=nz_low_threshold,nz_high_threshold=nz_high_threshold):
    """
    This is the main method for learning the attraction-repulsion mapping for each bigwig encoded signal to the interaction maps.
    It begins by exlcuding bad bins from the analysis using the defined thresholds. It then loads in
    """

    import numpy as np
    import pandas as pd
    import math
    import matplotlib.pyplot as plt
    from scipy.sparse import coo_matrix
    from scipy.sparse import csr_matrix
    import seaborn as sns
    from scipy import stats
    
    chrom1 = chrnames[chrindex]
    chrom2 = chrnames[chrindex]
    sizebins = chrsizebins[chrindex]
    print("Loading: {0} vs {1}".format(chrom1,chrom2))
    print("Size: {0}".format(sizebins))
    
    interactions = open(hic_directory+hic_prefix+chrom1+"_"+chrom2+"_"+str(resolution)+".txt",'r') #open interaction file
        
    matrix = np.zeros((sizebins,sizebins)) #initialize empty matrix
    
    for line in interactions:
        li = line.split()
        left = int(li[0])//resolution
        right = int(li[1])//resolution
        score = float(li[2])
        matrix[left,right] = np.nan_to_num(score)
        matrix[right][left] = np.nan_to_num(score)
        
    
    if "medians" in correction:
        medians = np.median(matrix,axis=0)
    
        median_mean = np.mean(medians[medians != 0])
        median_stddev = np.std(medians[medians != 0])
        
        print("Median Mean: {0}".format(median_mean))
        print("Median Std dev: {0}".format(median_stddev))
        high_line_data = np.array([median_mean+(median_stddev*medianhighstd) for i in range(len(medians))])
        low_line_data = np.array([median_mean-(median_stddev*medianlowstd) for i in range(len(medians))])
        if ignoreplots:
            plt.plot(medians)
            plt.plot(high_line_data, 'r--')
            plt.plot(low_line_data, 'r--')
            plt.show()
        medianignorebins = [(median_mean+(median_stddev*medianhighstd) < c) 
                      or median_mean-(median_stddev*medianlowstd) > c for c in medians]
    else:
        medianignorebins = [False for c in range(len(matrix[0]))]        

    
    
    if "sums" in correction:
        sums = np.sum(matrix,axis=0) + np.sum(matrix,axis=1)
        sum_mean = np.mean(sums[sums != 0])
        sum_stddev = np.std(sums[sums != 0])
        print("Sum Mean: {0}".format(sum_mean))
        print("Sum Std dev: {0}".format(sum_stddev))
        high_line_data = np.array([sum_mean+(sum_stddev*highstd) for i in range(len(sums))])
        low_line_data = np.array([sum_mean-(sum_stddev*lowstd) for i in range(len(sums))])
        sumymin = sum_mean+(sum_stddev*highstd*5)
        sumymax = sum_mean-(sum_stddev*lowstd*5)
        
        if ignoreplots:
            plt.plot(sums)
            plt.ylim(sumymin,sumymax)
            plt.plot(high_line_data, 'r--')
            plt.plot(low_line_data, 'r--')
            plt.show()
        sumignorebins = [(sum_mean+(sum_stddev*highstd) < c) or sum_mean-(sum_stddev*lowstd) > c for c in sums]
    else:
        sumignorebins = [False for c in range(len(matrix[0]))]
        
    
    if "nonzero" in correction:
        nonzerocounts = np.count_nonzero(matrix,axis=0)
        nonzero_mean = np.mean(nonzerocounts[nonzerocounts != 0])
        nonzero_stddev = np.std(nonzerocounts[nonzerocounts != 0])
        print("Nonzero Mean: {0}".format(nonzero_mean))
        print("Nonzero Std dev: {0}".format(nonzero_stddev))
        nonzeroignorebins = [(nonzero_mean+(nonzero_stddev*nz_high_threshold) < c) or
                      (nonzero_mean-(nonzero_stddev*nz_low_threshold) > c) for c in nonzerocounts]    
        if ignoreplots:
            plt.plot(nonzerocounts)
            high_line_data = np.array([nonzero_mean+(nonzero_stddev*nz_high_threshold) for i in range(len(nonzerocounts))])
            low_line_data = np.array([nonzero_mean-(nonzero_stddev*nz_low_threshold) for i in range(len(nonzerocounts))])   
            plt.plot(high_line_data, 'r--')
            plt.plot(low_line_data, 'r--')
            plt.show()        
    else:
        nonzeroignorebins = [False for c in range(len(matrix[0]))]  
        
    df = dfs[chrindex-chrstart]
    
    sigs = []
    
    #load in sigs so we can remove 0 quantiles
    signames = df.columns[-comp_num:].tolist()
    print(signames)
    for index, row in df.iloc[:,-comp_num:].iterrows():
        try:
            sigs.append([int(r) for r in row][-comp_num:])
        except ValueError:
            print("NaN")
            print(index)
            print(row)  
    
    if "quant" in correction:
        quantignorebins = [False for c in range(len(matrix[0]))]
        for n in range(len(signames)):
            print()
            for x in range(len(sigs)):
                if sigs[x][n] == 0:
                    quantignorebins[x] = True
    else:
        quantignorebins = [False for c in range(len(matrix[0]))]      
    
    ignorebins = [w | x | y | z for (w,x,y,z) in zip(medianignorebins, sumignorebins, nonzeroignorebins,quantignorebins)]
    print("Ignoring {0} median bins".format(sum(medianignorebins)))
    print("Ignoring {0} sum bins".format(sum(sumignorebins)))
    print("Ignoring {0} zero bins".format(sum(nonzeroignorebins)))
    print("Ignoring {0} quant bins".format(sum(quantignorebins)))

    #now we need to ignore z-norm outliers
    #first we set ignore rows and columns to nan
    #so they will be omitted by zscore operation
    nan_ignore_matrix = np.copy(matrix)
    nan_ignore_matrix[ignorebins,:] = np.nan
    nan_ignore_matrix[:,ignorebins] = np.nan
    
    
    zscoreignorebins = np.zeros(sizebins, dtype=bool)
    for d in range(sizebins):
        zscores = stats.zscore(np.diag(nan_ignore_matrix,d),nan_policy='omit')
        for e,z in enumerate(zscores):
            if z == np.nan:
                zscoreignorebins[e] = False
                zscoreignorebins[e+d] = False
            elif z > zscore_threshold:
                #make both row and column false
                zscoreignorebins[e] = True
                zscoreignorebins[e+d] = True
                
    #now we remove these bad bins
    print("Ignoring {0} zscore bins".format(sum(zscoreignorebins)))
    ignorebins = [x | y for (x,y) in zip(ignorebins,zscoreignorebins)]
    
    print("Ignoring {0} total bins".format(sum(ignorebins)))
    #now matrix is our entirely unnormalized matrix
    #initialize our variables:
    dists = list(D[:sizebins])
    maps = list(M[:])
    
    #here we want to set True on ignore flag for any bins outside sig coverage
    if len(sigs) < len(ignorebins):
        for x in range(len(sigs),len(ignorebins)):
            ignorebins[x] = True
    
    LLL = np.Inf
    count = 0 #intialize count at 0 for iterations
    error = 100000 #initialize error at arbitrary high value to pass first iteration
    
    #OK now let's define all our matrices:
    #to get y we need to compress to remove ignore bins, flatten,
    igno = np.where(ignorebins)[0]
    matrixdf = pd.DataFrame(matrix)

    matrixdf.iloc[:,ignorebins] = np.nan
    matrixdf.iloc[ignorebins,:] = np.nan
    #now nan out the lower triangle:
    matrixdf = matrixdf.where(np.tril(np.triu(np.ones(matrixdf.shape),k=1+distance_min_bins),k=min(distance_max_bins,matrixdf.shape[0])).astype(np.bool))
    matrixdf = matrixdf.stack().reset_index()
    matrixdf.columns = ['i','j','v']
    Y = matrixdf['v'] #the vector of observed values from the interaction map    
 
    #B is now the learned coefficients. Basically the flattened maps. Constructed by comparing sigs to y
    ttiles = tiles-1
    trisize = ((ttiles*(ttiles+1))//2)
    bsize = trisize*comp_num
    B = np.zeros(bsize)    
    
    #lastly we need to define H. This is the expected value based on distance for every value.
    #we can initialize it as 0's and then learn it first.
    H = np.zeros(Y.shape[0])
    #to learn H you learn the full matrix, then you average every distance together
    H_dists = matrixdf['j'] - matrixdf['i']
    Hadj = np.ones(len(dists))
    Hadj_exp = np.ones(Y.shape[0])
    
    #X is a CSR sparse matrix with each column corresponding to a sig-tile and each row a matrix bin.
    #X dimensions are y x B
    #once we have flattened y we can fill in each element in X. Should be 4 1's per row 
    #we construct by iterating over matrixdf, grabbing the appropriate sigs and adding that to a COO or DOK matrix.
    xijdic = {}
    for x in range(comp_num):
        for i in range(ttiles):
            for j in range(i,ttiles):
                y = trisize*x + np.sum(np.array(range(-1*i,0))+ttiles+1) + j-i
                xijdic[y] = [x,i,j]
    
    
    #here we construct the sparse matrix X such that it is 0 in every bin except for the bins
    #that map from a attraction-repulsion mapping weight to a chromosome bin.
    ii = []
    jj = []
    dd = []
    for index, row in matrixdf.iterrows():
        sigs_i = sigs[int(row['i'])]
        sigs_j = sigs[int(row['j'])]
        for x, (i,j) in enumerate(zip(sigs_i,sigs_j)):
            #here x is comp num, i is first quantile, j is second quantile
            #formula to get correct bin therefore is (1+tiles*tiles*x + tiles*i + j)
            if i > j:
                h = i
                i = j
                j = h
            ii.append(index)
            i = i-1
            j = j-1
            y = trisize*x + np.sum(np.array(range(-1*i,0))+ttiles+1) + j-i
            jj.append(y)
            dd.append(1)
        
        
    
    coo = coo_matrix((dd, (ii, jj)), shape=(Y.shape[0], bsize))
    
    X = csr_matrix(coo)
    
    #computing the average distance decay vector
    for d in range(1, len(dists)):
        #d_indices = np.where(H_dists == d)
        dmean = np.mean(Y[H_dists == d])
        H[H_dists == d] = dmean
        dists[d] = dmean

    #Plotting the distance decay vector
    plt.plot(dists)
    plt.yscale('log')
    plt.xscale('log')
    plt.title("Distance Vector {0}".format(chrindex))
    plt.show()
    
    #Sanity printing the shapes of each variable
    print("X {0}".format(X.shape))
    print("Y {0}".format(Y.shape))
    print("B {0}".format(B.shape))
    print("H {0}".format(H.shape))

    #now we will begin iterating
    while error > tolerance and count < max_iter:
        count += 1
        
        #here we adjust the distance normalization according to the average attraction-repulsion mapping weight at each distance.
        for d in range(1,len(dists)):
            XBmean = np.mean(mu(X,B)[H_dists==d])
            Hadj_exp[H_dists==d] = XBmean
            Hadj[d] = XBmean
        
        #now divide by H
        Yh = np.nan_to_num(Y/(H / Hadj_exp))
        
        #now we update B
        lamb = 0.1

        Grad = grad(Yh,X,B) #compute Gradient

        Hess = hess(Yh,X,B).todense() #compute Hessian
        
        lamb = 1/max(np.absolute(Grad))
        
        inv = np.linalg.pinv(Hess) #inverse the Hessian
        
        step = -1 * (inv @ Grad[:, np.newaxis]) #compute the step using the inverted Hessian and the Gradient
        
        B_new = B[:, np.newaxis] - step
        
        B = np.array(B_new).reshape(B.shape[0])

        #here we should reshape B to visualize as 4 heatmaps.
        #inverse of this operation: 1+trisize*x + np.sum(np.array(range(-1*i,0))+tiles+1) + j
        G_reshape = np.zeros((comp_num,ttiles,ttiles))
        for n,g in enumerate(Grad):
            x,i,j = xijdic[n]
            G_reshape[x,i,j] = g
            G_reshape[x,j,i] = g
            
        S_reshape = np.zeros((comp_num,ttiles,ttiles))
        for n,s in enumerate(step):
            x,i,j = xijdic[n]
            S_reshape[x,i,j] = s
            S_reshape[x,j,i] = s
        
        #reshape the weights (B) into squares
        B_reshape = np.zeros((comp_num,ttiles,ttiles))
        for n,b in enumerate(B):
            x,i,j = xijdic[n]
            B_reshape[x,i,j] = b
            B_reshape[x,j,i] = b

        #plot out the learned maps
        for b,n in zip(B_reshape,signames):
            sns.heatmap(b,center=0)
            plt.title("Maps "+ n)
            plt.show()

        #Log Likelihood
        LL = logL(Yh,X,B)

        print("Log likelihood: {0}".format(LL))
        print("Last log likelihood: {0}".format(LLL))
        error = LLL-LL
        print("Chromosome {0} Iteration {1} Improvement: {2}".format(chrom1,count,error))
        LLL = LL
        
        expected = mu(X,B)
        observed = Yh
        expectedc = mu(X,B) * (H / Hadj_exp)
        observedc  = Y
        pearson = np.corrcoef(observed,expected)[0,1]
        print('Pearson: (0)'.format(pearson))
        
        #generate some sample visualizations to see how well it worked.
        
        vis = np.zeros((sizebins,sizebins))
        visc = np.zeros((sizebins,sizebins))
        
        for o,e,i,j in zip(observed,expected,matrixdf['i'],matrixdf['j']):
            vis[i,j] = e
            vis[j,i] = o
        for o,e,i,j in zip(observedc,expectedc,matrixdf['i'],matrixdf['j']):
            visc[i,j] = e
            visc[j,i] = o            
            
        matplotlib.rcParams['figure.figsize'] = (12,12)
        plt.imshow(vis)
        plt.colorbar()
        plt.show()
        
        plt.imshow(vis[100:200,100:200])
        plt.colorbar()
        plt.show()
        matplotlib.rcParams['figure.figsize'] = (6,6)

        matplotlib.rcParams['figure.figsize'] = (12,12)
        plt.imshow(visc,vmin=0,vmax=1000)
        plt.colorbar()
        plt.show()
        
        plt.imshow(visc[100:200,100:200],vmin=0,vmax=1000)
        plt.colorbar()
        plt.show()
        matplotlib.rcParams['figure.figsize'] = (6,6)

        maps = B_reshape
        adj_dists = dists/Hadj
    return (maps, sigs, adj_dists, ignorebins)





In [None]:
#here we run the above script to learn on each chromosome

matplotlib.rcParams['figure.figsize'] = (6,6)

msdi = []
for x in range(chrstart,chrstop):
    msdi.append(MLE_ncomp(x))

In [None]:
#now we unpack the various features MLE_ncomp returned with
mappings = []
signals = []
distances = []
ignored = []

for chromosome in msdi:
    mappings.append(chromosome[0])
    signals.append(chromosome[1])
    distances.append(chromosome[2])
    ignored.append(chromosome[3])

In [None]:
#here we export the learned attraction-repulsion maps, distance normalization, and ignored bins to file.

dfnames = []
for c in dfs[0]:
    dfnames.append(c)

#here let's export tile thresholds and mappings to file
for chrindex in range(chrstart,chrstop):
    corrindex = chrindex - chrstart
    for x,signame in enumerate(dfnames[-comp_num:]):
        mapfileoutname = "/Zulu/mike/learnedmaps/{0}-{1}kb-chr{2}-{3}-learnedmaps-min:max-{4}kb:{5}kb.csv".format(celltype,resolution//1000,chrnames[chrindex],signame,distance_min//1000,distance_max//1000)

        mapdf = pd.DataFrame(mappings[corrindex][x])
        mapdf.to_csv(mapfileoutname,index=False)
        print(mapfileoutname)
        
    distancefileoutname = "/Zulu/mike/learneddistances/{0}-{1}kb-chr{2}-distances-min:max-{3}kb:{4}kb.csv".format(celltype,resolution//1000,chrnames[chrindex],distance_min//1000,distance_max//1000)
    distancedf = pd.DataFrame(distances[chrindex])
    distancedf.to_csv(distancefileoutname,index=False,na_rep=0)
    ignorefileoutname = "/Zulu/mike/learnedignores/{0}-{1}kb-chr{2}-ignore.csv".format(celltype,resolution//1000,chrnames[chrindex])
    ignoredf = pd.DataFrame(ignored[corrindex])
    ignoredf.to_csv(ignorefileoutname,index=False)

