In [None]:
#multiprocessing
import multiprocessing
import subprocess
import os

#numerics and vis
import math
import scipy.stats as stats
import numpy as np
import pandas as pd
import altair as alt

#disable altair max rows
alt.data_transformers.disable_max_rows()
#get default altair style
%run ~/scripts/altair_style_config_default.py

#trees
import ete3
from ete3 import Tree, TreeStyle, TextFace


root = '/home/tobiassonva/data/eukgen/'
%cd {root}

In [None]:
#small helper for pkl parsing
def load_pkl(pkl_file):
    import pickle
    with open(pkl_file, 'rb') as infile:
        item = pickle.load(infile)
    return item

euk_queries_test2 = ['CBN77353.1', 'CEL94470.1', 'CEL98020.1', 'CEM00912.1',
       'CEM13793.1', 'CEO94447.1', 'CEP02189.1', 'CEP02404.1',
       'EPZ31333.1', 'GBG32138.1', 'GBG34166.1', 'GBG34636.1',
       'GBG88810.1', 'KAA0151157.1', 'KAA0167757.1', 'KAA6364588.1',
       'KAA6383781.1', 'NP_001022034.1', 'NP_001105121.2',
       'NP_001170744.1', 'NP_001189295.1', 'NP_001242666.1',
       'NP_001259573.1', 'NP_001261837.1', 'NP_001294564.1',
       'NP_001307934.1', 'NP_001328712.1', 'NP_012528.1', 'NP_050092.1',
       'NP_051148.1', 'NP_189541.1', 'NP_197350.1', 'NP_498455.2',
       'NP_505960.3', 'NP_588329.1', 'NP_595422.1', 'NP_609709.1',
       'NP_611238.2', 'NP_649295.1', 'OAD00700.1', 'OAD03858.1',
       'OAD05886.1', 'OAE33051.1', 'OLP78629.1', 'OLQ06972.1',
       'OLQ08228.1', 'OLQ08510.1', 'OLQ11720.1', 'OLQ12045.1',
       'OLQ14344.1']

euk_queries_test3 = ['XP_008911403.1', 'XP_011408184.1', 'XP_002681038.1', 'XP_002673113.1',
                     'OAD09041.1', 'XP_001634466.1', 'XP_005765180.1', 'XP_011407364.1', 
                     'XP_005789988.1', 'KAA6344160.1', 'KAA6409619.1', 'XP_002287408.1',
                     'OAE21175.1', 'RKP17192.1', 'XP_013760427.1', 'KAA0163767.1',
                     'XP_002119908.1', 'XP_009692086.1']

euk_queries_test4 = ['NP_001002332.1', 'NP_001240313.1',
       'NP_001259573.1', 'NP_001260847.1', 
       'NP_001278869.1', 'NP_001307724.1', 'NP_001307934.1', 'NP_001334755.1',
       'NP_001356620.1', 'NP_115888.1', 
       'NP_610753.1', 'NP_611238.2', 'NP_956312.1',
       'NP_998197.1', 'NP_998403.1', 'XP_005256905.1',  'XP_017206845.1', 'XP_021326060.1', 'XP_021336265.1']

euk_queries_test5 = ['AGK83073.1', 'CBN73833.1', 'CBN79086.1', 'CEM35385.1', 'CEO94447.1', 'CEP00213.1', 'CEP03651.1', 'EPZ30938.1', 'EPZ31301.1', 'GBG60132.1', 
'GBG70565.1', 'GBG80562.1', 'GBG83744.1', 'KAA0165271.1', 'KAA0172078.1', 'KAA6408708.1', 'NP_001002332.1', 'NP_001259573.1', 'NP_001294564.1',
 'NP_001307724.1', 'NP_001328712.1', 'NP_001334755.1', 'NP_001356620.1', 'NP_011081.1', 'NP_050092.1', 'NP_594946.1', 'NP_610753.1', 'NP_848958.1',
  'NP_849074.1', 'NP_956312.1', 'NP_998197.1', 'OAD04802.1', 'OAD06369.1', 'OAE33370.1', 'OLP84660.1', 'OLQ06972.1', 'OLQ08228.1', 'OLQ14344.1',
   'OSX69435.1', 'OSX71470.1', 'OSX72678.1', 'OSX75094.1', 'OSX77054.1', 'PTQ50428.1', 'PXF41822.1', 'PXF45288.1', 'RKP17849.1', 'RKP18091.1', 
   'RKP20265.1', 'RWR93989.1', 'RWR97906.1', 'RWR98344.1', 'SLM34047.1', 'SLM40311.1', 'SLM40671.1', 'SPQ96285.1', 'SPQ98172.1']


