In [None]:
import os,sys
import re
import numpy as np
import pandas as pd
import glob

## 2021-10-25  Expectation Maximization for Multinomial mixture model estimation using Python
## Reference:
## Shenhav, L., Thompson, M., Joseph, T.A. et al. FEAST: fast expectation-maximization for microbial source tracking. Nat Methods 16, 627–632 (2019). 
##https://doi.org/10.1038/s41592-019-0431-x


def run_EM(countfile,metadata,iteration):
## preparation for data: source count matrix(with unknown source initiation,for population simulation),sink count matrix(vector),
## observed source count matrix (fill unknown source count matrix with 0) , initial proportion of sources generated by Uniform Distribution
    filename = metadata.split("\\")[-1].replace("metadata.txt","")
    print(f'{filename} estimation started\n')
    wb = pd.read_csv(countfile,sep="\t",header=0,index_col = 0)
    df = pd.DataFrame(wb)
    mwb = pd.read_csv(metadata,sep="\t",header = 0,index_col = 0)
    mdf = pd.DataFrame(mwb)
    ## labeling all sink samples and source samples
    source_idx = list(mdf[(mdf['SourceSink']=='Source')].index)
    sink_idx = list(mdf[(mdf['SourceSink']=='Sink')].index)
    taxa_idx = list(df.index)
    sink_name = sink_idx[0]
    source_matrix = np.array(df[source_idx])
    sink_matrix = np.array(df[sink_idx])
    ob_unknown = [0 for x in range(len(sink_matrix))]
    ob_source_matrix = np.c_[source_matrix,np.array(ob_unknown)]
    source_matrix = unknown_sources_init(sink_matrix,source_matrix)
    print("unknown source initiation completed")
    alpha = np.random.uniform(low=0,high=1,size=(source_matrix.shape[1],1))
    alpha = np.array([x/np.sum(alpha) for x in alpha])
    alpha1 = alpha
    ## estimate latent variable alpha & dist params by iteration
    param_est = [alpha1[0]]
    for i in range(1,iteration+1):
        alpha2 = E_step_refined(source_matrix,alpha1)
        new_res = M_step_refined(source_matrix,sink_matrix,alpha2,ob_source_matrix)
        alpha1 = new_res['alpha']
        source_matrix = new_res['params']
        param_est.append(alpha1[0])
        print(f'round {i} completed\n')
        print(abs(param_est[i]-param_est[i-1]))
        if abs(param_est[i]-param_est[i-1])< 1e-6:
            print("iteration convergence has come\n")
            break

    sink_type = mdf[(mdf['SourceSink']=='Sink')]['Type'].tolist()
    new_res['sink_type'] = sink_type
    new_res['sink_name'] = sink_name
    new_res['sample'] = source_idx
    new_res['taxa'] = taxa_idx
    return(new_res)

    


def unknown_sources_init(sink_cm,source_cm):
    ## first asign general taxa abundance to unknown source by:
    ## unknown_taxa j = (sink[j]-sum(all_sources[j]))(taxa j >= 0)
    ## second , find all core taxa(present in over 50% sources)，fill the core taxa in unknown sources by:
    ## unknown_taxa[which is core taxa] = min(all_sources[core_taxa])/2
    ## finally , adjust the value of taxas which are absent in all known sources by:
    ## unknown_taxa[no presence taxa] = sink[np_taxa]-poison(1,0.5)(np_taxa >= 0)
    unknown_taxa = []
    basestone_taxa_idx = []
    absent_taxa_idx = []
    for j in range(sink_cm.shape[0]):
        unknown_taxa.append(max(float(sink_cm.squeeze()[j]-np.sum(source_cm[j])),0))
        if len(source_cm[j][source_cm[j]>0]) > (len(source_cm[j])/2):
            basestone_taxa_idx.append(j)
        if len(source_cm[j][source_cm[j]==0]) == len(source_cm[j]):
            absent_taxa_idx.append(j)
    for taxa in basestone_taxa_idx:
        unknown_taxa[taxa] = min(source_cm[taxa])/2
    for taxa in absent_taxa_idx:
        unknown_taxa[taxa] = max(sink_cm.squeeze()[taxa]-np.random.poisson(lam=0.5),0)
    print(len(unknown_taxa))
    print(np.array(unknown_taxa).shape)
    new_source_cm = np.c_[source_cm,np.array(unknown_taxa)]
    print(new_source_cm.shape)
    return(new_source_cm)


