In [9]:
import sys
sys.path.append('/home/dmoi/projects/foldtree2')

import torch
import torch.nn.functional as F
from Bio import Phylo, AlignIO, SeqIO
from scipy.special import gamma as gamma_function
import numpy as np
import dendropy
from dendropy import StandardCharacterMatrix
from matplotlib import pyplot as plt
import glob
from dendropy.datamodel import charmatrixmodel 

datadir = '/home/dmoi/datasets/'
families = glob.glob( datadir + 'afdbclusters/structfams/*/')

alphabet = """0 1 2 3 4 5 6 7 8 9 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z ! " # $ % & ' ( ) * + , / : ; < = > @ [ \ ] ^ _ { | } ~""".split()
alphamap = { a: i for i, a in enumerate(alphabet) } 

state_alphabet = dendropy.datamodel.charstatemodel.StateAlphabet(fundamental_states=alphabet, 
ambiguous_states=None, polymorphic_states=None, symbol_synonyms=None, 
no_data_symbol=None, gap_symbol='-', label=None, case_sensitive=True)

data_type = dendropy.datamodel.charmatrixmodel.CharacterType(
        state_alphabet=state_alphabet ) 

def read_tree(tree_file):
    tree = Tree.get(path=tree_file, schema='newick')
    return tree

def read_msa(msa_file, format='fasta'):
    msa = StandardCharacterMatrix.get(
        path=msa_file,
        schema=format,
        data_type='standard',
        default_state_alphabet=state_alphabet,
    )
    return msa

def msa2array(msa):
    #use biopython to read the alignment
    msa = AlignIO.read(msa, 'fasta')
    index = {seq.id: i for i, seq in enumerate(msa)}
    index_rev = { i: seq.id for i, seq in enumerate(msa)}
    return index, index_rev, np.array([list(rec) for rec in msa], np.character)

class msaarray:
    def __init__(self, msa):
        self.index, self.array = msa2array(msa)
        self.n, self.L = self.array.shape
        self.alphabet = np.unique(self.array)
        self.alphabet_size = len(self.alphabet)

    def __getitem__(self, i):
        if type(i) is slice:
            return self.array[i]
        else:
            return self.array[self.index[i]]

    def __len__(self):
        return self.n

    def __iter__(self):
        for i in self.index:
            yield i


def read_seq(seq_file, format='fasta'):
    print(seq_file)
    try:
        return SeqIO.read(seq_file, format)
    except:
        print('Error reading sequence file')
        print('Trying to read as a list of sequences')
        return [ s for s in SeqIO.parse(seq_file, format) ]


  alphabet = """0 1 2 3 4 5 6 7 8 9 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z ! " # $ % & ' ( ) * + , / : ; < = > @ [ \ ] ^ _ { | } ~""".split()


In [10]:
#import the emperically observed charfreqs
import pickle
with open( datadir + 'afdbclusters/charfreqs.pkl' , 'rb') as f:
    freqs = pickle.load(f)

In [11]:
#alignments should be in the encoded alphabet but come from foldmason alignment logic. assumed to be true
#trees are from FT1. assumed to be true
res = {}
import glob
import os
import pandas as pd
import tqdm

dfs = []
def fasta2df(fasta):
    seqs = {}
    s = read_seq(fasta)
    for i, seq in enumerate(s):
        seqs[i] = {'id': seq.id, 'seq': str(seq.seq)}    
    return pd.DataFrame.from_dict(seqs, orient='index')
def df2fasta(df, fasta , key = 'seq'):
    with open(fasta, 'w') as f:
        for i, row in df.iterrows():
            f.write('>' + row['id'] + '\n')
            f.write(row[key] + '\n')
def copy_aln(row):
    encoded = list(row['encoded'])
    foldmason = list(row['foldmason'])
    for i, c in enumerate(foldmason):
        if c == '-':
            encoded.insert(i, '-')
    return ''.join(encoded)


