# PHYLOGENETIC TREES WITH DIGITAL ANNEALER INSPIRED BY NEIGHBOUR JOINING

The main objective of this method is to consecutively **group** a set of nodes or species with Fujitsu's Digital Annealer until a phylogenetic tree is obtained.

In [None]:
# local DA imports
from dadk.BinPol import *
from dadk.QUBOSolverCPU import *
from dadk.Solution_SolutionList import *

In [42]:
# real DA imports
from dadk.QUBOSolverDAv3c import *
from dadk.internal.ConnectionParameter import ConnectionParameter
from dadk.ProfileUtils import ProfileUtils
profile = ProfileUtils.get_annealer_access_profile(access_profile_file=r'~/.dadk/profiles/Lantik.prf')
connection_parameter = ConnectionParameter.decode(profile)
# ProfileUtils.list_annealer_access_profiles()
# ProfileUtils.store_annealer_access_profile(name_in_profile_store='Lantik',**profile)

In [43]:
# other imports
from openpyxl import load_workbook
import networkx as nx
import csv
import numpy as np
import time
import json
from datetime import datetime, timedelta
import enum

## Data collecting

In order to be able to obtain a phylogenetic tree, we need a set of species and a measurement that allows us to determine how close they are to each other. This would constitute the nodes and the edges of the graph we are going to cut. 