#euk_clust = load_pkl(root+'analysis/core_data/euk72_filtered-prof-search-clust.pkl')['members']

In [None]:
#calculate cumulative sum and distribution for pd.Series
#takes pd.Series as input and returns a parsed DF and altair chart object 
def plot_cumsum_counts(series, title='Chart', x_label='value', y_label='count', 
                       x_min=0, y_min=0, x_max=None, y_max=None,
                       x_scale_type='log', y_scale_type='log', decimals=2):
    
    #format DF for data handling, filter 0 values for plot 
    #round to reduce float data display jaggedness
    series = series[series!=0].round(decimals)
    
    #format distribution dataframe
    countDF = pd.DataFrame(series.value_counts())
    countDF.columns = ['amount']
    countDF.sort_index(inplace=True)
    countDF['cumsum'] = countDF['amount'].cumsum()
    countDF['frac_cumsum'] = countDF['cumsum']/countDF['cumsum'].max()
    countDF.reset_index(inplace=True)
    
    #rename columns for plotting
    countDF.columns = [x_label,y_label,'cumsum','frac_cumsum']

    #format axis domains
    x_range = [x_min, series.max()]
    y_range = [y_min, countDF[y_label].max()]
    
    if x_max:
        x_range = [x_min, x_max]
    
    if y_max:
        y_range = [y_min, y_max]
        
    #plot cumulative distribution
    chart_cumsum = alt.Chart(countDF, title=title).mark_line(color=colorlib['twilight_shifted_r_perm'][2],
                                              strokeWidth=3).encode(
        x=alt.X(x_label, title=x_label, scale=alt.Scale(type=x_scale_type)),
        y=alt.Y('frac_cumsum', title='Cumulative Fraction', scale=alt.Scale(domain=[0,1]), axis=alt.Axis(labelAlign='left')),
        tooltip=alt.Tooltip([x_label, y_label, 'frac_cumsum'])
    )
    
    #plot value distribution
    chart_bar = alt.Chart(countDF).mark_area(interpolate='step-after', 
                                            fillOpacity=0.2, line=True).encode(
        x=alt.X(x_label+':Q', scale=alt.Scale(domain=x_range, type=x_scale_type)),
        y=alt.Y(y_label, scale=alt.Scale(domain=y_range, type=y_scale_type)),
        tooltip=alt.Tooltip([x_label, y_label, 'frac_cumsum'])
    )

    #merge and configure
    merge = alt.layer(chart_bar, chart_cumsum).resolve_scale(y='independent').interactive()

    return countDF, merge




In [None]:
#helper IO functions for basic fasta reading into {id:seq} dict
#id taken as string up until first space character
def fasta_to_dict(fastastring=None, file=None):
    
    if file != None:
        with open(file, 'r') as fastafile:
            fastastring = fastafile.read()
    
    entries = [entry.strip() for entry in ''.join(fastastring).split('>') if entry != '']
    fastas = {entry.split('\n')[0].split(' ')[0]: ''.join(entry.split('\n')[1:]) for entry in entries}
    
    #replace unknown characters with A
    blacklist = set('BJUXZ')
    for key, seq in fastas.items():
        
        for char in set(seq):
            if char in blacklist:
                seq = seq.replace(char, 'A')
                print(f'WARNING: {key} Replaced illegal {char} with A')
                
        fastas[key] = seq
    return fastas

#returns simple single line fasta from {id:seq} dict
def dict_to_fasta(seq_dict, write_file=False):
    
    fasta_str =  '\n'.join(f'>{key}\n{value}' for key, value in seq_dict.items())
    
    if write_file != False:
        with open(write_file, 'w') as outfile:
            outfile.write(fasta_str)
        print(f'Wrote {write_file}')
            
    return fasta_str

#helper command for submitting subprocesses to the current shell via subprocesscan accept piped data
#equivalent to "$ command < stdin > stdout" if given a stdin string
#command string should be single space delimited unix command ex "clustalo -i test -o test.clu"
def run_subprocess_with_stdin(command_str, stdin_str=''):
    command = command_str.split(' ')
    command_process = subprocess.Popen(command, stdin=subprocess.PIPE, 
                                       stdout=subprocess.PIPE, 
                                       stderr=subprocess.PIPE)
    
    #input and output is taken as bytes so encode and decode inputs
    stdout, stderr = command_process.communicate(input=stdin_str.encode('utf-8'))
    
    #hhconsensus occationally produces additional bytes at the end of the consensus sequence
    #errors=ignore should not write these but ocasionally do?
    #these appeat not to affect downstream processing but are wierd
    return stdout.decode('utf-8',errors="ignore"), stderr.decode('utf-8',errors="ignore")