def E_step_refined(source_sample,alpha):
    ## expect conditional probability of (P(i|j)) by:
    ## JΣP(i|j) = JΣαi*γij / KΣJΣαi*γij
    numerator_list = []
    new_alpha = []
    for i in range(source_sample.shape[1]):
        #numerator = [source_sample[i,j]*alpha[i] for j in range(len(source_sample[i]))]
        numerator = (np.array(source_sample[:,i])*np.array(alpha[i])).tolist()
        numerator_list.append(numerator)
    numerator_array = np.transpose(np.array(numerator_list))
    denominator = np.sum(numerator_array)
    for numerator in numerator_list:
        new_proport = sum(numerator)/denominator
        new_alpha.append(new_proport)
    new_alpha = np.array(new_alpha)
    new_alpha = new_alpha[:,np.newaxis]
    print(np.sum(new_alpha))
    print(new_alpha.shape)
    return(new_alpha)
def M_step_refined(source_sample,sink_sample,alpha,ob_source_sample):
    ## calculate new alpha by:
    ## αi =  JΣxj*P(i|j)
    new_res = dict()
    sink_reab = np.array([np.float64(x)/np.sum(sink_sample) for x in sink_sample])
    sink_reab = np.nan_to_num(sink_reab)
    multi_alpha = np.array([alpha.squeeze() for x in range(source_sample.shape[0])])
    total_denomin = source_sample*multi_alpha
    new_alpha = []
    new_source_sample = []
    for i in range(source_sample.shape[1]):
        cp1 = sink_reab.squeeze()*np.array([np.float64(source_sample[j,i])/np.sum(total_denomin[j]) for j in range(source_sample.shape[0])])
        cp1 = np.nan_to_num(cp1)
        cp2 = np.array([float(alpha[i]) for x in range(source_sample.shape[0])])
        sub_numerat = cp1*cp2
        sub_alpha = sum(sub_numerat) 
        new_alpha.append(sub_alpha)
    new_alpha = np.array(new_alpha)
    new_alpha = new_alpha[:,np.newaxis]
    ## calculate new source parameter by:
    ## γij =  xj*P(i|j)+yij/JΣ(xj*P(i|j)+yij)
    for i in range(source_sample.shape[1]):
        cp3 = sink_sample.squeeze()*np.array([np.float64(source_sample[j,i])/np.sum(total_denomin[j]) for j in range(source_sample.shape[0])])
        cp3 = np.nan_to_num(cp3)
        cp4 = np.array([float(alpha[i]) for x in range(source_sample.shape[0])])
        sub_sub_num = cp3*cp4
        sub_num = sub_sub_num + ob_source_sample[:,i]
        total_sub = (sub_num/sum(sub_num))
        total_sub = np.nan_to_num(total_sub)
        total_sub = total_sub.tolist()
        new_source_sample.append(total_sub)
    new_source_sample = np.transpose(np.array(new_source_sample))
    new_res['alpha'] = new_alpha
    new_res['params'] = new_source_sample
    return(new_res)