For that, we have originally taken datasets from [Phylome DB](https://phylomedb.org/phylomes?s=expl), a database webpage of different phylomes that offers both the standard accepted phylogenetic tree and the amino acid sequences of the species. With this amino acid sequence, in FASTA format, we have used the $\verb|func_cleanfasta|$ to clean up the .fasta files and fed it to BLAST+. The program analyzes its database of aminoacid sequences and following a local comparison algorithm, returns several metrics, among which we find the bit scores. 

Bit scores are a good metric that take several characteristics into account to determine how closely related the two species are. They are not normalized when they first come out of BLAST+, so the first step to get a proper matrix of similarities is to normalize them. Let us see how we do that with the following code.

In [2]:
# IMPORT SIMILARITY MATRIX (TSV FILE)

def import_from_tsv(filename):

    """
    Imports a similarity matrix from a tsv file and returns a list of nodes and a bitscore matrix.

    Args:
        filename (str): Path to the Excel file.

    Returns:
        nodes (list): List of unique node names.
        bitscore_matrix (numpy.ndarray): 2D numpy array representing the bitscore matrix.
    """

    tsv_path = filename
    column_list = []
    
    # Reads all rows as lists first
    
    with open(tsv_path, newline='', encoding='utf-8') as f:
        reader = csv.reader(f, delimiter='\t')
        rows = list(reader)
    
    # Transpose rows for columns
    
    for col in zip(*rows):
        column_list.append(list(col))
    
    # No header
    
    nodes_query = column_list[0]
    nodes_sub = column_list[1]
    bit_scores = column_list[11]
    
    bit_scores = [float(i) for i in bit_scores]

    # returs nodes and bitscore matrix

    nodes = extract_nodes(nodes_query)
    bitscore_matrix = extract_matrix(nodes_query, nodes_sub, nodes, bit_scores)

    return(nodes, bitscore_matrix)

The reason for $\verb|nodes_query|$ and $\verb|nodes_sub|$ to exist is that BLAST+ first creates a database with our .fasta file and calls the species in that database 'Subject'. Then it iteratively compares the sequences in the original .fasta ('Queries') with the ones on the database, so we have two columns with all the possible pairs in the set. To have a proper array with only the nodes (sequences) for future use, we will search through the first column and get the unique names out of it.

In [3]:
def extract_nodes(nodes_query):
    """
    Subroutine that extracts unique nodes from the nodes_query list.

    Args:
        nodes_query (list): Original query nodes list.

    Returns:
        nodes (list): List of unique node names.
    """

    nodes = []

    # we initialize the current node selected
    # the nodes are in order, so we just need to check when the name changes
    current_node_name = nodes_query[0]
    nodes.append(current_node_name)
    
    # we inspect all the nodes in nodes_query
    for i in range(len(nodes_query)):
        # select a new node only when the name changes
        if nodes_query[i] != current_node_name:
            nodes.append(nodes_query[i])
            # we update current_node_name
            current_node_name = nodes_query[i]
    
    return(nodes)

Now, we will normalize the bit scores. We will use the formula

$$ Norm_{bit(i,j)} = \frac{bit(i,j)}{\mathrm{mean}\{bit(i,i), bit(j,j)\}}*100. $$



We must also rearrange the $\verb|bit_scores_norm|$ array into a matrix. This is not a trivial step, as the total length of the array is not exactly $nodes \times nodes $, since some species are so distantly related that BLAST+ has purged them out. For those cases, we will asign the default zero to the edge between those nodes.

In [4]:
def extract_matrix(nodes_query, nodes_sub, nodes, bit_scores):

    """
    Subroutine that builds the structure of the bitscore matrix from the nodes_query and 
    nodes_sub list and the values from the normalized the bit_scores list.

    Args:
        nodes_query (list): List of query nodes.
        nodes_sub (list): List of subject nodes.
        nodes (list): List of unique node names.
        bit_scores (list): List of bit scores.

    Returns:
        bitscore_matrix (numpy.ndarray): 2D numpy array representing the bitscore matrix.
    """

    n_total = len(bit_scores)
    
    bit_scores_norm = []
    
    # We iterate over the entire array of bit scores
    
    for k in range(n_total):
    
        # We save the name of the query and subject
        query_name = nodes_query[k]
        sub_name = nodes_sub[k]
    
        # We iterate over the rest of the list finding the bit(query,query) = bit(i,i)
        # and bit(subject,subject) = bit(j,j)
    
        for i in range(n_total):
            
            if nodes_query[i] == query_name and nodes_sub[i] == query_name:
                bit_i_i = bit_scores[i]
    
            if nodes_query[i] == sub_name and nodes_sub[i] == sub_name:
                bit_j_j = bit_scores[i]
    
        # Calculate the mean
        mean_scores = (bit_i_i + bit_j_j)/2
    
        # Apply formula
        bit_scores_norm.append(bit_scores[k]*100/mean_scores)
        

    ### BIT SCORE MATRIX BUILD ###
    
    bitscore_matrix = np.zeros([len(nodes),len(nodes)])
    
    for k in range(n_total):
    
        # for each element in bit_scores_norm, we search for the corresponding
        # indexes (matching names) in the nodes array, thus we find the
        # corresponding column and row in the matrix
        
        index_query = nodes.index(nodes_query[k])
        index_sub = nodes.index(nodes_sub[k])
        bitscore_matrix[index_query,index_sub] = bit_scores_norm[k]


    ### TRIANGULATION ###
    
    # The matrix must be symmetrical and sometimes it is not symmetrical when it comes
    # directly out of the bitscore list values, so we force the symmetry by 
    # calculating the mean of the corresponding values 
    for i in range(len(nodes)):
        for j in range(len(nodes)):
            
            if bitscore_matrix[i][j] != bitscore_matrix[j][i]:
                # mean between element (i,j) and element (j,i)
                mean_score = (bitscore_matrix[i][j] + bitscore_matrix[j][i])/2
                # reset value of elements (i,j) and (j,i) to the new mean
                bitscore_matrix[i][j] = mean_score
                bitscore_matrix[j][i] = mean_score
    
            # Diag zeros !! NOT DIAG ZEROS THIS TIME
    
            # if i == j:
            #     bitscore_matrix[i][j]= 0

    return(bitscore_matrix)       

We find an interesting result when printing this matrix, which is that it is not exactly symmetrical. This is due to the algorithm that BLAST+ uses to compare queries against subjects. For example, if the amino acid sequence of a certain species is too short, it may yield different results when comparing that species with another as subject-query VS query-subject. In order to fix this, we have decided to triangulate the matrix by taking the mean between the corresponding entries $\mathrm{mean}(\verb|bitscore_matrix[i][j]|, \verb|bitscore_matrix[j][i]|)$.

## NEIGHBOUR JOINING MATRIX

The classical Neighbor Joining matrix is defined as follows. 
$$
Q(i,j) = (n-2) d(i,j) - \sum_{k=1}^n d(i,k) - \sum_{k=1}^n d(j,k)
$$

We will implement this definition element by element in the following function.

In [5]:
def calculate_real_nj(bitscore_inverse):

    """
    Function to calculate the Q matrix of the real nodes (leaf nodes), i.e, without the combined nodes yet.
    
    Args:
    bitscore_inverse (np.ndarray): Inverse bitscore matrix (dissimiliarity matrix).

    Returns:
    real_nj (np.ndarray): Q matrix of the real nodes, based on the inverse bitscore matrix.
    """

    # we get the number of real nodes by the size of the bitscore_inverse matrix
    N = bitscore_inverse.shape[0]

    # open an empty nj matrix of that size
    real_nj = np.zeros((N, N))

    # we fill only the upper triangular nj matrix with the Q formula
    for i in range(N):
        for j in range(i+1, N):
            real_nj[i,j] = (N-2)*bitscore_inverse[i][j] - np.sum(bitscore_inverse[i,:]) - np.sum(bitscore_inverse[j,:])

    return(real_nj)

## Combined nodes

We first calculate the distances between all the nodes ($\verb|bitscore_matrix|$) and the combined nodes (also refered to as artificial nodes or intermediate nodes). This process is of order $O(n(n-1)/2)$. 

The distance of the combined nodes to the rest of the nodes can be calculated through the Neighbour-Joining formula for the creation of a new node:
$$
d(u,k) = \frac{1}{2} [d(f,k)+d(g,k)-d(f,g)]
$$
where $f$ and $g$ are the joined nodes and $k$ is any other node left. $d(a,b)$ denotes the distance (normalized bitscore) between nodes $a$ and $b$.

In [6]:
def change_index(len_nodes, index_list, n_index=None):
    """
    Change the index of the matrix from (i,j) to the corresponding planar index up to N(N+1)/2

    Args:
    len_nodes (int): Number of nodes in the original matrix (N).
    index_list (list): List with the two indexes [i,j] to be converted to planar index.
    n_index (int): Planar index to be converted to (i,j) indexes.

    Returns:
    n_index (int): Planar index corresponding to (i,j) if index_list is provided.
    index_list (list): List with the two indexes [i,j] corresponding to n_index if n_index is provided.
    """ 

    # case where [i,j] is provided
    if index_list != []:
        if index_list[1] <= index_list[0]:
            i = index_list[1]
            j = index_list[0]
        else:
            i = index_list[0]
            j = index_list[1]
        n_index = len_nodes + sum([x for x in reversed(range(len_nodes))][:i]) + j - i - 1
        return(n_index)
    
    # case where n_index is provided
    elif n_index:
        i = 0
        suma = len_nodes
        while suma+len_nodes-i-1<= n_index:
            i += 1
            suma += len_nodes - i

        j = n_index - suma + i +1 

        return [i,j]


In [7]:
def filter_combined_nodes(len_nodes, nj_matrix, del_comb):

    """  
    Function that eliminates a number of the highest-energy combined nodes based on a percentage 
    of deletion and returns a list with the surviving pairs (real and combined).

    Args:
    len_nodes (int): Number of real nodes (N).
    nj_matrix (np.ndarray): NJ matrix with only REAL (leaf) nodes.
    del_comb (float): Deletion ratio for combined nodes (between 0 and 1).

    Returns:
    index_list (list): List of indexes corresponding to only the combined nodes to be created.
    """
    
    N = len_nodes

    index_list = []

    # we create a list with all the possible combinations of indices based on 
    # the total number of nodes N

    for k in range(int(N*(N+1)/2)):

        # leaf nodes only get one index (k)
        if k < len_nodes:
            index_list.append(k)
        # combined nodes get two indexes (i,j) based on the leaf nodes they join
        else:
            i_index = change_index(N, [], k)[0]
            j_index = change_index(N, [], k)[1]
            index_list.append((i_index,j_index))

    # in the particular case where there are less than 4 nodes,
    # we create all the possible combined nodes (we are approaching
    # the end of the reconstruction and we may get stuck in a loop)
    if N<4:
        return(index_list)
    
    # we calculate the number of combined nodes to be deleted
    # based on the del_comb ratio
    number_del_nodes = round(del_comb*N*(N-1)/2)

    # we create the nj matrix of the real nodes only
    short_nj_matrix = nj_matrix[:,:len_nodes]

    for k in range(number_del_nodes):
        # we look for the highest value in the short_nj_matrix 
        # (pair of nodes with the highest energy) and we delete that pair
        max_value = short_nj_matrix.max()
        indices = np.argwhere(short_nj_matrix == max_value)
        i_index = int(indices[0][0])
        j_index = int(indices[0][1])
        
        l_index = index_list.index((i_index,j_index))
        index_list.pop(l_index)
        # we set the highest value of the matrix to zero and continue
        short_nj_matrix[i_index,j_index] = 0

    return(index_list)

In [8]:
def calculate_adj_matrix(bitscore_inverse, comb_index_list):

    """   
    Function that creates the adjacency matrix again, now including the combined nodes
    (contrary to the bitscore_inverse matrix, where only the real nodes were taken into account).

    Args:
    bitscore_inverse (np.ndarray): Inverse bitscore matrix (we will keep the real node values from here).
    comb_index_list (list): List of indexes corresponding to the combined nodes to be created.

    Returns:
    adj_matrix (np.ndarray): Adjacency matrix including both real and combined nodes.

    """

    # we get the size of the real nodes from the bitscore_inverse matrix
    N = bitscore_inverse.shape[0]
    # we get the size of the real + combined nodes from the comb_index_list
    N_big = len(comb_index_list)
    # the adj_matrix does not need to be square yet
    adj_matrix = np.zeros((N, N_big))

    # we copy the bitscore_inverse values for the real nodes
    for i in range(N):
        for j in range(i+1, N):
            adj_matrix[i,j] = bitscore_inverse[i][j]
            adj_matrix[j,i] = adj_matrix[i,j]

    # we calculate the distances of the combined nodes for the rest
    for item in comb_index_list:
        # check that we are dealing with combined nodes (two indices)
        if isinstance(item, tuple):
            i_index = item[0]
            j_index = item[1]
            n_index = comb_index_list.index(item)
            # apply distance formula
            for k in range(N):
                adj_matrix[k,n_index] = 0.5*(bitscore_inverse[i_index][k] + bitscore_inverse[j_index][k] - bitscore_inverse[i_index][j_index])

    return(adj_matrix)


In [9]:
def correct_nj(nj_matrix):

    """     
    Function that shifts all the values in the NJ matrix to be positive and sets the lower triangular
    part of the matrix to zero.

    Args:
    nj_matrix (np.ndarray): NJ matrix to be corrected.

    Returns:
    new_nj_matrix (np.ndarray): Corrected NJ matrix.
    """

    N = nj_matrix.shape[0]
    # shifts matrix to be all positive
    new_nj_matrix = - min(min(row) for row in nj_matrix) + 1 + nj_matrix

    # sets lower triangular to zero (not needed)
    for i in range(N):
        for j in range(N):
            if j<= i:
                new_nj_matrix[i][j] = 0

    return(new_nj_matrix)

In [10]:
def calculate_combined_nj(real_nj, adj_matrix, comb_index_list):

    """"     
    Function that calculates the Q matrix with combined nodes based on the
    adj_matrix with combined nodes.

    Args:
    real_nj (np.ndarray): NJ matrix with only REAL (leaf) nodes.
    adj_matrix (np.ndarray): Adjacency matrix including both real and combined nodes.
    comb_index_list (list): List of indexes corresponding to the suviving combined nodes.

    Returns:
    corrected_nj_matrix (np.ndarray): Corrected (positive) NJ matrix including both real and combined nodes. 
       
    """

    # we get N and N_big from the shapes of the real_nj matrix and the combined nodes list
    N = real_nj.shape[0]
    N_big = len(comb_index_list)

    # dimension of the combined nj matrix (only combined nodes)
    combined_nj = np.zeros((N, N_big-N))

    # we concatenate that part to the final nj matrix 
    total_nj_matrix = np.concatenate((real_nj, combined_nj), axis=1)

    # we fill only the combined part
    for item in comb_index_list:

        if isinstance(item,tuple):
            i_index = item[0]
            j_index = item[1]
            n_index = comb_index_list.index(item)
            
            # combined-combined pairs get a penalization term
            for k in range(N):
                if k==i_index or k==j_index:
                    total_nj_matrix[k, n_index] = 1000
                # we calculate the real-combined values of the Q matrix
                else:
                    total_nj_matrix[k, n_index] = (N-3)*adj_matrix[k, n_index] - np.sum(adj_matrix[k,:N]) - np.sum(adj_matrix[:,n_index])
                    total_nj_matrix[k, n_index] += - adj_matrix[k][n_index] + adj_matrix[k][i_index] + adj_matrix[k][j_index]

    # we make the matrix positive
    corrected_nj_matrix = correct_nj(total_nj_matrix)

    return(corrected_nj_matrix)

In [11]:
def main_calculate_nj(adj_matrix, del_comb):

    """    
    Function that integrates all the subroutines to calculate the NJ matrix.

    Args:
    adj_matrix (np.ndarray): Adjacency matrix including both real and combined nodes.
    del_comb (float): Deletion ratio for combined nodes (between 0 and 1).

    Returns:
    comb_index_list (list): List of indexes corresponding to only the combined nodes to be created.
    filtered_adj (np.ndarray): Adjacency matrix including both real and combined nodes after filtering.
    filtered_nj (np.ndarray): NJ matrix including both real and combined nodes after filtering.
    
    """

    N = adj_matrix.shape[0]
    real_adj = adj_matrix[:, :N]

    real_nj = calculate_real_nj(real_adj)
    corrected_nj = correct_nj(real_nj)
    comb_index_list = filter_combined_nodes(N, corrected_nj, del_comb)
    filtered_adj = calculate_adj_matrix(real_adj, comb_index_list)
    filtered_nj = calculate_combined_nj(real_nj, filtered_adj, comb_index_list)

    return(comb_index_list, filtered_adj, filtered_nj)

## QUBO formulation


The base expression for our QUBO relies on minimizing the sum of the distance of all the pairs in the $K$ communities, meaning
$$
\sum_{k}^K \left( \sum_{i, j = 1}^{\frac{N(N+1)}{2}} d(i,j) x_{i,k} x_{j,k}\right)
$$
Nodes $\{1, ..., N\}$ are real leaf nodes and the rest of the indexes $\{N+1, ..., \frac{N(N+1)}{2}\}$ are reserved for combined nodes.

The QUBO is subject to the following restrictions:

- each group must have exactly two elements. This can be done through the following penalization term.
$$
\alpha \sum_{k}^K \left( \sum_{i=1}^{\frac{N(N+1)}{2}} x_{i,k} - 2 \right)^2
$$

- each initial node must be in one and only one group. We can apply this restriction through one-hot groups instead of a penalization term.
$$
\sum_{k}^K x_{i,k} = 1 \qquad i \in \{1, ..., N\}
$$

- each real node cannot belong to the same pair as an artificial node that sprouts out of it. This is applied by making the corresponding distance associated to virtual node $u = (i,j)$ and the corresponding $i$ and $j$ a high value, such as 1000.

The idea is that, in case the number of pairs does not match half the number of real nodes, the artificial nodes would fill up those places, and we would interpret that as the real node being left alone in the first iteration.

Therefore, the final QUBO that is going to be implemented is
$$
\sum_{k}^K \left( \sum_{i, j = 1}^{\frac{N(N+1)}{2}} d(i,j) x_{i,k} x_{j,k}\right) + \alpha \sum_{k}^K \left( \sum_{i=1}^{\frac{N(N+1)}{2}} x_{i,k} - 2 \right)^2
$$

In [12]:
# Simple function to hide the output of the QUBO solver

def _hide_pol_info(pol):
        pol.user_data['hide_scaling_info'] = True
        pol.user_data['hide_sampling_info'] = True
        return pol

In [13]:
def fix_varshapeset(A_matrix,K):
    """
    Fix the variable shape set for the given number of nodes N and number of communities K.
    """
    N = A_matrix.shape[0]
    N_big = A_matrix.shape[1]

    # shape of variables x_{i,k}
    bit_array_shape = BitArrayShape('x',(N_big,K))

    # one-hot groups restriction
    one_hot_groups = [OneHotGroup(('x',i,None)) for i in range(N)]

    # shape of our variable set (bit_array_shape + one_hot_groups)
    BinPol.freeze_var_shape_set(VarShapeSet(bit_array_shape, one_hot_groups=one_hot_groups))

In [14]:
def get_qubo(A_matrix,K_vars, alpha):
    """
    Get the QUBO BinPol object for the given NJ matrix, number of communities K_vars
    and penalization factor alpha.

    Args:
    A_matrix (np.ndarray): NJ matrix including both real and combined nodes.
    K_vars (int): Number of communities.
    alpha (float): Penalization factor. 

    Returns:
    qubo (BinPol): total QUBO BinPol object.
    q_dist (BinPol): minimization part of QUBO BinPol object.
    H_alpha (BinPol): penalization part of QUBO BinPol object.
    """

    # fix variable shape set
    fix_varshapeset(A_matrix, K=K_vars)
    
    # Build the QUBO matrix
    qubo = BinPol()
    N = A_matrix.shape[0]
    N_big = A_matrix.shape[1]

    for k in range(K_vars):
        for i in range(N):
            for j in range(i+1, N_big):
                # we only create the upper triangular matrix
                qubo.add_term(A_matrix[i,j], ('x', i, k), ('x', j, k)) # d_ij * x_i * x_j

    q_dist = qubo.clone()


    # build the penalization term 
    H_alpha = BinPol()

    for k in range(K_vars):
        H_aux = BinPol()
        for i in range(N_big):
            H_aux.add_term(1, ('x', i, k))
        H_aux.add_term(-2, ())
        H_aux.power(2)
        H_alpha.add(H_aux)
        
    H_alpha.multiply_scalar(alpha)

    # total qubo
    qubo = qubo.add(H_alpha)

    return(qubo, q_dist, H_alpha)

In [15]:
def get_communities(A_matrix,K,solution_list):
    """
    Transform the solution list into a list of communities.

    Args:
    A_matrix (np.ndarray): NJ matrix including both real and combined nodes.
    K (int): Number of communities.
    solution_list (list): object returned from the QUBO solver (list of 0s and 1s).

    Returns:
    communities (list): List of communities with the corresponding node indexes.
    """

    # initialize empty list of communities
    communities = [[] for i in range(K)]

    N_big = A_matrix.shape[1]

    # enumerate nodes (planar)
    node_list = [i for i in range(N_big)]

    for i in range(len(solution_list)):
        for k in range(K):
            if solution_list[i][k] == 1:
                communities[k].append(node_list[i])
                
    return communities

In [16]:
def QUBO_local(nj_matrix, K_vars, file_name):

    """    
    Function that builds and solves the QUBO problem for the given NJ matrix 
    in the CPU (local solver).

    Args:
    nj_matrix (np.ndarray): NJ matrix including both real and combined nodes.
    K_vars (int): Number of communities.
    file_name (str): Name of the file to save the graphs (not used currently).

    Returns:
    comms (list): List of communities with the corresponding node indexes.
    an_time (dict): Dictionary with the timing information of the annealing process.
    an_energy (dict): Dictionary with the energy information of the annealing process.
    
    """

    # built-in penalization factor
    alpha = 1000

    # building the QUBO
    qubo, q_dist, H_alpha = get_qubo(nj_matrix, K_vars, alpha)
    qubo = _hide_pol_info(qubo)

    # SOLVER
    # arguments for the QUBO solver
    solver_args = {
                    'optimization_method':'annealing',
                    'number_iterations':30000,
                    'number_runs':10,
                    'scaling_bit_precision':32,   ### CAMBIAR AQUI A 32 O 16
                    'scaling_action':ScalingAction.AUTO_SCALING,
                    }
                    # 'graphics':GraphicsDetail.SINGLE

    solver = QUBOSolverCPU(**solver_args)

    # launch solver
    solution_list = solver.minimize(qubo)
    solution_list.encode(solution_list)
    # solution_list.display_graphs(file=file_name)

    # get communities from solution list
    comms = get_communities(nj_matrix, K_vars ,solution_list.min_solution['x'].data)

    # CONDITION CHECK
    # checks that the one-hot condition and penalization 
    # are fulfilled (no empty or overfull communities)

    max_iter = 3
    it = 0

    # relaunches solver increasing alpha until conditions are met
    while (max(len(x) for x in comms)>2 or min(len(x) for x in comms)<2) and it < max_iter:
        it += 1
        alpha = alpha*10
        qubo, q_dist, H_alpha = get_qubo(nj_matrix, K_vars, alpha)
        qubo = _hide_pol_info(qubo)
        solution_list = solver.minimize(qubo)
        comms = get_communities(nj_matrix, K_vars ,solution_list.min_solution['x'].data)

    # saves time dictionary
    an_time =  {'execution':        solution_list.solver_times.duration_execution.total_seconds(),
                'solve':            solution_list.solver_times.duration_solve.total_seconds(),
                'scaling':          solution_list.solver_times.duration_scaling.total_seconds(),
                'elapsed':          solution_list.solver_times.duration_elapsed.total_seconds()}
                # 'send_request':     solution_list.solver_times.duration_send_request.total_seconds(),
                # 'receive_response': solution_list.solver_times.duration_receive_response.total_seconds(),
    
    # saves energy dictionary
    an_energy = {'total_energy':    qubo.compute(solution_list.min_solution.configuration),
                 'dist_energy':     q_dist.compute(solution_list.min_solution.configuration),
                 'penal_energy':    H_alpha.compute(solution_list.min_solution.configuration)}
        
        
    return(comms, an_time, an_energy)


In [17]:
def QUBO_annealer(nj_matrix, K_vars, time_limit, file_name):
    
    """    
    Function that builds and solves the QUBO problem for the given NJ matrix 
    in Fujitsu's Digital Annealer (DAv3).

    Args:
    nj_matrix (np.ndarray): NJ matrix including both real and combined nodes.
    K_vars (int): Number of communities.
    file_name (str): Name of the file to save the graphs.

    Returns:
    comms (list): List of communities with the corresponding node indexes.
    an_time (dict): Dictionary with the timing information of the annealing process.
    an_energy (dict): Dictionary with the energy information of the annealing process.
    
    """

    # built-in penalization factor
    alpha = 1000

    # building the QUBO
    qubo, q_dist, H_alpha = get_qubo(nj_matrix, K_vars, alpha)
    # qubo = _hide_pol_info(qubo)

    # SOLVER
    # arguments for the QUBO solver
    solver = QUBOSolverDAv3c(
            time_limit_sec=time_limit,
            # user_access_profile=True,
            scaling_bit_precision=32,
            # num_group=kwargs['num_group'],
            ohs_xw1h_internal_penalty=1,               # if there is one-hot, leave at =1
            scaling_action=ScalingAction.AUTO_SCALING,
            offline_request_file= "request.json",
            offline_response_file = "response.json",
            connection_parameter=connection_parameter,)
    
    # launch solver
    # solution_list = solver.minimize(q_dist, H_alpha)
    solution_list = solver.minimize(q_dist, H_alpha)
    solution_list.encode(solution_list)
    # solution_list.display_graphs(file=file_name)

    # only valid for real annealer (DAV3)
    solution_list.print_progress(csv_report=filename + '.csv', fig_report=filename + '.png')
    json_file_name = file_name + '.json'
    write_json(solution_list, json_file_name)
    
    # get communities from solution list
    comms = get_communities(nj_matrix, K_vars ,solution_list.min_solution['x'].data)  

    # saves time dictionary
    an_time =  {'execution':        solution_list.solver_times.duration_execution.total_seconds(),
                'solve':            solution_list.solver_times.duration_solve.total_seconds(),
                'scaling':          solution_list.solver_times.duration_scaling.total_seconds(),
                'elapsed':          solution_list.solver_times.duration_elapsed.total_seconds(),
                'send_request':     solution_list.solver_times.duration_send_request.total_seconds(),
                'receive_response': solution_list.solver_times.duration_receive_response.total_seconds()}
    
    an_energy = {'total_energy':    qubo.compute(solution_list.min_solution.configuration),
                 'dist_energy':     q_dist.compute(solution_list.min_solution.configuration),
                 'penal_energy':    H_alpha.compute(solution_list.min_solution.configuration)}
      
        
    return(comms, an_time, an_energy)

In [18]:
def order_comms(len_nodes, comms, nj_matrix):

    """"    
    Function that orders the list of communities based on their energy values.

    Args:
    len_nodes (int): Number of real nodes (N).
    comms (list): List of communities with the corresponding node indexes.
    nj_matrix (np.ndarray): NJ matrix including both real and combined nodes.

    Returns:
    comms_list_sorted (list): List of communities ordered by energy values.
    energies_list_sorted (list): List of energy values ordered.
    
    """

    # initialize both lists
    comms_list = []
    energies_list = []

    # save the values of communities containing at least one real node
    for i in range(len(comms)):
        # we only consider communities that include at least one real node
        if comms[i][0]<=len_nodes-1 and comms[i][1]<=len_nodes-1:
            comms_list.append(comms[i])
            index_0 = comms[i][0]
            index_1 = comms[i][1]
            # save the energy values from the nj_matrix
            energies_list.append(nj_matrix[index_0, index_1])

    # pair both lists and sort them based on the energy values
    paired = list(zip(comms_list, energies_list))

    # sort based on energy values (second element of the tuple)
    paired_sorted = sorted(paired, key=lambda x: x[1], reverse=True) # sorted based on second list (x[1])

    # Unzip back into two lists
    comms_list_sorted, energies_list_sorted = zip(*paired_sorted)

    return list(comms_list_sorted), list(energies_list_sorted)

### Post-processing

In [19]:
def complete_combined_adj(adj_matrix, comb_index_list):

    """    
    Subroutine that completes the adjacency matrix with the distance between
    pairs of combined nodes.

    Args:
    adj_matrix (np.ndarray): Adjacency matrix.
    comb_index_list (list): List of indexes corresponding to the total amount of nodes.

    Returns:
    new_adj_matrix (np.ndarray): Completed adjacency matrix including distances between combined nodes.
    
    """

    N = adj_matrix.shape[0]
    N_big = adj_matrix.shape[1]

    # create a new adjacency matrix that is square
    new_adj_matrix = np.zeros((N_big, N_big))

    for i in range(N_big):
        for j in range(N_big):

            # previous adjacency matrix (real and combined nodes)
            if i<N:
                new_adj_matrix[i,j] = adj_matrix[i,j]
            # complete lower triangular part with symmetry
            elif i>=N and j<N:
                new_adj_matrix[i,j] = adj_matrix[j,i]
            # combined-combined pairs
            elif i>=N and j>=N and j>i:
                # i indicates the position in the comb_index_list array
                # j indicates the column
                i_index = comb_index_list[i][0]
                j_index = comb_index_list[i][1]
                # apply distance formula
                new_adj_matrix[i,j] =0.5*(adj_matrix[i_index][j] + adj_matrix[j_index][j] - adj_matrix[i_index][j_index])
                new_adj_matrix[j,i] = new_adj_matrix[i,j]
            
            # diagonal zeros
            if i==j: 
                new_adj_matrix[i,j] = 0

    return(new_adj_matrix)

In [20]:
def update_nodes_adj_matrix(node_list, pair_sublist, adj_matrix):

    """    
    Function that looks for the indices of the nodes to be removed and the combined nodes to be added 
    based on the pairs formed in the sublist.

    Args:
    node_list (list): List of current nodes.
    pair_sublist (list): List of pairs of nodes to be combined.
    adj_matrix (np.ndarray): Current adjacency matrix.

    Returns:
    removed_indices (list): List with the indices of the nodes that were removed.
    added_indices (list): List with the indices of the combined nodes that were added.
    new_node_list (list): Updated list of nodes after combining.
    new_full_adj (np.ndarray): Updated adjacency matrix after combining nodes.

    """

    # Create an empty array with the indices to be removed
    # and another with the combined node indices to be added
    removed_indices = []
    added_indices = [i for i in range(len(node_list))]

    for j in range(len(pair_sublist)):
        for i in range(len(node_list)):

            if node_list[i] in pair_sublist[j]:
                removed_indices.append(i)

        # combined indices that were added need to be taken out from the real nodes index list
        # so if combined node (1,3) is created, then indices 1 and 3 are removed
        added_indices.append((removed_indices[j*2], removed_indices[j*2+1]))


    # MATRIX UPDATE

    # We purge all the combined nodes from the matrix
    bitscore_inverse = adj_matrix[:,:len(node_list)]

    # Add the corresponding combined nodes (as columns and as rows)
    new_adj_matrix = calculate_adj_matrix(bitscore_inverse, added_indices)

    # Complete with the distances between combined-combined nodes
    new_full_adj = complete_combined_adj(new_adj_matrix, added_indices)

    # Delete the corresponding real nodes
    new_full_adj = np.delete(new_full_adj, removed_indices, axis=0)  # rows
    new_full_adj = np.delete(new_full_adj, removed_indices, axis=1) # columns

    # We create a new node list removing the real nodes that have been paired up
    new_node_list = [x for x in node_list if node_list.index(x) not in removed_indices]

    # We flatten the combined nodes list
    for i in range(len(pair_sublist)):
        flattened_sublist = []
        for item in pair_sublist[i]:
            if isinstance(item, list):
                flattened_sublist.extend(item)
            else:
                flattened_sublist.append(item)

        new_node_list.append(flattened_sublist)

    return(removed_indices, added_indices, new_node_list, new_full_adj)


In [21]:
def create_pair_sublist(node_list, comms):

    """     
    Function that creates a sublist from node_list with the communities of nodes joined.

    Args:
    node_list (list): List of current nodes.
    comms (list): List of communities with the corresponding node indexes.

    Returns:
    pair_sublist (list): List of pairs of nodes that were combined.
    """

    len_nodes = len(node_list)
    pair_sublist = []

    for i in range(len(comms)):
        if len(comms[i]) == 2:
            # only taken into accounts pairs of real-real nodes
            # (i.e., those with indices lower than N)
            if comms[i][0]<=len_nodes-1 and comms[i][1]<=len_nodes-1:
                index_coms_1 = comms[i][0]
                index_coms_2 = comms[i][1]
                pair_sublist.append([node_list[index_coms_1], node_list[index_coms_2]])
        # elif comms[i][0]<=len_nodes-1 and comms[i][1]>len_nodes-1:
        #     index_coms_1 = comms[i][0]
        #     index_coms_2 = change_index(len_nodes, [], comms[i][1])[0]
        #     index_coms_3 = change_index(len_nodes, [], comms[i][1])[1]
        #     print(nodes[index_coms_1], (nodes[index_coms_2], nodes[index_coms_3]))
        #     # print([comms[i][0], change_index(len_nodes, [], comms[i][1])])
        # elif comms[i][0]>len_nodes-1 and comms[i][1]>len_nodes-1:
        #     print(change_index(len_nodes, [], comms[i][0]), change_index(len_nodes, [], comms[i][1]))
        
    return(pair_sublist)

In [22]:
def filter_comms(len_nodes, comms, nj_matrix, del_filter):

    """    
    Function that filters the communities list by removing a percentage of the highest-energy
    communities based on the del_filter ratio.

    Args:
    len_nodes (int): Number of real nodes (N).
    comms (list): List of communities with the corresponding node indexes.
    nj_matrix (np.ndarray): NJ matrix including both real and combined nodes.
    del_filter (float): Filtering percentage for communities (between 0 and 1).

    Returns:
    comms_list (list): Filtered list of communities after removing the highest-energy ones.

    """

    # if no filtering, skip this function
    if del_filter == 0:
        return(comms)
    
    # initialize both lists
    comms_list = []
    energies_list = []

    # for pairs of real-real nodes, save their energy values
    for i in range(len(comms)):
        if comms[i][0]<=len_nodes-1 and comms[i][1]<=len_nodes-1:
            comms_list.append(comms[i])
            index_0 = comms[i][0]
            index_1 = comms[i][1]
            energies_list.append(nj_matrix[index_0, index_1])

    # count how many pairs are going to be deleted
    deletion_number = del_filter*len(comms_list)

    # start deleting highest-energy pairs
    for i in range(round(deletion_number)):
        max_index = energies_list.index(max(energies_list))
        energies_list.pop(max_index)
        comms_list.pop(max_index)

    return(comms_list)

In [23]:
def assign_key(species,
               node_list, 
               bitscore_inverse, 
               del_filter, 
               del_comb, 
               data=None, 
               result=None, 
               counter=0, 
               prints=False):
    
    """    
    Recursive function that calls all the functions of the algorithm and assigns a binary 
    key to the grouping process and saves the dictionary with this tree topology.

    Args:
    species (str): Name of the species.
    node_list (list): List of current nodes.
    bitscore_inverse (np.ndarray): Inverse bitscore matrix (dissimiliarity matrix).
    del_filter (float): Filtering percentage for communities (between 0 and 1).
    del_comb (float): Deletion ratio for combined nodes (between 0 and 1).
    data (list): List to save the data of each iteration (number of nodes, times, etc.).
    result (dict): Dictionary to save the binary keys for each original node.
    counter (int): Iteration counter.
    prints (bool): Whether to print debug information.

    Returns:
    result (dict): Dictionary with the binary keys for each original node.
    data (list): List with the data of each iteration (number of nodes, times, etc.).
    
    """

    # only for the first iteration, create the result dictionary
    if result is None:
        result = {}
        data = []
        for i in range(len(node_list)):
            result[node_list[i]] = ''

    # increment counter
    counter += 1
    # name of the data dump file
    file_name = species + '_' + str(int(del_comb*100)) + '_' + str(int(del_filter*100)) + '_iter_' + str(counter)

    # prints for debugging
    if prints:
        print('----------------------------------------')
        print('result_before', result)
        print('node_list=', node_list)
        print('len_nodes', len(node_list))

    # only if it is not the last iteration
    if len(node_list)>1:

        # calculate NK matrix
        comb_index_list, adj_matrix, nj_matrix = main_calculate_nj(bitscore_inverse, del_comb)
        # set number of communities based on number of nodes
        K_vars = round(len(node_list)*3/4)
        
        # special case for 2 nodes
        if len(node_list) == 2:
            K_vars = 1

        # special case for more than 30 iterations, we remove the filter
        # (avoids infinite loops with filtering percentage)
        if counter>= 30:
            del_filter = 0

        # solves QUBO problem
        comms, annealer_times, annealer_energies = QUBO_local(nj_matrix, K_vars, file_name)

        # processes communities
        filtered_comms = filter_comms(len(node_list), comms, nj_matrix, del_filter)
        pair_sublist = create_pair_sublist(node_list, filtered_comms)

        # prints for debugging
        if prints:
            print('comms=', comms)
            print('pair_sublist', pair_sublist)

        # routine that creates the binary code for the tree topology
        # assigns '0' and '1' to each node based on the pairs created
        # it is cummulative, so the key keeps growing each iteration
        for i in range(len(pair_sublist)):
            # old pairs, assigns key to the entire group
            if isinstance(pair_sublist[i][0], list):
                for j in range(len(pair_sublist[i][0])):
                    result[pair_sublist[i][0][j]] += '0'
            # new pairs
            elif isinstance(pair_sublist[i][0], list) == False:
                result[pair_sublist[i][0]] += '0'

            if isinstance(pair_sublist[i][1], list):
                for j in range(len(pair_sublist[i][1])):
                    result[pair_sublist[i][1][j]] += '1'
            elif isinstance(pair_sublist[i][1], list) == False:
                result[pair_sublist[i][1]] += '1'

        # update node list and adjacency matrix
        removed, added, new_nodes, new_bitscore = update_nodes_adj_matrix(node_list, pair_sublist, adj_matrix)
        
        # prints for debugging
        if prints:
            print('added_indices', added)

        # save data
        data.append([counter, len(node_list), annealer_times])
        
        # recursive call
        assign_key(species, new_nodes, new_bitscore, del_filter, del_comb, data, result, counter, prints)

    return(result, data)

-------------
-------------

## Phylogenetic tree reconstruction

Now that we have our final result array, let us reconstruct the tree with the binary code we have implemented. Later, we will use the BioPython package to visualize the results, once we have them in newick format.

In [24]:
# biological imports

from Bio import Phylo
import dendropy
from dendropy.calculate import treecompare
from itertools import combinations
from sklearn.metrics import mutual_info_score
from scipy.stats import entropy
from io import StringIO
import matplotlib.pyplot as plt
import re
import time

In [25]:
# PACKAGES RELATED TO R

import os
os.environ['R_HOME'] = r'C:\Users\alfonsorodrr\AppData\Local\Programs\R\R-4.5.0'
from rpy2.robjects import r, globalenv
from rpy2.robjects.packages import importr
from rpy2.robjects.vectors import StrVector
ape = importr('ape')
treedist = importr('TreeDist')

In [26]:
def import_real_tree(filename):

    """   
    Function that imports a real tree from an Excel file.
    
    Args:
    filename (str): path to the excel file.

    Returns:
    row_list (list): List of lists with the rows from the excel file (trees in newick format).
    
    """

    # Open up the Excel file
    workbook = load_workbook(filename)
    
    # Get the first sheet
    worksheet = workbook.worksheets[0]
    
    # Read the rows into a list
    row_list = []
    
    for r in worksheet.rows:
        column = [cell.value for cell in r]
        row_list.append(column)
    
    return(row_list)

In [27]:
def prune(tree):

    """  
    Function that prunes the intermediate nodes from a newick tree string (if any).

    Args:
    tree (str): Newick string of the tree.

    Returns:
    tree_root (str): Newick string of the tree without intermediate nodes.

    """

    # substitutes string of the shape ...)NAME_OF_NODE_0001:0.1...
    # for ...):0.1... (standard format with only leaf names)

    tree_nodes = re.sub(r'\)\s*[A-Za-z0-9\._]+:', '):', tree)
    tree_root = re.sub(r'\)\s*[A-Za-z0-9\._]+;', ');', tree_nodes)

    return(tree_root)

In [45]:
def dict_to_newick(tree_dict, branch_length=1.0):

    """    
    Function that transforms the reconstructed tree dictionary into Newick format.

    Args:
    tree_dict (dict): Dictionary with the binary code as keys and list of species as values.
    branch_length (float): Branch length to be assigned to each branch in the Newick format.    

    Returns:
    newick_str (str): Newick string of the reconstructed tree.
    
    """

    def insert_path(tree, path, name):
        """    
        Helper function that inserts a path (sequence of directions) into 
        a nested tree dictionary.
        """
        # iterates through all but the last element of the path
        for direction in path[:-1]:
            # checks if current direction exists in the current level of the tree
            if direction not in tree:
                # if not, creates a new dictionary at that direction
                tree[direction] = {}
            # moves to the next level in the tree
            tree = tree[direction]
        # assigns the name at the last element of the path
        tree[path[-1]] = name

    def build_newick(subtree):
        """    
        Helper function that recursively builds the Newick string from a subtree.
        """
        # checks if the current subtree is a leaf node
        if isinstance(subtree, str):
            # if so, returns the species name with branch length
            return f'{subtree}:{branch_length}'
        # otherwise, recursively builds the Newick string for each child
        children = [build_newick(child) for child in subtree.values()]
        # joins the children with commas and wraps them in parentheses
        return f'({','.join(children)}):{branch_length}'

    # initializes the root of the tree
    root = {}
    # inserts each path and species into the tree
    for path, species_list in tree_dict.items():
        for species in species_list:
            insert_path(root, path, species)

    # builds the Newick string from the root and appends final semicolon
    return build_newick(root) + ';'

### Percentage of correctly reconstructed tree

In order to measure the percentage of similarity between the real tree and the reconstructed one, we use the unweighted Robinson-Fould distance and the formula

$$ percentage_{correct} = \frac{Robinson-Foulds \ distance}{2N-6}*100, $$

where $N$ denotes the total number of external tips and $2N-6$ represents the maximal possible distance two trees can take. Therefore, this formula measures how many branches are recovered in the reconstructed tree compared with the true tree.

In [46]:
def percentage_rf(tree_real, tree_reconstructed, len_nodes):

    """    
    Function that calculates the percentage of similarity between two trees based 
    on the Robinson-Foulds distance.

    Args:
    tree_real (str): Newick string of the real tree.
    tree_reconstructed (str): Newick string of the reconstructed tree.
    len_nodes (int): Number of nodes in the tree.

    Returns:
    percentage (float): Percentage of similarity between the two trees (100 = total similarity).

    """

    # shared_data creates a shared taxon space of node names for the two trees to use
    shared_taxa = dendropy.TaxonNamespace()

    # read trees in newick format into the dendropy syntax
    tree1 = dendropy.Tree.get(data = tree_real, schema="newick", taxon_namespace = shared_taxa)
    tree2 = dendropy.Tree.get(data = tree_reconstructed, schema="newick", taxon_namespace = shared_taxa)

    # calculate robinson-foulds distance
    robinson_foulds_distance = treecompare.symmetric_difference(tree1, tree2)

    # implement percentage formula
    percentage = 100 - robinson_foulds_distance*100/(2*len_nodes-6)

    return(percentage)

### Percentage of correctly reconstructed tree using Clustering distance

Sometimes, the Robinson-Foulds distance may not be an appropiate measure, as it has been critisized to yield biased results. It may not escalate well with a large number of nodes and may omit appropiate clusters if a partition close to the root is not correctly reconstructed. Thus, the Clustering Info Distace has been proposed as an alternative for the Robinson-Foulds distance in order to measure the percentage of correct reconstruction.

The Clustering Distance measures the correctly recovered clusters based on information theory metrics. It is implemented in R in the TreeDist library, so we have imported the function directly into python using the $\verb|rpy2|$ package. The function $\verb|ClusteringInfoDistace|$ can be normalized and returns a value $0$ when two trees are identical, therefore, we have implemented the percentage as

$$  percentage_{correct} = \left( 1- Clustering \ distance \right)*100. $$

In [47]:
def percentage_cd(tree_real, tree_reconstructed):

    """    
    Function that calculates the percentage of similarity between two trees based 
    on the Clustering Information distance.

    Args:
    tree_real (str): Newick string of the real tree.
    tree_reconstructed (str): Newick string of the reconstructed tree.

    Returns:
    percentage (float): Percentage of similarity between the two trees (100 = total similarity).
    
    """

    # transform trees into R format
    tree1 = ape.read_tree(text=tree_real)
    tree2 = ape.read_tree(text=tree_reconstructed)

    # calculate clustering information distance
    dist_r = treedist.ClusteringInfoDistance(tree1, tree2, normalize=True)
    
    distance = float(dist_r[0])
    
    # apply formula based on percentage
    return((1-distance)*100)

In [None]:
def visualize(tree, file_save, fig_size=(20,18)):

    """   
    Function to visualize the phylogenetic trees using matplotlib.

    Args:
    tree (str): Newick string of the tree.
    file_save (str): Name of the file to save the output image.
    fig_size (tuple): Size of the figure to be created.
    
    """

    # read tree in newick format into Phylo syntax
    handle_tree = StringIO(tree)
    tree = Phylo.read(handle_tree, 'newick')

    # build plot and save output
    fig = plt.figure(figsize=fig_size)
    axes = fig.add_subplot(1,1,1)
    Phylo.draw(tree, do_show=False, axes=axes)
    plt.savefig(file_save +'.jpg', dpi=300)
    plt.show()

In [48]:
def calculate_branch_length(tree_file):

    """      
    Function that calculates the mean and standard deviation of branch lengths
    from a given tree in Newick format.

    Args:
    tree_file (str or Phylo.BaseTree.Tree): Newick string of the tree or Phylo tree object.

    Returns:
    mean_length (float): Mean branch length of the tree.
    std_length (float): Standard deviation of branch lengths of the tree.
    
    """

    # check if the tree is in newick formart
    if isinstance(tree_file, str) is True:
        # if so, transform it into Phylo object
        tree = Phylo.read(StringIO(tree_file), 'newick')
    else:
        tree = tree_file

    # get all the branch lengths
    lengths = [clade.branch_length for clade in tree.find_clades() if clade.branch_length is not None]

    if not lengths:
        return None, None

    # calculate mean and standard deviation
    mean_length = np.mean(lengths)
    std_length = np.std(lengths)

    return mean_length, std_length

-------------
------------

## Tables and others

In [None]:
def combined_deletion_percentage(len_nodes, nj_matrix, comms):

    """    
    Function that calculates how many combined nodes were at play in the 
    reconstruction to estimate how many can be deleted.

    Args:
    len_nodes (int): Number of real nodes (N).
    nj_matrix (np.ndarray): NJ matrix including both real and combined nodes.
    comms (list): List of communities with the corresponding node indexes.

    Returns:
    combined_deletion_perc (float): Percentage of combined nodes that can be deleted.
    
    """

    energies = []
    
    for i in range(len(comms)):
        if comms[i][0]<=len_nodes-1 and comms[i][1]<=len_nodes-1:
            index_0 = comms[i][0]
            index_1 = comms[i][1]
            print('indices= (', index_0 ,',', index_1, ')')
            energies.append(nj_matrix[index_0, index_1])
            print(nj_matrix[index_0, index_1])

    print('comms', comms)
    print('nj_matrix_reals', nj_matrix[:,:len_nodes])

    counter = 0
    for i in range(len_nodes):
        for j in range(len_nodes):
            if nj_matrix[i,j]>max(energies):
                counter += 1

    total = len_nodes*(len_nodes-1)/2

    combined_deletion_perc = counter/total

    print('len_nodes', len_nodes)
    print('comb_perc=', combined_deletion_perc)
    print('------------------------------------------------')

    line = ["6_Phy000D0PL_SCHPO", len_nodes, "", combined_deletion_perc]

    with open("combined_nodes_deletion.txt", "a") as txtfile:
        # Write headers
        txtfile.write('\t'.join(str(item) for item in line) + '\n')

    

In [34]:
def write_annealer_times(filename, species, data_list):
    """
    Saves a formatted species table to a text file.
    Each element in data_list should be:
    [species, iteration, len_nodes, { "execution": x, "solve": y, "scaling": z, "elapsed": w }]
    """

    with open(filename, "w") as f:
    
        for iteration, len_nodes, annealer_times in data_list:
            # Header line
            f.write(f"Species: {species:<20} | iter: {iteration:<5}\n")
            # Data lines
            f.write(f"{'':<44}| len_nodes:       {len_nodes:<10}\n")
            f.write(f"{'':<44}| execution_time:  {annealer_times.get('execution', 'N/A'):<10}\n")
            f.write(f"{'':<44}| solve_time:      {annealer_times.get('solve', 'N/A'):<10}\n")
            f.write(f"{'':<44}| scaling_time:    {annealer_times.get('scaling', 'N/A'):<10}\n")
            f.write(f"{'':<44}| elapsed_time:    {annealer_times.get('elapsed', 'N/A'):<10}\n")
           
            # Separator
            f.write("_" * 41 + "\n")

In [35]:
def write_percentages(filename, perc_filter, deletions, times):
    """
    Appends a species delay-times table to a text file.

    Parameters:
    - filename: str, path to file
    - species: str, species name
    - deletions: list of delay labels (e.g., ["Del_0", "Del_0.6", "Del_0.7", "Del_0.8"])
    - times: list of times corresponding to deletions
    """

    with open(filename, "a") as f:

        # Deletions row
        # del_row = "Percentages ".ljust(16) + "".join(f"{d:<10}" for d in deletions)
        # f.write(del_row + "\n")
        f.write("Filter_" + str(int(perc_filter*100)) + '\t' + '\t'.join(str("{:.4f}".format(d)) for d in deletions) + '\n')

        # Times row
        # times_row = "Times".ljust(16) + "".join(f"{t:<10}" for t in times)
        # f.write(times_row + "\n\n")
        f.write("Times\t\t" + '\t'.join(str("{:.4f}".format(t)) for t in times) + '\n')

In [49]:
def serialize_datetime(obj):

    """    
    Helper function to serialize datetime, timedelta, enum, and numpy ndarray objects for JSON serialization.
    
    """
    
    if isinstance(obj, datetime):
        return obj.isoformat()
    if isinstance(obj, timedelta):
        return obj.total_seconds()
    if isinstance(obj, enum.Enum):
        return obj.name
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    raise TypeError(f"Type {type(obj)} not serializable")

In [50]:
def write_json(solution_list, json_file_name):

    """   
    Function that writes the solution list to a JSON file.
    
    """

    json_solution = SolutionList.encode(solution_list)
    with open(json_file_name, 'w') as file:
        json.dump(json_solution, file, indent=2, default=serialize_datetime)

In [51]:
def save_intermediate_solutions(filename, node_list, adj_matrix, nj_matrix, comms, pair_sublist, result):

    """   
    Function that writes the intermediate solutions to a JSON file.
    
    """

    data = {
        "nodes": node_list,
        "adj_matrix": adj_matrix,
        "nj_matrix": nj_matrix,
        "comms": comms,
        "pair_sublist": pair_sublist,
        "result": result
    }
    with open(filename, "w") as f:
        json.dump(data, f, indent=2, default=serialize_datetime)

def load_intermediate_solutions(filename):
    with open(filename, "r") as f:
        data = json.load(f)
    return data["nodes"], data["adj_matrix"], data["nj_matrix"], data["comms"], data["pair_sublist"], data["result"]

In [None]:
def save_times_energies(filename, len_nodes, num_variables, time_dict, energy_dict):

    """   
    Function that saves the times and energies dictionary to a JSON file.

    """
    
    data = {
        "len_nodes": len_nodes,
        "num_variables": num_variables,
        "times": time_dict,
        "energies": energy_dict
    }
    with open(filename, "w") as f:
        json.dump(data, f, indent=2, default=serialize_datetime)

In [53]:
def write_branch_percentage_table(filename, header=None, data_rows=[]):
    """
    Writes a table with the given header and data_rows to filename.
    Each row should be a list: [filter_label, species, avg_branch_length, std_branch_length, nodes, percentage, std_percentage]
    """
    
    with open(filename, "a", encoding="utf-8") as f:
        # Write header
        if header:
            f.write('\t'.join(header) + '\n')
        # Write rows
        else:
            for row in data_rows:
                f.write('\t'.join(str(item) for item in row) + '\n')


-------------
------------

## EXECUTION

In [54]:
def main_local(directory, input_file, file_real_trees, del_filter, del_comb, prints=False):

    """
    Main function tagt that executes the local annealer reconstruction process.

    Args:
    directory (str): Directory where the input file is located.
    input_file (str): Name of the input TSV file.
    file_real_trees (str): Path to the Excel file with real trees.
    del_filter (float): Filtering percentage for communities (between 0 and 1).
    del_comb (float): Deletion ratio for combined nodes (between 0 and 1).
    prints (bool): Whether to print debug information.

    Returns:
    branch_length (float): Mean branch length of the reconstructed tree.
    branch_std (float): Standard deviation of branch lengths of the reconstructed tree.
    percentage (float): Clustering Distance percentage of similarity between the real and reconstructed trees.
    annealer_times (list): List of dictionaries with the timing information of each annealing iteration.

    """

    input_file_name = input_file.replace(".tsv", "")

    nodes, bitscore_matrix = import_from_tsv(directory + '\\' + input_file)
    bitscore_inverse = 100 - bitscore_matrix
    result, annealer_times = assign_key(input_file_name, nodes, bitscore_inverse, del_filter, del_comb, prints=prints)
    swapped_result =  {value[::-1]: [str(key)] for key, value in result.items()}

    # for key in sorted(swapped_result):
    #     print(f"{key}: {swapped_result[key]}")

    newick_reconstructed = dict_to_newick(swapped_result, branch_length=1.0)
    # visualize(newick_reconstructed, 'prueba_reconstructed')

    # file_real_trees = r'C:\Users\alfonsorodrr\OneDrive - FUJITSU\Escritorio\FUJITSU\Github\Folders\New_benchmarking\Homo_sapiens\Trees\Homo_sapiens.xlsx'
    real_tree_lists = import_real_tree(file_real_trees)
    # species = re.sub(r'^\d+_', '', input_file_name)
    species = re.sub(r'^[^_]+_|_[^_]+$', '', input_file_name)
    species_list = [sp[0] for sp in real_tree_lists]
    index_tree = species_list.index(species)
    newick_real = prune(real_tree_lists[index_tree][3])  

    branch_length, branch_std = calculate_branch_length(newick_real)

    percentage = percentage_cd(newick_real, newick_reconstructed)

    return(branch_length, branch_std, percentage, annealer_times)

In [None]:
"""   
Main execution block

"""

# directory_of_tsv = r'C:\Users\alfonsorodrr\OneDrive - FUJITSU\Escritorio\FUJITSU\Softwares\ncbi-blast-2.12.0+\databases\Miranda\Schpo 25-30'
directory_of_tsv = r'C:\Users\alfonsorodrr\OneDrive - FUJITSU\Escritorio\FUJITSU\QCare - Árboles filogenéticos\Github\Folders\New_benchmarking\Bos_taurus\Selection branches'
output_file = r'C:\Users\alfonsorodrr\OneDrive - FUJITSU\Escritorio\FUJITSU\QCare - Árboles filogenéticos\Github\Folders\New_benchmarking\Bos_taurus\Branch_lengths_results_NJ_80_all.txt'
file_real_trees = r'C:\Users\alfonsorodrr\OneDrive - FUJITSU\Escritorio\FUJITSU\QCare - Árboles filogenéticos\Github\Folders\New_benchmarking\Bos_taurus\Trees\Bos_taurus.xlsx'


# percentages_list = np.linspace(0,95,5) 
# filter_list = [0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95]
filter_list = [0.70]
del_list = [0.8]

header = ["species", "av_branch", "std_branch", "nodes", "filter", "av_perc", "std_perc"]
write_branch_percentage_table(output_file, header=header)

tsv_files = [f for f in os.listdir(directory_of_tsv) if f.endswith('.tsv')]
tsv_files_every_5 = tsv_files[::1]  # Take every 5th file

for file in tsv_files_every_5:
    input_file = os.fsdecode(file)
    species_name = input_file.replace(".tsv", "")
    if input_file.endswith(".tsv"):

        nodes = import_from_tsv(directory_of_tsv + '\\' + input_file)[0]

        input_file_name = input_file.replace(".tsv", "")
        real_tree_lists = import_real_tree(file_real_trees)
        # species = re.sub(r'^\d+_', '', input_file_name)
        species = re.sub(r'^[^_]+_|_[^_]+$', '', input_file_name)
        species_list = [sp[0] for sp in real_tree_lists]
        index_tree = species_list.index(species)
        newick_real = prune(real_tree_lists[index_tree][3])  

        branch_length, branch_std = calculate_branch_length(newick_real)

        if branch_length > 1.3718 and branch_length < 1.3720:

            print('----------------------------------')
            print('Species', species_name)

            # header = ["", "", 'Del_0', 'Del_60', 'Del_70', 'Del_80']

            # with open('all_comm_results.txt', "a") as f:

            #     f.write(f"Species: {species_name}\n")
            #     f.write('\t'.join(header) + '\n')

            for perc_filter in filter_list:

                del_results = []
                time_results = []

                for del_comb in del_list:

                    # print('Species', input_file)
                    # print('Filter:', perc_filter)
                    # print('Del:', del_comb)

                    repetitions = []

                    for i in range(5):

                        start_time = time.time()
                        branch_length, branch_std, percentage, data_annealers = main_local(directory_of_tsv, input_file, file_real_trees, del_filter=perc_filter, del_comb = del_comb, prints=False)
                        end_time = time.time()

                        repetitions.append(percentage)

                    mean_percentage = sum(repetitions)/len(repetitions)
                    std_percentage = np.std(repetitions)
                        
                    elapsed_time = end_time - start_time

                    del_results.append(mean_percentage)
                    time_results.append(elapsed_time)

                    print('Percentage_', round(perc_filter*100), percentage)
                    print('Std', std_percentage)

                    # Example usage:
                    data_rows = [
                        [species_name, 
                         round(branch_length, 4),
                         round(branch_std, 4), 
                         len(nodes),
                         int(perc_filter*100),
                         round(mean_percentage, 4),
                         round(std_percentage, 4)
                        ]
                    ]
                    # write_branch_percentage_table(output_file, data_rows=data_rows)
                    

                    # annealer_file_name = input_file.replace(".tsv", "")+ '_DA_times_' + str(int(del_comb*100)) + '_' + str(int(perc_filter*100)) + '.txt'

                    # write_annealer_times(annealer_file_name, species_name, data_annealers)

                # write_percentages('all_comm_results.txt', perc_filter, del_results, time_results)
                
                # print('Percentage_', percentage)


            # with open('all_comm_results.txt', "a") as f:

                # f.write('-------------------------------------------------------------------\n')


### D-WAVE

In [None]:
# Hybrid Solver
from dwave.system import LeapHybridSampler

# D-Wave key
os.environ['DWAVE_API_TOKEN']='Zsos-7d4d4375efae3e72f98e2ab4abc2d127aca9730e'

In [56]:
def as_bqm(self) -> 'dimod.BinaryQuadraticModel':
    """
    Function that converts the QUBO problem into a Binary Quadratic Model (BQM) for D-Wave.

    Args:
    self: QUBO problem instance.

    Returns:
    dimod.BinaryQuadraticModel: BQM representation of the QUBO problem.
    
    """

    try:
        import dimod
    except Exception as oops:
        print('\n\n' + (100 * '#'))
        print('pip install dwave-ocean-sdk')
        print((100 * '#') + '\n\n')
        raise oops

    return dimod.BinaryQuadraticModel(
        {i0: self._p1[i0] for i0 in self._p1},
        {(i0, i1): self._p2[i0][i1] for i0 in self._p2 for i1 in self._p2[i0]},
        self._p0,
        dimod.BINARY)

In [50]:
def unflatten_solution(flat, n_i, n_k):
    return [flat[i*n_k : (i+1)*n_k] for i in range(n_i)]

The one-hot-groups restriction now becomes a penalization term

$$
\sum_{k=1}^K x_{i,k} = 1 \qquad i \in \{1, ..., N\}
$$

$$
\sum_{i=1}^N \left( \sum_{k=1}^K x_{i,k} -1 \right)^2
$$

In [None]:
def get_qubo_dwave(A_matrix,K_vars, alpha, beta):
    """
    Get the QUBO object (D-Wave) for the given matrix and number of clusters K.
    """
    fix_varshapeset(A_matrix, K=K_vars)
    
    # Build the QUBO matrix
    qubo = BinPol()
    N = A_matrix.shape[0]
    N_big = A_matrix.shape[1]
    # N_big = 2*N
    for k in range(K_vars):
        for i in range(N):
            for j in range(i+1, N_big):
                qubo.add_term(A_matrix[i,j], ('x', i, k), ('x', j, k)) # d_ij * x_i * x_j
                # if i < N:
                #     qubo.add_term(A_matrix[i,j], ('x', i, k), ('x', j, k)) # d_ij * x_i * x_j
                # elif i >= N and j < N:
                #     qubo.add_term(A_matrix[j,i], ('x', i, k), ('x', j, k))

    q_dist = qubo.clone()

    # build the penalization term 

    H_alpha = BinPol()

    for k in range(K_vars):
        H_aux = BinPol()
        for i in range(N_big):
            H_aux.add_term(1, ('x', i, k))
        H_aux.add_term(-2, ())
        H_aux.power(2)
        H_alpha.add(H_aux)
        
    H_alpha.multiply_scalar(alpha)

    qubo = qubo.add(H_alpha)


    # build the one-hot-groups term as penalization

    H_one_hot = BinPol()

    for i in range(N):
        H_aux = BinPol()
        for k in range(K_vars):
            H_aux.add_term(1, ('x', i, k))
        H_aux.add_term(-1, ())
        H_aux.power(2)
        H_one_hot.add(H_aux)
        
    H_one_hot.multiply_scalar(beta)

    qubo = qubo.add(H_one_hot)


    return(qubo, q_dist, H_alpha, H_one_hot)

In [None]:
def QUBO_dwave(nj_matrix, K_vars):

    """   
    Function that solves the QUBO problem using D-Wave's Leap Hybrid Solver.

    Args:
    nj_matrix (np.ndarray): NJ matrix including both real and combined nodes.
    K_vars (int): Number of communities (clusters) to form.

    Returns:
    comms (list): List of communities with the corresponding node indexes.
    timing_dict (dict): Dictionary with timing information from the D-Wave solver.
    best.energy (float): Energy of the best solution found by the solver.
    
    """

    # old penalization term
    alpha = 1000
    # new penalization term (previously one-hot)
    beta = 1000

    N = nj_matrix.shape[1]
    
    qubo, q_dist, H_alpha, H_beta = get_qubo_dwave(nj_matrix, K_vars, alpha, beta)

    print('number x_i', N)
    # print('qubo x_i', qubo.var_shape_set.get_symbolic(i))
    print('number x_k = k_vars', K_vars)
    print('total num of variables', qubo.N)


    bqm=qubo.as_bqm()

    sampler = LeapHybridSampler()    
    answer = sampler.sample(bqm)

    best = min(answer.data(['sample', 'energy']), key=lambda d: d.energy)

    # print("Best sample:", best.sample)
    # print("Best energy:", best.energy)

    solution_dwave = list(best.sample.values())

    solution_list_dwave = [int(x) for x in solution_dwave]
    print([qubo.var_shape_set.get_symbolic(i) for i,sol_i in enumerate(solution_list_dwave) if sol_i==1])
    ### comparacion con da ###

    solver_args = {
                    'optimization_method':'annealing',
                    'number_iterations':30000,
                    'number_runs':10,
                    'scaling_bit_precision':32,   ### CAMBIAR AQUI A 32 O 16
                    'scaling_action':ScalingAction.AUTO_SCALING,
                    }
                    # 'graphics':GraphicsDetail.SINGLE

    solver = QUBOSolverCPU(**solver_args)

    solution_list_da = solver.minimize(qubo)

    ##############################
    print('solution_dwave \n', best.sample, best.energy)
    print('solution_list_dwave \n', solution_list_dwave)
    print('variables_dwave', len(solution_dwave))
    print('solution_list_da \n', solution_list_da.min_solution['x'].data)

    unflattened_sol_dwave = unflatten_solution(solution_list_dwave, N, K_vars)

    print('unflatted_sol_dwave \n', np.array(unflattened_sol_dwave))
    print('xk_dwave', len(unflattened_sol_dwave[0]))
    print('xi_dwave', len(unflattened_sol_dwave))

    comms = get_communities(nj_matrix, K_vars, unflattened_sol_dwave) 

    for i in range(len(comms)):
        if len(comms[i]) != 2:
            alpha = alpha*10
            qubo, q_dist, H_alpha, H_beta = get_qubo_dwave(nj_matrix, K_vars, alpha, beta)
            bqm=qubo.as_bqm()
            sampler = LeapHybridSampler()    
            answer = sampler.sample(bqm)
            best = min(answer.data(['sample', 'energy']), key=lambda d: d.energy)
            solution_dwave = list(best.sample.values())
            solution_list_dwave = [int(x) for x in solution_dwave]
            unflattened_sol_dwave = unflatten_solution(solution_list_dwave, N, K_vars)

    
    comms = get_communities(nj_matrix, K_vars, unflattened_sol_dwave) 


    timing_dict = dict(answer.info)


    return(comms, timing_dict, best.energy)

#### D-Wave pipeline

In [None]:
"""   
Main D-Wave execution block

"""

directory_of_tsv = r'C:\Users\alfonsorodrr\OneDrive - FUJITSU\Escritorio\FUJITSU\QCare - Árboles filogenéticos\Nicolás - benchmarking\.tsv\Schizosaccharomyces_pombe'
directory_output = r'C:\Users\alfonsorodrr\OneDrive - FUJITSU\Escritorio\FUJITSU\QCare - Árboles filogenéticos\Nicolás - benchmarking\newick outputs\Schizosaccharomyces_pombe'
os.chdir(directory_output)

# directory = r'C:\Users\alfonsorodrr\OneDrive - FUJITSU\Escritorio\FUJITSU\QCare - Árboles filogenéticos\Nicolás - benchmarking\newick outputs\Bos_taurus'
# input_file = '0-2516_Phy0001SS5_BOVIN_nd.tsv'

for file in os.listdir(directory_of_tsv):
    input_file = os.fsdecode(file)
    if input_file.endswith(".tsv"):

        species = input_file.replace(".tsv", "")
        filename = species + '_iter1'
        filtered = True
        del_filter = 0.8
        old_nodes, bitscore_matrix = import_from_tsv(directory_of_tsv + '\\' + species + '.tsv')
        bitscore_inverse = 100 - bitscore_matrix
        initial_len_nodes = len(old_nodes)
        result = None

        for it in range(1, initial_len_nodes): 

            if result != None:

                old_filename = filename
                filename = species + '_iter' + str(it)
                filtered = False
                old_nodes, old_adj_matrix, nj_matrix, comms, pair_sublist, result = load_intermediate_solutions(old_filename + '.json')
                removed, added, new_nodes, new_bitscore = update_nodes_adj_matrix(old_nodes, pair_sublist, np.array(old_adj_matrix))

                old_nodes = new_nodes
                bitscore_inverse = new_bitscore

            comb_index_list, adj_matrix, nj_matrix = main_calculate_nj(bitscore_inverse, 0.8)

            if len(old_nodes)>=4:
                K_vars = round(len(old_nodes)*3/4)
            elif len(old_nodes)==3:
                K_vars = 2
                filtered = False
            elif len(old_nodes)==2:
                K_vars = 1
            else:
                break

            comms, info_dict, min_energy = QUBO_dwave(nj_matrix, K_vars)
            comms2, an_time2, an_energy2 = QUBO_local(nj_matrix, K_vars, filename)

            # make sure that all comms have len2
            for i in range(len(comms)):
                if len(comms[i])>2:
                    comms[i] = [comms[i][0], comms[i][1]]
                elif len(comms[i])<2:
                    comms[i] = [comms[i][0], len(comms)]

            if result is None:
                    result = {}
                    data = []
                    for i in range(len(old_nodes)):
                        result[old_nodes[i]] = ''

            print('filename', filename)
            print('len_nodes', len(old_nodes))
            print('comms_annealer', comms)
            print('comms_local', comms2)
            print('annealer_times', info_dict)
            print('annealer_energies', min_energy)

            comms1_sorted, energies1_sorted = order_comms(len(old_nodes), comms, nj_matrix)
            comms2_sorted, energies2_sorted = order_comms(len(old_nodes), comms2, nj_matrix)

            print('########## DWAVE ############')
            for i in range(len(comms1_sorted)):
                print(comms1_sorted[i], energies1_sorted[i])

            print('########## LOCAL ############')
            for i in range(len(comms2_sorted)):
                print(comms2_sorted[i], energies2_sorted[i])

            if filtered == True: 
                filtered_comms = filter_comms(len(old_nodes), comms, nj_matrix, del_filter)
                pair_sublist = create_pair_sublist(old_nodes, filtered_comms)
            else:
                pair_sublist = create_pair_sublist(old_nodes, comms)

            for i in range(len(pair_sublist)):
                if isinstance(pair_sublist[i][0], list):
                    for j in range(len(pair_sublist[i][0])):
                        result[pair_sublist[i][0][j]] += '0'
                elif isinstance(pair_sublist[i][0], list) == False:
                    result[pair_sublist[i][0]] += '0'

                if isinstance(pair_sublist[i][1], list):
                    for j in range(len(pair_sublist[i][1])):
                        result[pair_sublist[i][1][j]] += '1'
                elif isinstance(pair_sublist[i][1], list) == False:
                    result[pair_sublist[i][1]] += '1'

            num_variables = K_vars*nj_matrix.shape[1]

            save_intermediate_solutions(filename + '.json', old_nodes, adj_matrix, nj_matrix, comms, pair_sublist, result)
            save_times_energies(filename + '_times_energies.json', len(old_nodes), num_variables, info_dict, min_energy)

        # outside of the loop
        result = load_intermediate_solutions(old_filename + '.json')[5]
        swapped_result =  {value[::-1]: [str(key)] for key, value in result.items()}
        newick_reconstructed = dict_to_newick(swapped_result, branch_length=1.0)

        with open(species + '_tree.txt', 'w') as file:
            file.write(f"{newick_reconstructed}\n")


number x_i 32
number x_k = k_vars 10
total num of variables 320
[('x', 0, 7), ('x', 1, 3), ('x', 2, 7), ('x', 3, 6), ('x', 4, 6), ('x', 5, 4), ('x', 6, 8), ('x', 7, 8), ('x', 8, 4), ('x', 9, 3), ('x', 10, 1), ('x', 11, 0), ('x', 12, 1), ('x', 13, 0), ('x', 19, 9), ('x', 22, 9), ('x', 25, 2), ('x', 29, 2), ('x', 29, 5), ('x', 31, 5)]

********************************************************************************
  temperature_start:                       7.30816572e+02
  temperature_end:                         1.22691322e+02
  offset_increase_rate:                    3.10474874e+02
  duration:                                0.036 sec
********************************************************************************

  max_abs_coefficient:                   5094.425

********************************************************************************
Effective values (including scaling factor)
  scaling_factor:                           4.21493000e+05
  temperature_start:                   

#### DA pipeline

In [None]:
"""   
Main Digital Annealer execution block

"""

directory_of_tsv = r'C:\Users\alfonsorodrr\OneDrive - FUJITSU\Escritorio\FUJITSU\QCare - Árboles filogenéticos\Nicolás - benchmarking\.tsv\Schizosaccharomyces_pombe'
directory_output = r'C:\Users\alfonsorodrr\OneDrive - FUJITSU\Escritorio\FUJITSU\QCare - Árboles filogenéticos\Nicolás - benchmarking\newick outputs - DA\Schizosaccharomyces_pombe'
os.chdir(directory_output)

# directory = r'C:\Users\alfonsorodrr\OneDrive - FUJITSU\Escritorio\FUJITSU\QCare - Árboles filogenéticos\Nicolás - benchmarking\newick outputs\Bos_taurus'
# input_file = '0-2516_Phy0001SS5_BOVIN_nd.tsv'

for file in os.listdir(directory_of_tsv):
    input_file = os.fsdecode(file)
    if input_file.endswith(".tsv"):

        species = input_file.replace(".tsv", "")
        filename = species + '_iter1'
        filtered = True
        del_filter = 0.8
        old_nodes, bitscore_matrix = import_from_tsv(directory_of_tsv + '\\' + species + '.tsv')
        bitscore_inverse = 100 - bitscore_matrix
        initial_len_nodes = len(old_nodes)
        result = None

        for it in range(1, initial_len_nodes): 

            if result != None:

                old_filename = filename
                filename = species + '_iter' + str(it)
                filtered = False
                old_nodes, old_adj_matrix, nj_matrix, comms, pair_sublist, result = load_intermediate_solutions(old_filename + '.json')
                removed, added, new_nodes, new_bitscore = update_nodes_adj_matrix(old_nodes, pair_sublist, np.array(old_adj_matrix))

                old_nodes = new_nodes
                bitscore_inverse = new_bitscore

            comb_index_list, adj_matrix, nj_matrix = main_calculate_nj(bitscore_inverse, 0.8)

            if len(old_nodes)>=4:
                K_vars = round(len(old_nodes)*3/4)
                time_limit = 10
            elif len(old_nodes)==3:
                K_vars = 2
                filtered = False
                time_limit = 2
            elif len(old_nodes)==2:
                K_vars = 1
                filtered = False
                time_limit = 1
            else:
                break

            comms, info_dict, min_energy = QUBO_annealer(nj_matrix, K_vars, time_limit, filename)
            comms2, an_time2, an_energy2 = QUBO_local(nj_matrix, K_vars, filename)

            # make sure that all comms have len2
            for i in range(len(comms)):
                if len(comms[i])>2:
                    comms[i] = [comms[i][0], comms[i][1]]
                elif len(comms[i])<2:
                    comms[i] = [comms[i][0], len(comms)]

            if result is None:
                    result = {}
                    data = []
                    for i in range(len(old_nodes)):
                        result[old_nodes[i]] = ''

            print('filename', filename)
            print('len_nodes', len(old_nodes))
            print('comms_annealer', comms)
            print('comms_local', comms2)
            print('annealer_times', info_dict)
            print('annealer_energies', min_energy)

            comms1_sorted, energies1_sorted = order_comms(len(old_nodes), comms, nj_matrix)
            comms2_sorted, energies2_sorted = order_comms(len(old_nodes), comms2, nj_matrix)

            print('########## NNEALER ############')
            for i in range(len(comms1_sorted)):
                print(comms1_sorted[i], energies1_sorted[i])

            print('########## LOCAL ############')
            for i in range(len(comms2_sorted)):
                print(comms2_sorted[i], energies2_sorted[i])

            if filtered == True: 
                filtered_comms = filter_comms(len(old_nodes), comms, nj_matrix, del_filter)
                pair_sublist = create_pair_sublist(old_nodes, filtered_comms)
            else:
                pair_sublist = create_pair_sublist(old_nodes, comms)

            for i in range(len(pair_sublist)):
                if isinstance(pair_sublist[i][0], list):
                    for j in range(len(pair_sublist[i][0])):
                        result[pair_sublist[i][0][j]] += '0'
                elif isinstance(pair_sublist[i][0], list) == False:
                    result[pair_sublist[i][0]] += '0'

                if isinstance(pair_sublist[i][1], list):
                    for j in range(len(pair_sublist[i][1])):
                        result[pair_sublist[i][1][j]] += '1'
                elif isinstance(pair_sublist[i][1], list) == False:
                    result[pair_sublist[i][1]] += '1'

            num_variables = K_vars*nj_matrix.shape[1]

            save_intermediate_solutions(filename + '.json', old_nodes, adj_matrix, nj_matrix, comms, pair_sublist, result)
            save_times_energies(filename + '_times_energies.json', len(old_nodes), num_variables, info_dict, min_energy)

        # outside of the loop
        result = load_intermediate_solutions(old_filename + '.json')[5]
        swapped_result =  {value[::-1]: [str(key)] for key, value in result.items()}
        newick_reconstructed = dict_to_newick(swapped_result, branch_length=1.0)

        with open(species + '_tree.txt', 'w') as file:
            file.write(f"{newick_reconstructed}\n")