#iterate over seq and create a boolean list of insertions as False
#apply_matchlist then replaces false with '-' and True from sample string at index
def calculate_matchlist(seq, ref_seq):
    ref_i = 0
    max_i = len(ref_seq)-1
    matchlist = []

    for i, c in enumerate(seq):
        ref_i = min(ref_i, max_i)

        if c == ref_seq[ref_i] and ref_seq[ref_i] != '-':
            matchlist.append(True)
            ref_i += 1

        else:
            matchlist.append(False)

    return matchlist

#take list of gaps as [True, False, ...] where False indicate gaps
#replace matches (True) with seq from sequence
def apply_matchlist(seq, matchlist):
    
    ref_i = 0
    new_seq_list = matchlist.copy()

    for i, b in enumerate(new_seq_list):
        if b:
            new_seq_list[i] = seq[ref_i]
            ref_i += 1
        else:
            new_seq_list[i] = '-'

    return ''.join(new_seq_list)

#define the entropy for a string given amino acids, protein=True or, DNA protein=False
def column_entropy(string, protein=True, gaptoken='-'):
    
    size = len(string)
    counts = [string.count(i) for i in set(string).difference({gaptoken})]
    entropy = -sum([i/size*math.log2(i/size) for i in counts])
    
    if protein:
        entropy_uniform = math.log2(20)
    else:
        entropy_uniform = 2
        
    gap_entropy = entropy_uniform*(string.count(gaptoken)/size)
    information = entropy_uniform - entropy - gap_entropy
    
    return max(information,0)



#columnwise cut based on criteria
def filter_by_entropy(seq_dict, entropy_min, filter_accs=[], gaptoken='-'):
        
    #transpose seqs into columns
    cols = [''.join(seq) for seq in list(zip(*seq_dict.values()))]
    
    #filter only by columns present in filter_accs keys
    if filter_accs:
        filter_dict = {key:value for key, value in seq_dict.items() if key in filter_accs}
        filter_cols = [''.join(seq) for seq in list(zip(*filter_dict.values()))]
    
    else:
        filter_cols = cols
    
    #include based on entropy threshold
    filter_cols = [col for i, col in enumerate(cols) if column_entropy(filter_cols[i]) > entropy_min]

    #transpose back to alignment
    filter_aln = [''.join(col) for col in list(zip(*filter_cols))]
    
    #return with original keys
    return {key: value for key, value in zip(seq_dict.keys(), filter_aln)}



#input helper to read tsv from file, merge singletons
def read_cluster_tsv(cluster_file, split_large=False, max_size=500, batch_single=False, single_cutoff=1):   

    #read TSV and group clusters based on first tsv column
    with open(cluster_file, 'r') as infile:
        clusters = {}

        for l in infile.readlines():
            cluster_acc, acc = l.strip().split('\t')

            if cluster_acc not in clusters.keys():
                clusters[cluster_acc] = [acc]

            else:
                clusters[cluster_acc].append(acc)
    
    #merge all clusters smaller than cutoff into one
    if batch_single:
        filter_dict = {}
        singles = []
        for key, accs in clusters.items():  
            #gather singletons
            if len(accs) <= single_cutoff:
                singles.extend(accs)
            
            #keep larger clusters
            else:
                filter_dict[key] = accs
                
        if singles:
            filter_dict[singles[0]] = singles
        
        clusters = filter_dict
    
    #split clusters larger than x into smaller pieces
    if split_large:
        filter_dict = {}
        for key, accs in clusters.items():  
            if len(accs) > max_size:
                #partition large clusters into batches of max_size
                for split in range(0, len(accs), max_size):
                    batch = accs[split:split + max_size]
                    filter_dict[batch[0]] = batch

            #keep smaller clusters
            else:
                filter_dict[key] = accs
                
        clusters = filter_dict         
        
    return clusters

In [None]:
#wrapper function for consensus alignment to track thread numbering etc.
def pool_align_with_consensus(run_seq_dict, return_consensus=True):
    
    thread = multiprocessing.current_process().pid
    ref_seq = list(run_seq_dict.keys())[0]
    #print(f'{thread}: ref_seq = {ref_seq} started with {len(run_seq_dict)} sequences\n', end='')
    
    #filter seqs from 
    run_seq_dict_consensus = align_with_consensus(run_seq_dict, return_consensus)
    
    #print(f'{thread} finished\n', end='')
    
    #return dict with key = first sequence id from run_seq_dict
    return {ref_seq: run_seq_dict_consensus}