def merge_count_matrix(filepath):
    file_list = list(glob.glob(os.path.join(filepath,"*.txt")))
    union_taxa_list = []
    merge_table = []
    sample_list = ['taxa']
    
    print(file_list)
    ## find union of all samples' taxas
    for file in file_list:
        wb = pd.read_csv(file,sep="\t",header=None)
        df = pd.DataFrame(wb)
        switch_dict = {True:len(df.columns)-1,False:0}
        col_name = []
        for j in range(len(df.columns)):
            col_name.append(df[0:1][j][0])
        if '#ID' in col_name:
            otu_flag = True
        else:
            otu_flag = False

        for i in range(1,len(df)):
            drow = df[i:i+1]
            taxa = str(drow[switch_dict[otu_flag]][i])
            if taxa not in union_taxa_list:
                union_taxa_list.append(taxa)
    merge_table.append(union_taxa_list)
    
    ## merge all samples into one count matrix
    for file in file_list:
        wb = pd.read_csv(file,sep="\t",header=None)
        df = pd.DataFrame(wb)
        switch_dict = {True:len(df.columns)-1,False:0}
        col_name = []
        for j in range(len(df.columns)):
            col_name.append(df[0:1][j][0])
        if '#ID' in col_name:
            otu_flag = True
        else:
            otu_flag = False
        site_taxa = set(list(df[df.columns[switch_dict[otu_flag]]]))
        df.set_index(switch_dict[otu_flag],inplace=True)            
        #print(site_taxa)

        for j in range(1,len(df.columns)):
            sample_list.append(df[0:1][j][0])
            sample_table = []
            for u_taxa in union_taxa_list:
                if u_taxa in site_taxa:
                    if isinstance(df.loc[u_taxa][j],str) or isinstance(df.loc[u_taxa][j],float) or isinstance(df.loc[u_taxa][j],int):
                        taxa_abundance = float(df.loc[u_taxa][j])
                    else:
                        taxa_ablist = list(df.loc[u_taxa][j])
                        taxa_ablist = [float(x) for x in taxa_ablist]
                        taxa_abundance = np.sum(taxa_ablist)
                    sample_table.append(taxa_abundance)
                else:
                    sample_table.append(float(0))

            merge_table.append(sample_table)
        print(f'file {file} screening completed\n')
    t_merge_table = np.transpose(merge_table).tolist()
    t_merge_df = pd.DataFrame(t_merge_table,columns = sample_list)
    # replace by your own path
    t_merge_df.to_csv("D:/ShanxiJiankang/2021_shanghai_household_data/multiomics/merge_indoor_countfile.txt",sep="\t")



        


def main():

    # merge count matrix, mostly unused
    if  not os.path.exists("D:/ShanxiJiankang/2021_shanghai_household_data/mice_gut/true_countfile_p1.txt"):
        merge_count_matrix("D:\\ShanxiJiankang\\2021_shanghai_household_data\\mice_gut\\read_count_p1")
    
    #replace by your own meta file directory
    metapath = "D:\\ShanxiJiankang\\2021_shanghai_household_data\\another_analysis\\cohort_seq_meta\\6m_meta\\"
    meta_list = list(glob.glob(os.path.join(metapath,"*.txt")))
    #replace by your own count matrix file path
    countfile = "D:/ShanxiJiankang/2021_shanghai_household_data/another_analysis/Species_counts.txt"
    for metafile in meta_list:
        metadata = metafile
        new_res = run_EM(countfile,metadata,1000)
        sink_name = new_res['sink_name']
        sink_type = new_res['sink_type']
        sample_list = new_res['sample']
        sample_list.append('unknown')
        taxa_list = new_res['taxa']
        candidate_source_df = pd.DataFrame(new_res['params'],columns=sample_list,index=taxa_list)
        candidate_alpha_df = pd.DataFrame(new_res['alpha'],index=sample_list)
        
        ## output 1.alpha: proportion in different sources 2. source matrix: proportion in different sources 
        ## per taxa
        #replace by your own source matrix file path
        candidate_source_df.to_csv("D:/ShanxiJiankang/2021_shanghai_household_data/multiomics/source_res_indoor/"+sink_name+"_source.txt",sep="\t")
        #replace by your own alpha file path
        candidate_alpha_df.to_csv("D:/ShanxiJiankang/2021_shanghai_household_data/multiomics/source_res_indoor/"+sink_name+"_alpha.txt",sep="\t")
main()