for f in tqdm.tqdm( families):
    if os.path.exists(f+'fident_distmat.txt_tree.txt') and os.path.exists(f+'foldmason.fasta_aa.fa') and os.path.exists(f+'encoded.fasta'):
        #read encoded, foldtree tree, and foldmason msa
        encoded = fasta2df(f+'encoded.fasta')
        msa = fasta2df(f+'foldmason.fasta_aa.fa')
        #merge encoded and msa
        merged = pd.merge(encoded, msa, on='id')
        merged.columns = ['id', 'encoded', 'foldmason']
        merged['family'] = f.split('/')[-2]
        #transfer gaps to encoded
        merged['encoded_aln_foldmason'] = merged.apply( copy_aln , axis=1)
        merged['aln_len'] = merged['encoded_aln_foldmason'].apply(len)
        print(merged.aln_len.value_counts())
        print(merged)
        #output encoded aln to fasta
        df2fasta(merged[['id', 'encoded_aln_foldmason']], f+'encoded_aln_foldmason.fasta', key = 'encoded_aln_foldmason')
        dfs.append( merged )

df = pd.concat(dfs)

 13%|▏| 13/101 [00:00<00:00, 126.94it/s

/home/dmoi/datasets/afdbclusters/structfams/A0A011N458/encoded.fasta
Error reading sequence file
Trying to read as a list of sequences
/home/dmoi/datasets/afdbclusters/structfams/A0A011N458/foldmason.fasta_aa.fa
Error reading sequence file
Trying to read as a list of sequences
aln_len
170    10
Name: count, dtype: int64
           id                                            encoded  \
0      G6F3X8  TT2+_S)}|2E]|_77]E4M+]<|8|4']S/VVSZKR2/IZ8I|||...   
1  A0A0H3KU55  DDG]9]9=/44XX999828UD9=8=KZ]9/:=9G=X2&:Z8IK:Z=...   
2      A9HPZ2  DDYE9]:R|:}PRF'/]RV&8|&:DJ|]:8J|6JRC$$}/FX2]PZ...   
3  A0A4P0Y7M4  TT9_ZE4]{RR_}]}RZ4{8]='SS/::Y}:RPC:Z8'==V|9:92...   
4  A0A535F069  TT2P+2|EE+_A38_77_}_R+8_{8M]'RE/]:7'KR'):ZK'|:...   
5  A0A378W409  DD:&~8P_B]8_RJ{WR@}:|K8DV'VR/V$&WVR2):ZR:|:)|S...   
6  A0A645JEF6  DD2=9C842]{E98{RS_}]8]]||8|:'8}/::}/KR+|<]JL|_...   
7  A0A7X6IPW2  ]488R4{{R_}P}_234{_R}{W{[,844CR_6JR/:|}Z\J2)1Z...   
8      E3T6K5  TT2{E9{{9S_9_}+8==8I4I&8/:V}9VR2/:Z}P||/]W:W]6...  

 40%|▍| 40/101 [00:00<00:00, 84.60it/s]

           id                                            encoded  \