#wrapper for clustalo alignment from stdin piped to hhconsensus capturing stdout
#avoids file generation
def align_with_consensus(seq_dict, return_consensus=True):
    
    #format fasta-like string from seqs_dict
    aligner_in = '\n'.join(f'>{key}\n{value}' for key, value in seq_dict.items())
    
    aligner_command = 'clustalo -i - --threads 1'
    #aligner_command = 'mafft --quiet --auto /dev/stdin'
    hhconsensus_command = 'hhconsensus -i stdin -o stdout'

    print(f'Aligning {len(seq_dict.keys())} seqs\n', end='')

    aligner_stdout, aligner_stderr = run_subprocess_with_stdin(aligner_command, aligner_in)
    
    if return_consensus:
        hhcons_stdout, hhcons_stderr = run_subprocess_with_stdin(hhconsensus_command, aligner_stdout)
        hhcons_fasta = '\n'.join(['>consensus']+hhcons_stdout.split('\n')[2:])
        seq_dict = fasta_to_dict(hhcons_fasta)
    
    else:
        seq_dict = fasta_to_dict(aligner_stdout)
    
    return seq_dict





def hierarchical_realignment(seq_fasta, cluster_file, filter_entropy=False, parallel_n=None):
    
    thread = multiprocessing.current_process().pid
    threadID_string = f'{thread} | {seq_fasta.split("/")[-1]} alignment:'
    
    aligner_threads = 1
    if parallel_n != None:
        aligner_threads = parallel_n
        
    with open(seq_fasta, 'r') as infile:
        seq_dict = fasta_to_dict(''.join(infile.readlines()))

    #split larger clusters to avoid large realignments
    max_size = 200
    singleton_threshold = 10

    clusters = read_cluster_tsv(cluster_file, split_large=True, max_size=max_size,
                                batch_single=True, single_cutoff=singleton_threshold)

    #in some edge cases partitioning and merginf clusters froms a single sequence non-singleton cluster
    #add all clusters larger than one to realignment list
    seq_dicts = [{key:seq_dict[key] for key in seqs} for seqs in clusters.values() if len(seqs) > 1]

    
    #add those lonley edge sequences to a separate dict for merger later
    singles_dict = {seq[0]:seq_dict[seq[0]] for seq in clusters.values() if len(seq) == 1}
    
    print(f'{threadID_string} realigning a total of {len(seq_dicts)} clusters')
    
    
    #return seq_dicts, seq_dicts2, singles_dict
    
    if parallel_n != None:
        #multiprocessing
        print(f'{threadID_string} Cluster alignment')
        with multiprocessing.Pool(processes=parallel_n) as pool:
            cluster_alignments_stream = pool.map(pool_align_with_consensus, seq_dicts)
    
    else:
        #serial implementation
        cluster_alignments_stream = []
        for seq_dict in seq_dicts:
            cluster_alignments_stream.append(pool_align_with_consensus(seq_dict))
    
    #merge and order results from stream 
    cluster_alignments = {key: value for result in cluster_alignments_stream for key, value in result.items()}
    
    #merge consensus sequences with singletons and realign
    consensus_dict = {key: value['consensus'] for key, value in cluster_alignments.items()}
    print(f'{threadID_string} Running consensus alignment of length {len(consensus_dict)}')
    
    outfile = dict_to_fasta(consensus_dict, write_file=f'{seq_fasta}.cons.fasta')
    
    #add edge case singles as consensus sequences
    consensus_dict.update(singles_dict)

    aligner_command = f'clustalo -i - --threads {aligner_threads}'
    aligner_in = dict_to_fasta(consensus_dict)

    aligner_stdout, aligner_stderr = run_subprocess_with_stdin(aligner_command, aligner_in)
    singles_consensus_alignment_dict = fasta_to_dict(aligner_stdout)
    
    outfile = dict_to_fasta(singles_consensus_alignment_dict, write_file=f'{seq_fasta}.cons.clu')
    
    print(f'{threadID_string} Recombining cluster and consensus alignments')
    #replace consensus sequences with original aligned sequences
    for cluster_acc, cluster_alignment in cluster_alignments.items():

        #set the reference and realigned consensus sequences
        realigned_consensus_seq = singles_consensus_alignment_dict[cluster_acc]
        consensus_seq = cluster_alignment['consensus']

        #calculate the matches
        matchlist = calculate_matchlist(realigned_consensus_seq, consensus_seq)

        #for each cluster member apply the transformation to yield "realigned" version
        #and add the realigned verison to the singles dictionary
        for acc, alignment in cluster_alignment.items():
            if acc != 'consensus':
                new_alignment = apply_matchlist(alignment, matchlist)

                #consensus sequences are replaced by their realigned versions of the same name
                singles_consensus_alignment_dict[acc] = new_alignment
    
    #delete columns with entropy lower than the filter level
    if filter_entropy != None:
        print(f'{threadID_string} Filtering alignment by bitscore > {filter_entropy}')
        singles_consensus_alignment_dict = filter_by_entropy(singles_consensus_alignment_dict, filter_entropy)
        
    #print(f'{threadID_string} finished initial profile creation, wrote {seq_fasta}.haln.clu file')
    with open(seq_fasta+'.haln', 'w') as out:
        out.write(dict_to_fasta(singles_consensus_alignment_dict))
        
    #realign all sequences to merged alignment profile
    #print(f'{threadID_string} Serially realigning all sequences against profile')    
    
    #using mafft --add --keeplength to seriall trealign all sequences individually to combined consensus 
    #aligner_command = f'mafft --anysymbol --thread {aligner_threads} --keeplength --quiet --add {seq_fasta} {seq_fasta}.haln > {seq_fasta}.haln.mafft-add'
    #os.system(aligner_command)
    #with open(f'{seq_fasta}.haln.mafft-add', 'r') as infile:
    #    aligner_stdout = ''.join(infile.readlines())
    #returned file contains duplicate sequences from profile, take last half
    #aligner_stdout = '>'.join(aligner_stdout.split('>')[len(singles_consensus_alignment_dict)-1:])   
    
    #singles_consensus_alignment_dict = fasta_to_dict(aligner_stdout)
    
    #delete columns with entropy lower than the filter level
    #if filter_entropy != None:
    #    print(f'{threadID_string} Filtering final alignment by bitscore > {filter_entropy}')
    #    singles_consensus_alignment_dict = filter_by_entropy(singles_consensus_alignment_dict, filter_entropy)

        
    #write final output alignment
    #print(f'{threadID_string} Finished final profile creation, wrote {seq_fasta}.haln file')
    #with open(seq_fasta+'.haln', 'w') as out:
    #    out.write(dict_to_fasta(singles_consensus_alignment_dict))
        
    
    print(f'{threadID_string} DONE!')
    
    

In [None]:
#wrapper function for target reclustering which required a temporary accession file to be written
def mmseqs_run_target(target, members, root, seq_DB, threads=1):
    
    #boilerplate for informative print()
    thread = multiprocessing.current_process().pid
    threadID_string = f'{thread} | {target}:'
    print(f'{threadID_string} Preparing mmseqs data for target {target}\n', end='')
    
    basename = f'{root+target}'
    
    #write all accessstions to temporary file
    with open(basename, 'w') as outfile:
        outfile.writelines([acc+'\n' for acc in members])
    
    #create a temporary seqDB, extract fastas, cluster and calculate cluster.tsv file
    os.system(f'mmseqs createsubdb -v 0 --id-mode 1 {basename} {seq_DB} {basename}.DB')
    os.system(f'mmseqs convert2fasta -v 0 {basename}.DB {basename}.fasta')
    os.system(f'mmseqs cluster -v 0 --remove-tmp-files 1 --threads {threads} -s 7.5 {basename}.DB {basename}.cluster {root}tmp/{target}')
    os.system(f'mmseqs createtsv -v 0 {seq_DB} {seq_DB} {basename}.cluster {basename}.cluster.tsv')
    
    #clean 
    os.system(f'mmseqs rmdb {basename}.DB -v 0')
    os.system(f'mmseqs rmdb {basename}.cluster -v 0')
    os.system(f'rm {basename}')
    
    return