0      K8AUN2  AT07P<<Z+8<00R/+MR/<11]#/R28/M8]04KM+||0L7<R+Z...   
1      U2YMB4  AA4R3MP<<]3%[0<Z//A>0PQ<]CR28L780+|Q/]|2M074P%...   
2  A0A009Q7Y6  ;;28(50HP<NZ+RM0(Z//A7<<|K]]L8728+M80P]</|]Q7N...   
3  A0A7U9NTI9  AA47+<<4+8]]]/+M)M{43]+]R|80M80+|0C]]%7<R+427]...   
4  A0A0A6D371  ;!FRJM(QP+N0Z20+1X/Z5XCQ(/2QP/X28Z5;Q+N0Z(28(0...   
5      K8AK78  AAZ[]N0]+Z%2N]RR<]2<]2</85R+((+(%ZNRQ80M80P4QH...   
6  A0A0F3N252  TT8H^+8(R0M+Q<]%2|<]]0]/>T]<MR+]]08M28+M8K+|</...   
7  A0A6N6X927  ]]RZ82]<]P807F71C15C2O1+Z/0;8C+N#8R+8QM8ZP|QC0...   
8  A0A7W4Z9R7  AA0HP%<]<<]0]HQ]/ZAP+<]+]7P8+M8/P<|C(L7M<P<P<P...   
9  A0A0N0M1K4  AA<M+P<++QP<(R//T/>7P+0K+8M/8L78)K|KK]L/8<0RP0...   

                                           foldmason      family  \
0  -----------MSIDGGQN----Y---G--SSVRNLVRGGG-TERV...  A0A009Q7Y6   
1  ---------MSVFLDGG------T---FAISGQRRMKSDAG--THV...  A0A009Q7Y6   
2  -------MPWTLSFDGGQNVL---------STQRRMIGGASTTE

 65%|▋| 66/101 [00:00<00:00, 103.77it/s

/home/dmoi/datasets/afdbclusters/structfams/A0A010RFJ8/encoded.fasta
Error reading sequence file
Trying to read as a list of sequences
/home/dmoi/datasets/afdbclusters/structfams/A0A010RFJ8/foldmason.fasta_aa.fa
Error reading sequence file
Trying to read as a list of sequences
aln_len
289    10
Name: count, dtype: int64
           id                                            encoded  \
0      C1GVW5  DD8:2Z9I_8I]9_P2C8ZP4_')}R8C}}Z]//C85RP9}/]2T2...   
1  A0A6J5W2Z7  TTR87)R<4M)R38)RCR'PRR)'[}PP/8[{+MR<]724+/RRRM...   
2  A0A010RFJ8  TT/842{Z4CZ2/]E+{SPT/24]{2+S/GG9ZR{9C2_944+29]...   
3  A0A0D3B9M2  TT_E9E9RT{RR{/]%%]22X5]XA42PNP[XP8[*[*#Z2%3PC2...   
4  A0A7V4P7R4  TTG]{{SRT{_{94R4{C6]+6ZA_ZWW:1JQ1\1W15{W{_PPPP...   
5  A0A835J6P6  TTP_S4/]P_T9P/SI2{2+/2_/2RGI]G9P4XX/*8*84]NX/P...   
6  A0A2D8HEC8  TT]2XZ{P2228=2|C:8/+)CG_)+Z+8RR%2R23N[PRN/8*/#...   
7      N1PIT3  TA]4/[824X{4/4]E+4'_MSR={S48@&]/D8K]X2/2{{P2_E...   
8  A0A258GZP4  TTG9{9{4{{9{{S9{42{E{ZZ224]]T{22R{Z_4R/RI_R={I...  

 93%|▉| 94/101 [00:00<00:00, 120.11it/s

aln_len
1410    10
Name: count, dtype: int64
           id                                            encoded  \
0  A0A1S9RPQ1  ;;5>2/Q(Q(/#5XF5P1^8+BQ5105PPZ^P((8(Z2WF26WZX2...   
1  A0A2N2Q006  A;Q^2C8/12X1/8QQ88/R%3]QN#Q5,Q0N/X,5X5/1XP8/8Z...   
2  A0A0B7KIB4  DT{52Z0Z2P8++5>X5Z/+/52]/{CE8+8}P']4{}]E+/+)K|...   
3  A0A5B5WZL5  TT9G]R{ER9_R]{58X4{{_{_8_6#%22X]//;X8+/18#B/5#...   
4  A0A2E9CV65  TT{]RT4RZ{{4/2_II/2%0^Q/5/F+B+\P8/BZ^1P%21QOL%...   
5  A0A6L6QDR5  TT/29]R9Y9Y899=_42{R4{9{4+_2{]4ZW1Z02/Q/6+B+\1...   
6  A0A385SNN1  TT4S/]E4993[8[392SRRH0/02/O(5+B+/P8>B/^1PCQ1QC...   
7      U3P295  ;;XX282^LF\5F1^Z/C#PCAR8//],#X2K181X101Z/X55+H...   
8  A0A6N3BSJ1  TTGG=&8&:88858_XX8_XZ{A+]PN/#RX+W+/68(8Z/1PC2<...   
9  A0A3N2MZL7  TT/9+&989Y9SS{45*80%0/R[P+B+/P8/8X51PR|1QLC2P@...   

                                           foldmason      family  \
0  -----------------------------------ML--KPR-ATA...  A0A010YC40   
1  ----------------------------------------------...  A0A010YC40   
2 

100%|█| 101/101 [00:00<00:00, 105.24it/

/home/dmoi/datasets/afdbclusters/structfams/A0A009EYQ5/encoded.fasta
Error reading sequence file
Trying to read as a list of sequences
/home/dmoi/datasets/afdbclusters/structfams/A0A009EYQ5/foldmason.fasta_aa.fa
Error reading sequence file
Trying to read as a list of sequences
aln_len
147    10
Name: count, dtype: int64
           id                                            encoded  \
0  A0A378X829  TTGS94]{99R{{]9R{_R|CQP|<2QR))00R<848ZZ<O)R)R]...   
1  A0A220S0G0  TTGGR4RS9E{{9{RT{_{{{Z{{24]0RZ8%8Z*NO/R]R080H+...   
2  A0A7X2GXR0  TDG>88XA{TR{RT{{4_{{Z%%<++2]R0))4R08/8PZ</]M)R...   
3  A0A496NBR5  TTGG499S{4RS{R{4T2M{C|P8|4P4/<<C<0]O]]O]R+8C8Z...   
4  A0A7H1MEK6  DDG+98]]99992{{{94_{{8{LLA2<2RROO(RRZ8R8ZZNOOR...   
5      Q2KZL1  TT{{C{CXPQPXZX%C8C8P%N/2,/XR8#,PC2/2LQCXC^CN/Q...   
6  A0A547ET74  TTSR4+9{C||<LZ4P%/L%CL0]R)]4]RO848Z+<OOM0R)8<H...   
7  A0A009EYQ5  AAO%37]PZZZ2K+ZZ+8RR+8O8P4+L+M4R]8CH++ZO]4%]R]...   
8      C6M3X3  TDPZQO8Z(8/8OP+N0X0X/8B>(#PNQQQZX/,ONQ#BRQ2,02...  




In [12]:
#load the families
res = {}
datadir = '/home/dmoi/datasets/'
families = glob.glob( datadir + 'afdbclusters/structfams/*/')
for f in tqdm.tqdm( families):
    res[f.split('/')[-2]] = {}
    res[f.split('/')[-2]]['tree'] = read_tree(f+'fident_distmat.txt_tree.txt')
    res[f.split('/')[-2]]['msa'] = read_msa(f+'encoded_aln_foldmason.fasta')
print(res.keys())
print(res['A0A011N458'])

100%|█| 101/101 [00:00<00:00, 600.86it/

dict_keys(['A0A011N458', 'A0A009NI06', 'A0A010RA97', 'A0A011P9Q6', 'A0A010RDZ2', 'A0A010QV66', 'A0A009LLL0', 'A0A009QF58', 'A0A009ECR5', 'A0A011M8L1', 'A0A010QBH7', 'A0A010RFX4', 'A0A010QH57', 'A0A009I5S8', 'A0A009QKS2', 'A0A009Q7Y6', 'A0A010QIB6', 'A0A009J8A1', 'A0A010Z6J6', 'A0A009LG02', 'A0A010ZQ01', 'A0A009LHJ9', 'A0A010QTJ4', 'A0A010Q4R4', 'A0A009ZZM7', 'A0A011MY18', 'A0A010Q175', 'A0A010SDI4', 'A0A010Q0U6', 'A0A009JM56', 'A0A010RCJ1', 'A0A010ZM43', 'A0A010R9P6', 'A0A010R299', 'A0A010S131', 'A0A009K3J7', 'A0A009PU45', 'A0A010Q311', 'A0A009EYT8', 'A0A010SSY2', 'A0A010R5R1', 'A0A011NB92', 'A0A011M2B8', 'A0A010JSD3', 'A0A010RFJ8', 'A0A010YDU6', 'A0A010TBK8', 'A0A011MIK5', 'A0A011NVJ3', 'A0A011MH15', 'A0A011P9P5', 'A0A010RZH3', 'A0A010PZP1', 'A0A010PTL2', 'A0A011NHP0', 'A0A011NP66', 'A0A010YHB2', 'A0A010YC68', 'A0A010RTK1', 'A0A011NVJ4', 'A0A010QYK0', 'A0A010SM59', 'A0A011P8I0', 'A0A010R0N5', 'A0A011NXL4', 'A0A010QA35', 'A0A010Z4N9', 'A0A010QIQ0', 'A0A010YLN3', 'A0A010YC40', 'A0A011NW




In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dendropy import Tree, CharacterMatrix
import numpy as np
from scipy.linalg import expm

class QMat(nn.Module):
    def __init__(self, num_chars=64, device='cpu'):
        super(QMat, self).__init__()
        self.device = device
        self.num_chars = num_chars
        # Initialize rate parameters
        num_rates =  int(((num_chars * num_chars) - num_chars ) / 2)
        self.rates = nn.Parameter(torch.ones(num_rates, device=device) / num_chars)
        self.scaling = torch.tensor( 1.0 , device=device)
        print(self.rates.shape)
        # Equilibrium frequencies
        self.freqs = nn.Parameter(torch.ones(num_chars, device=device) / num_chars)

    def forward(self):
        # Normalize frequencies to sum to 1
        freqs = self.freqs / torch.sum(self.freqs)

        # Initialize Q matrix
        Q = torch.zeros((self.num_chars, self.num_chars), device=self.device)
        
        # Fill the off-diagonal elements
        idx = 0
        for i in range(self.num_chars):
            for j in range(i+1, self.num_chars):
                if i != j:
                    rate = self.rates[idx]
                    Q[i, j] = rate * freqs[j]
                    Q[j, i] = rate * freqs[i]    
                    idx += 1
        # Set diagonal elements
        diag =-torch.sum(Q, dim=1)
        Q = torch.diagonal_scatter(Q,diag )
        self.scaling  = 1.0 / torch.sum(diag , dim=0)
        Q = Q * self.scaling
        return Q 



In [52]:

def compute_likelihood(tree, msa, Q, freqs, device='cpu', num_chars = 64 , verbose = False):
    
    # Exponentiate Q matrices for each unique branch length
    branch_lengths = set(edge.length for edge in tree.edges() if edge.length is not None)
    P_matrices = {}
    for t in branch_lengths:
        if t is not None:
            # Compute P(t) = expm(Q * t)
            Qt = Q * t  
            P = torch.matrix_exp(Qt)
            P_matrices[t] = P
    # Initialize likelihood vectors at the leaves
    taxon_namespace = msa.taxon_namespace
    char_to_index = alphamap
    for leaf in tree.leaf_node_iter():
        taxon_label = leaf.taxon.label
        sequence = msa[taxon_label]
        # Initialize the likelihood vector for the leaf
        leaf.likelihoods = []
        for site in range(len(sequence)):
            likelihood = torch.full((num_chars,), -np.inf, device=device)  # Log-likelihoods initialized to -inf
            observed_char = str(sequence[site])
            if observed_char in char_to_index:
                idx = char_to_index[observed_char]
                likelihood[idx] = 0.0  # log(1) = 0
            else:
                # Handle missing data or gaps by setting all states to 0 (log(1))
                likelihood[:] = 0.0
            leaf.likelihoods.append(likelihood)
    
    # Post-order traversal to compute likelihoods
    for node in tree.postorder_node_iter():
        if node.is_leaf():
            continue  # Likelihoods already initialized
        else:
            child_likelihoods = []
            for child in node.child_node_iter():
                edge_length = child.edge_length
                if edge_length is None:
                    edge_length = 0.0  # Assume zero length if not specified
                P = P_matrices.get(edge_length)
                if P is None:
                    # Compute P for this branch length
                    Qt = Q * edge_length
                    P = torch.matrix_exp(Qt)
                    P_matrices[edge_length] = P
                child_likelihoods.append((P, child.likelihoods))
            # Combine likelihoods from children
            node.likelihoods = []
            for site in range(len(leaf.likelihoods)):
                log_likelihood = torch.full((num_chars,), -np.inf, device=device)
                for state in range(num_chars):
                    log_prob = 0.0
                    for P, child_likelihood in child_likelihoods:
                        # Multiply (add in log-space) the probabilities from each child
                        log_p = torch.log(P[state, :]) + child_likelihood[site]
                        log_prob += torch.logsumexp(log_p, dim=0)
                    log_likelihood[state] = log_prob
                node.likelihoods.append(log_likelihood)
    
    print(  node.likelihoods[0] )

    # At the root, compute the total log-likelihood
    root = tree.seed_node
    total_log_likelihood = 0.0
    freqs = freqs / torch.sum(freqs)
    column_likelihoods = []

    for site in range(len(leaf.likelihoods)):
        root_likelihood = root.likelihoods[site]
        # Combine with equilibrium frequencies
        log_likelihood_site = torch.log(freqs) + root_likelihood
        # Sum over all states
        total_log_likelihood += torch.logsumexp(log_likelihood_site, dim=0)
    
    return total_log_likelihood

def optimize_likelihood(trees, msas ,  epochs=5, learning_rate=.1, lambda_penalty = 10, device='cpu' , verbose = False):
    # Initialize QMat
    qmat = QMat(num_chars=64, device=device)
    qmat.train()
    # Define optimizer 
    optimizer = torch.optim.Adam([ qmat.rates , qmat.freqs , qmat.scaling ] , lr=learning_rate  )
    # Get number of sequences
    for epoch in range(epochs):
        order = [i for i in range(len(msas))]
        np.random.shuffle(order)
        for i in tqdm.tqdm(order):
            msa = msas[i]
            tree = trees[i]
            num_sequences = len(msa)
            optimizer.zero_grad()
            # Compute log-likelihood
            Q = qmat()
            total_log_likelihood = compute_likelihood(tree, msa, Q , qmat.freqs , device=device , verbose = verbose)
            # Negative log-likelihood for minimization (normalized)
            loss = -total_log_likelihood / num_sequences
            # Backpropagation
            loss.backward()
            optimizer.step()
            
            # visualize qmat
            qdetached = qmat().detach().cpu().numpy()
            plt.imshow(np.log(qdetached))
            plt.colorbar()
            #label the axes with the alphabet
            plt.xticks(range(64), alphabet)
            plt.yticks(range(64), alphabet)
            plt.show()

            qfreqs = qmat.freqs.detach().cpu().numpy()
            plt.bar(range(64), qfreqs)
            plt.xticks(range(64), alphabet)
            plt.show()

            
            print(f'Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}')
    # After optimization, access optimized parameters
    optimized_rates = qmat.rates.detach().cpu().numpy()
    print("Optimized Rates:", optimized_rates)



In [53]:
trees = [ res[fam]['tree'] for fam in res ]
msas = [ res[fam]['msa'] for fam in res ]
optimize_likelihood(trees, msas, epochs=5, learning_rate=.001, device='cpu' , verbose = True)

torch.Size([2016])


  0%|          | 0/101 [00:00<?, ?it/s]

tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
       grad_fn=<CopySlices>)


  0%|          | 0/101 [01:51<?, ?it/s]


KeyboardInterrupt: 

In [None]:
#version with gamma and invariant sites
class QMat(nn.Module):
    def __init__(self, num_chars=64, device='cpu'):
        super(QMat, self).__init__()
        self.device = device
        self.num_chars = num_chars

        # Initialize rate parameters
        num_rates = num_chars * (num_chars - 1) // 2
        self.rates = nn.Parameter(torch.ones(num_rates, device=device))
        
        # Equilibrium frequencies
        self.freqs = nn.Parameter(torch.ones(num_chars, device=device))

        # Gamma shape parameter (alpha)
        self.alpha = nn.Parameter(torch.tensor(1.0, device=device))

        # Proportion of invariant sites (p_inv)
        self.p_inv = nn.Parameter(torch.tensor(0.1, device=device))

def gamma_discretization(alpha, k, device='cpu'):
    from scipy.stats import gamma
    quantiles = [i / k for i in range(k+1)]
    rates = []
    weights = []
    for i in range(k):
        # Calculate the mean rate in the interval
        q1 = gamma.ppf(quantiles[i], alpha)
        q2 = gamma.ppf(quantiles[i+1], alpha)
        mean_rate = (gamma.cdf(q2, alpha) - gamma.cdf(q1, alpha)) / (quantiles[i+1] - quantiles[i])
        rates.append(mean_rate)
        weights.append(quantiles[i+1] - quantiles[i])
    rates = torch.tensor(rates, device=device)
    weights = torch.tensor(weights, device=device)
    # Normalize rates so that the average rate is 1
    rates = rates / torch.mean(rates * weights)
    return rates, weights


def compute_likelihood(tree, msa, qmat, device='cpu', gamma_categories=4):
    Q = qmat()
    num_chars = qmat.num_chars

    # Discretize gamma distribution
    alpha = torch.exp(qmat.alpha)  # Ensure alpha is positive
    gamma_rates, gamma_weights = gamma_discretization(alpha.item(), gamma_categories, device=device)

    # Exponentiate Q matrices for each rate category and unique branch length
    unique_branch_lengths = set(edge.length for edge in tree.edges() if edge.length is not None)
    P_matrices = {}
    for rate in gamma_rates:
        for t in unique_branch_lengths:
            if t is not None:
                # Compute P(t) = expm(Q * rate * t)
                Qt = Q.detach().cpu().numpy() * (rate.item() * t)
                P = expm(Qt)
                P = torch.from_numpy(P).to(device)
                P_matrices[(rate.item(), t)] = P

    # Initialize likelihood vectors at the leaves
    taxon_namespace = msa.taxon_namespace
    custom_alphabet = [char.symbol for char in msa.default_state_alphabet]
    char_to_index = {char: idx for idx, char in enumerate(custom_alphabet)}

    for leaf in tree.leaf_node_iter():
        taxon_label = leaf.taxon.label
        sequence = msa[taxon_label]
        # Initialize the likelihood vector for the leaf
        leaf.likelihoods = []
        for site in range(len(sequence)):
            likelihood = torch.full((num_chars,), -np.inf, device=device)  # Log-likelihoods initialized to -inf
            observed_char = sequence[site]
            if observed_char in char_to_index:
                idx = char_to_index[observed_char]
                likelihood[idx] = 0.0  # log(1) = 0
            else:
                # Handle missing data or gaps by setting all states to 0 (log(1))
                likelihood[:] = 0.0
            leaf.likelihoods.append(likelihood)

    # Post-order traversal to compute likelihoods
    for node in tree.postorder_node_iter():
        if node.is_leaf():
            continue  # Likelihoods already initialized
        else:
            child_likelihoods = []
            for child in node.child_node_iter():
                edge_length = child.edge_length
                if edge_length is None:
                    edge_length = 0.0  # Assume zero length if not specified
                child_likelihoods.append((edge_length, child.likelihoods))
            # Combine likelihoods from children
            node.likelihoods = []
            for site in range(len(leaf.likelihoods)):
                # For invariant sites
                site_invariant = False
                observed_chars = set(msa[taxon_label][site] for taxon_label in msa.taxon_namespace.labels())
                if len(observed_chars) == 1:
                    site_invariant = True

                # Initialize total log-likelihood for the site
                total_log_likelihood = None

                # Compute likelihood for gamma-distributed rates
                log_likelihood_gamma = None
                for rate, weight in zip(gamma_rates, gamma_weights):
                    log_likelihood_rate = torch.full((num_chars,), 0.0, device=device)
                    for state in range(num_chars):
                        log_prob = 0.0
                        for edge_length, child_likelihood in child_likelihoods:
                            P = P_matrices[(rate.item(), edge_length)]
                            log_p = torch.log(P[state, :]) + child_likelihood[site]
                            log_prob += torch.logsumexp(log_p, dim=0)
                        log_likelihood_rate[state] = log_prob
                    if log_likelihood_gamma is None:
                        log_likelihood_gamma = torch.log(weight) + log_likelihood_rate
                    else:
                        log_likelihood_gamma = torch.logaddexp(log_likelihood_gamma, torch.log(weight) + log_likelihood_rate)

                # Compute likelihood for invariant sites
                if site_invariant:
                    # Likelihood is 1 for the observed state, 0 for others (log(1) = 0, log(0) = -inf)
                    invariant_likelihood = torch.full((num_chars,), -np.inf, device=device)
                    observed_char = sequence[site]
                    idx = char_to_index.get(observed_char, None)
                    if idx is not None:
                        invariant_likelihood[idx] = 0.0  # log(1) = 0
                else:
                    invariant_likelihood = torch.full((num_chars,), -np.inf, device=device)

                # Combine invariant and gamma likelihoods
                p_inv = torch.sigmoid(qmat.p_inv)  # Ensure p_inv is between 0 and 1
                log_p_inv = torch.log(p_inv)
                log_p_var = torch.log(1.0 - p_inv)

                combined_log_likelihood = torch.logaddexp(
                    log_p_inv + invariant_likelihood,
                    log_p_var + log_likelihood_gamma
                )

                node.likelihoods.append(combined_log_likelihood)

    # At the root, compute the total log-likelihood
    root = tree.seed_node
    total_log_likelihood = 0.0
    freqs = qmat.freqs / torch.sum(qmat.freqs)
    for site in range(len(leaf.likelihoods)):
        root_likelihood = root.likelihoods[site]
        # Combine with equilibrium frequencies
        log_likelihood_site = torch.log(freqs) + root_likelihood
        # Sum over all states
        total_log_likelihood += torch.logsumexp(log_likelihood_site, dim=0)

    return total_log_likelihood


In [None]:
#mafft matrix file example max 248 char

"""
0x01 0x01 2   # (comment)
0x1e 0x1e 2
0x1f 0x1f 2
0x21 0x21 2   # ! × !
0x41 0x41 2   # A × A
0x42 0x42 2   # B × B
0x43 0x43 2   # C × C
"""

def formathex(hexnum):
    if len(hexnum) == 3:
        return hexnum[0:2] + '0' + hexnum[2]
    else:
        return hexnum

def output_mafft_matrix( submat , char_set ,  outpath='mafft_submat.mtx' ):
    print( submat.shape , len( char_set ) )
    with open(outpath, 'w') as f:
        for i in range(len(char_set)):
            for j in range(len(char_set)):
                if i <= j:
                    stringi = char_set[i]
                    stringj = char_set[j]
                    
                    if stringi in replace_dict.keys():
                        stringi = replace_dict[stringi]
                    if stringj in replace_dict.keys():
                        stringj = replace_dict[stringj]
                  
                    hexi = formathex(hex(ord(stringi)))
                    hexj = formathex(hex( ord(stringj)))
                    
                    f.write( f'{hexi} {hexj} {submat[i,j]} \n ')# '+ stringi + 'x' + stringj + ' \n' )
        f.write('\n')
    return outpath


output_mafft_matrix( submat , alphabet_str , outpath='mafft_submat.mtx' )