#main script for formatting query and target .fasta and .cluster.tsv files
def microcosm_prepare_mmseqs(query, query_DB, target_DB, root):
    
    #configure paths and flags
    thread = multiprocessing.current_process().pid
    query_root = root+query+'/'
    basename = query_root+query
    threadID_string = f'{thread} | {query}:'
    os.system(f'mkdir {query_root}/tmp')
    
    #parse target clusters
    target_clusters = read_cluster_tsv(f'{basename}.targets')  

    print(f'{threadID_string} Started \n', end='')
    print(f'{threadID_string} Preparing mmseqs data for query\n', end='')

    #createa new DB for the query sequences and cluster it
    os.system(f"mmseqs createsubdb -v 0 --id-mode 1 --subdb-mode 1 {basename}.acc {query_DB} {basename}.DB")
    os.system(f'mmseqs convert2fasta -v 0 {basename}.DB {basename}.fasta')
    os.system(f'mmseqs cluster -v 0 --remove-tmp-files 1 --threads 1 -s 7.5 {basename}.DB {basename}.cluster {query_root}/tmp/{query}')
    os.system(f'mmseqs createtsv -v 0 {query_DB} {query_DB} {basename}.cluster {basename}.cluster.tsv')
    
    print(f'{threadID_string} Preparing mmseqs data for merged target hits\n', end='')

    #create a subDB and extract sequences, create a new seqDB in order to recreate the header lookup
    #otherwise each query using header info searches the entire original header lookup
    with open(f'{basename}.members', 'w') as outfile:
        for members in target_clusters.values():
            outfile.writelines([member+'\n' for member in members])
    
    os.system(f"mmseqs createsubdb -v 0 --id-mode 1 --subdb-mode 1 {basename}.members {target_DB} {basename}.members.DB")
    os.system(f"mmseqs convert2fasta -v 0 {basename}.members.DB {basename}.members.fasta")
    os.system(f"mmseqs createdb -v 0 --createdb-mode 1 {basename}.members.fasta {basename}.members.DB")
    os.system(f'mmseqs cluster -v 0 --remove-tmp-files 1 --threads 1 -s 7.5 {basename}.members.DB {basename}.members.cluster {query_root}/tmp/{query}')
    os.system(f'mmseqs createtsv -v 0 {basename}.members.DB {basename}.members.DB {basename}.members.cluster {basename}.members.cluster.tsv')

    
    #clean
    os.system(f'mmseqs rmdb -v 0 {basename}.DB')
    os.system(f'mmseqs rmdb -v 0 {basename}.cluster')
    
    #temporary block for testing
    return

    #multithreads the individual DB creation and clustering
    with multiprocessing.Pool(processes=8) as pool:
        
        targets = target_clusters.keys()
        members = target_clusters.values()
        root = query_root
        seq_DB = f'{basename}.members.DB'
        submit_vars = zip(list(targets), list(members), [root]*len(targets), [seq_DB]*len(targets))
        
        mmseqs_target_runs = pool.starmap(mmseqs_run_target, submit_vars)
    
    #final clean
    os.system(f'mmseqs rmdb -v 0 {basename}.members.DB')
    os.system(f'find {query_root} -maxdepth 1 -type l -exec unlink {{}} \;')
    os.system(f'rm -r *members* {query_root}/tmp')
    
    #perform hierachical alignment of query and all clusters > 10 

    
    

In [None]:
#wrapper function to dictate alignment type based on cluster members
def choose_alignment_type(seq_fasta, cluster_file, filter_entropy=False, parallel_n=None):
    
    hierarchical_threshold = 10
    
    thread = multiprocessing.current_process().pid
    threadID_string = f'{thread} | {seq_fasta.split("/")[-1]} alignment:'

    with open(seq_fasta, 'r') as infile:
        seq_dict = fasta_to_dict(''.join(infile.readlines()))

    seq_members = len(seq_dict.keys())

    if seq_members > hierarchical_threshold:
        print(f'{threadID_string} Aligning heirarchically with {seq_members} sequences\n', end='')
        hierarchical_realignment(seq_fasta, cluster_file, filter_entropy, parallel_n)

    else:
        print(f'{threadID_string} Aligning naively with {seq_members} sequences\n', end='')
        os.system(f'clustalo --in {basename}.fasta --out {basename}.fasta.haln --iter 2')
        
#main realignment script for taking a directory with cluster and target .fastas and .clsuter.tsv files
#realigns all files either heirarchically or naively using clustal depending on cluster member number
def microcosm_perform_haln(query, entropy_filter, parallel_n, root):
    
    #configure paths and flags
    thread = multiprocessing.current_process().pid
    threadID_string = f'{thread} | {query}:'
    
    query_root = root+query+'/'
    basename = query_root+query
    
    os.system(f'mkdir {query_root}/tmp')
    
    #parse target clusters
    target_clusters = read_cluster_tsv(f'{basename}.targets')  
    
    
    all_clusters = [query]+list(target_clusters.keys())

    with multiprocessing.Pool(processes=8) as pool:
        
        targets = [query_root+cluster+'.fasta' for cluster in all_clusters]
        target_clusters = [query_root+cluster+'.cluster.tsv' for cluster in all_clusters]
        entropy_filter = 0
        parallel_n = None
        submit_vars = zip(targets, target_clusters, [entropy_filter]*len(targets), [parallel_n]*len(targets))

        mmseqs_target_runs = pool.starmap(choose_alignment_type, submit_vars)
    

In [None]:

def pool_cluster_alignment(run_seq_dict, filter_entropy=None, write_cluster_align=False, cluster_align_root=''):
    
    thread = multiprocessing.current_process().pid
    ref_seq = list(run_seq_dict.keys())[0]
    
    threadID_string = f'{thread} | {ref_seq} alignment:'
    
    print(f'{threadID_string} ref_seq = {ref_seq} started with {len(run_seq_dict)} sequences\n', end='')
    
    #filter seqs from 
    seq_dict = align_with_consensus(run_seq_dict, return_consensus=False)
    
    if filter_entropy != None:
        print(f'{threadID_string} Filtering alignment by bitscore > {filter_entropy}')
        seq_dict = filter_by_entropy(seq_dict, filter_entropy)
        
    if write_cluster_align:
        print(f'{threadID_string} Writing alignment to file {ref_seq}.cluster.clu')
        with open(f'{cluster_align_root}{ref_seq}.cluster.clu', 'w') as out:
            out.write(dict_to_fasta(seq_dict))
    
    return {ref_seq: seq_dict}


#for each cluster in a .tsv realign alll clusters individually by clustalo, filter by bitscore, write to file 
def cluster_realignment(file_root, seq_fasta, cluster_file, filter_entropy=False, parallel_n=None):
    
    thread = multiprocessing.current_process().pid
    threadID_string = f'{thread} | {seq_fasta.split("/")[-1]} alignment:'
    
    print(f'{threadID_string} started')

    aligner_threads = 1
    if parallel_n != None:
        aligner_threads = parallel_n
        
    with open(seq_fasta, 'r') as infile:
        seq_dict = fasta_to_dict(''.join(infile.readlines()))

    #split larger clusters to avoid large realignments
    max_size = 200
    singleton_threshold = 10

    clusters = read_cluster_tsv(cluster_file, split_large=True, max_size=max_size,
                                batch_single=True, single_cutoff=singleton_threshold)

    #in some edge cases partitioning and merginf clusters froms a single sequence non-singleton cluster
    #add all clusters larger than one to realignment list
    seq_dicts = [{key:seq_dict[key] for key in seqs} for seqs in clusters.values() if len(seqs) > 1]

    
    #add those lonley edge sequences to a separate dict for merger later
    singles_dict = {seq[0]:seq_dict[seq[0]] for seq in clusters.values() if len(seq) == 1}
    
    print(f'{threadID_string} realigning a total of {len(seq_dicts)} clusters')
    
    if parallel_n != None:
        #multiprocessing
        print(f'{threadID_string} aligning parallel with {parallel_n} threads')
        with multiprocessing.Pool(processes=parallel_n) as pool:
            
            #do not return consensus sequences for any subalignment 
            filter_e = [filter_entropy for _ in seq_dicts]
            write_cluster_align = [True for _ in seq_dicts]
            cluster_align_root = [file_root for root in seq_dicts]
            submit_vars = zip(seq_dicts, filter_e, write_cluster_align, cluster_align_root)
            
            cluster_alignments_stream = pool.starmap(pool_cluster_alignment, submit_vars)
    
    else:
        #serial implementation
        cluster_alignments_stream = []
        for seq_dict in seq_dicts:
            cluster_alignments_stream.append(pool_align_with_consensus(seq_dict))
    
    print(f'{threadID_string} merging alignments and writing filtered files')
    
    #merge and order results from stream 
    cluster_alignments = {key: value for result in cluster_alignments_stream for key, value in result.items()}
    
    return cluster_alignments
    
    #delete columns with entropy lower than the filter level
    if filter_entropy != None:

        for key, alignment in cluster_alignments.items(): 
            cluster_alignments[key] = filter_by_entropy(alignment, filter_entropy)
    
    #write individual files for all realignments
    for key, alignment in cluster_alignments.items():
        with open(f'{seq_fasta}.{key}.clu', 'w') as out:
            out.write(dict_to_fasta(alignment))
    

In [None]:
#prepare one mmseqs

microcosm_prepare_mmseqs('OAE21175.1', 'euk72/euk72', 'prok2111/prok2111', 'microcosm2/')

In [None]:
#prepare all mmseqs

with multiprocessing.Pool(processes=4) as pool:
    #format starmap input
    queries = euk_queries_test5
    query_DB = 'euk72/euk72'
    hit_DB = 'prok2111/prok2111'
    root = 'microcosm4/'
    submit_vars = zip(queries, [query_DB]*len(queries), [hit_DB]*len(queries), [root]*len(queries))

    cluster_alignment_stream = pool.starmap(microcosm_prepare_mmseqs, submit_vars)

In [None]:
#one cluster alignment

seq_dicts = cluster_realignment(f'microcosm2/OAE21175.1/', 
                                         'microcosm2/OAE21175.1/OAE21175.1.members.fasta',
                                         'microcosm2/OAE21175.1/OAE21175.1.members.cluster.tsv',
                                         filter_entropy=0, parallel_n=16)


In [None]:
%%time
#all cluster alignment
root = 'microcosm4/'

for query in euk_queries_test5:
    file_root = f'{root}{query}/{query}'
    cluster_file = file_root + '.members.cluster.tsv'
    seq_fasta = file_root + '.members.fasta' 
    cluster_alignments = cluster_realignment(f'{root}{query}/', seq_fasta, cluster_file, filter_entropy=0, parallel_n=16)

In [None]:
#one haln alignment
seq_dicts = hierarchical_realignment('microcosm3/NP_001002332.1/NP_001002332.1.members.fasta',
                                     'microcosm3/NP_001002332.1/NP_001002332.1.members.cluster.tsv',
                                     filter_entropy=0, parallel_n=16)

In [None]:
%%time
#all haln alignment
root = 'microcosm4/'

for query in euk_queries_test5[38:]:
    file_root = f'{root}{query}/{query}'
    seq_fasta = file_root + '.members.fasta' 
    cluster_file = file_root + '.members.cluster.tsv'
    cluster_alignments = hierarchical_realignment(seq_fasta, cluster_file, filter_entropy=0, parallel_n=16)

In [None]:
euk_queries_test5[38:]

In [None]:
#filter by columnwise bitscore

#files = !find microcosm4/* -name '*'
files = ['microcosm4/AGK83073.1/test/AGK83073.1.members.fasta.famsa3']

for file in files:
    #euk_acc_file = file.split('.cluster.clu.euk')[0]+'.acc'
    with open(file, 'r') as infile:#, open(euk_acc_file,'r') as accfile:
        #euk_accs = accfile.read().split()
        aln = fasta_to_dict(infile.readlines())

    aln_filter = filter_by_entropy(aln, 0.5, gaptoken='-')#,filter_accs=euk_accs)
    dict_to_fasta(aln_filter, write_file=file+'.fil')

In [None]:
t = hierarchical_realignment('microcosm3/XP_005256905.1/XP_005256905.1.members.fasta',
                                     'microcosm3/XP_005256905.1/XP_005256905.1.members.cluster.tsv',
                                     filter_entropy=0, parallel_n=16)

In [None]:
file = 'microcosm4/AGK83073.1/AGK83073.1.members.fasta.haln.euk'

with open(file, 'r') as infile:
    aln = fasta_to_dict(infile.readlines())

euk_accs = open('microcosm4/AGK83073.1/AGK83073.1.acc', 'r').read().split()

In [None]:
files = !find microcosm/*/merged.fasta.muscle
for file in files:
    with open(file, 'r') as infile:
        aln = fasta_to_dict(infile.readlines())
        
    aln_filter = filter_by_entropy(aln, 0.5)
    dict_to_fasta(aln_filter, write_file=file+'.fil')

In [None]:
cluster_file = 'microcosm3/XP_005256905.1/XP_005256905.1.members.cluster.tsv'
#split larger clusters to avoid large realignments
max_size = 200
singleton_threshold = 10

clusters = read_cluster_tsv(cluster_file, 
                            split_large=False, max_size=max_size, 
                            batch_single=False, single_cutoff=singleton_threshold)

In [None]:
files = !find ./microcosm2/ -name '*members.fasta'
roots = ['.'.join(f.split('.')[:-2]) for f in files]
for root in roots:
    seq_file = root+'.members.fasta'
    cluster_tsv = root+'.members.cluster.tsv'
    seq_dicts = hierarchical_realignment(seq_file, cluster_tsv, filter_entropy=0, parallel_n=16)

In [None]:
!cd ..

In [None]:
from core_functions.altair_plots import plot_alignment