In [119]:
import numpy as np
import itertools
from Bio import SeqIO


class Action:
    MATCH = 1
    GAP_I = 2
    GAP_J = 3
    TWO_GAPS = 4


class MultiSequenceAligner:

    def __init__(self,
                 match: int = 5,
                 mismatch: int = -2,
                 indel: int = -4,
                 two_gaps: int = 0):
        """
        :param match: Score if AA match
        :param mismatch: Score if AA mismatch
        :param indel: Linear gap penalty (applied for every gap)
        :param two_gaps: Additional gap penalty applied if two or more sequences have a gap at given position
        """
        self.match = match
        self.mismatch = mismatch
        self.indel = indel
        self.two_gaps = two_gaps

    def get_alignment_matrix(self, sequences: list[str], method: str = "global") -> tuple[np.ndarray, np.ndarray]:
        """
        Returns a matrix of scores for all possible alignments of the given sequences and a backtrack matrix.
        :param sequences: List of sequences to align
        :param method: Method to use for alignment. Either "global" or "local"
        :return: Matrix of scores for all possible alignments of the given sequences and a backtrack matrix
        """
        assert method in ["global", "local"], "Method must be either 'global' or 'local'"

        # Initialize matrix
        # n-dimensional matrix where n is the number of sequences, each dimension is the length of the sequence + 1 (for the inital gap)
        dimensions = [len(sequence) + 1 for sequence in sequences]
        matrix = np.zeros(dimensions)

        # Backtrack matrix contains taken direction to get the score
        backtrack_dimensions = dimensions + [len(sequences)]
        backtrack_matrix = np.zeros(backtrack_dimensions)

        # Fill matrix
        idx_value_pairs = np.ndenumerate(matrix)
        # sort on sum of indices to ensure that the first value is the one with the lowest index sum
        idx_value_pairs = sorted(idx_value_pairs, key=lambda x: sum(x[0]))
        for idx, value in idx_value_pairs:
            score, neighbour = self._get_score(idx, sequences, matrix, method)
            matrix[idx] = score
            backtrack_matrix[idx] = neighbour

        return matrix, backtrack_matrix

    def _get_score(self, idx: tuple[int], sequences: list[str], matrix: np.ndarray, method: str) -> tuple[
        int, tuple[int]]:
        """
        Returns the score for the given position in the matrix. The score is s completely analogous to the pairwise
        case, only now the scores for each position are equal to the sum of the individual
        pairwise comparisons (i.e. a position that is identical for three sequences has a
        score of 5s1,s2 + 5s1,s3 + 5s2,s3 = 15).

        :param idx: Index of the position in the matrix
        :param sequences: List of sequences to align
        :param matrix: Partially completed matrix of scores for all possible alignments of the given sequences
        :param method: Method to use for alignment. Either "global" or "local"
        :return: Score for given position in the matrix and the index of the neighbour that generated the score.
        """
        # if idx all zeros, return 0
        if all([i == 0 for i in idx]):
            return 0, None

        num_sequences = len(sequences)

        neighbours = self._get_preceeding_neigbours(idx)

        # get scores for all neighbours
        scores = []
        for neighbour in neighbours:

            # if any of the indices is negative, return -inf
            if any([i < 0 for i in neighbour]):
                scores.append(-np.inf)
                continue

            pairs = self._get_pairs(num_sequences)
            total = matrix[neighbour]
            for i, j in pairs:
                score, action = self._get_pairwise_score(neighbour, i, j, matrix, sequences, idx)
                total += score
            scores.append(total)

        max_score = max(scores)
        max_idx = scores.index(max_score)
        max_score_neighbour = neighbours[max_idx]

        if method == "global":
            return max_score, max_score_neighbour
        elif method == "local":
            return max(max_score, 0), max_score_neighbour if max_score > 0 else None
        else:
            raise ValueError("Method must be either 'global' or 'local'")

    def _get_preceeding_neigbours(self, idx: tuple[int]) -> list[tuple[int]]:
        """
        Returns a list of all possible preceeding neighbours of the given index. A neighbour is defined as a tuple of indices
        where each index is either the same as the given index or one less.
        :param idx: Index of the position in the matrix
        :return: List of all possible preceeding neighbours of the given index
        """
        num_sequences = len(idx)
        all_neighbours = list(itertools.product([0, -1], repeat=num_sequences))
        all_neighbours.remove(tuple([0 for _ in range(num_sequences)]))
        neighbours = [tuple(np.array(idx) + np.array(neighbour)) for neighbour in all_neighbours]
        return neighbours

    def _get_pairwise_score(self, neighbour: tuple[int], i: int, j: int, matrix: np.ndarray, sequences: list[str],
                            idx: tuple[int]) -> tuple[int, int]:
        """"
        Returns the score of a pair for given neigbour to idx transition
        :param neighbour: Neighbour position
        :param i: Index of first sequence
        :param j: Index of second sequence
        :param matrix: Partially completed matrix of scores for all possible alignments of the given sequences
        :param sequences: List of sequences to align
        :param idx: score for given pair and action that generated the score
        """
        neighbour_i = neighbour[i]
        neighbour_j = neighbour[j]

        idx_i = idx[i]
        idx_j = idx[j]

        diff_i = idx_i - neighbour_i
        diff_j = idx_j - neighbour_j

        # determin the action based on diff_i and diff_j
        if diff_i == 1 and diff_j == 1:
            action = Action.MATCH
        elif diff_i == 1 and diff_j == 0:
            action = Action.GAP_J
        elif diff_i == 0 and diff_j == 1:
            action = Action.GAP_I
        elif diff_i == 0 and diff_j == 0:
            action = Action.TWO_GAPS
        else:
            raise ValueError(f"Invalid offset: {diff_i}, {diff_j}")

        if action == Action.MATCH:
            # match
            aa_i = sequences[i][idx_i - 1]
            aa_j = sequences[j][idx_j - 1]
            score = self.match if aa_i == aa_j else self.mismatch

        elif action == Action.GAP_I:
            # gap in j
            score = self.indel
        elif action == Action.GAP_J:
            # gap in i
            score = self.indel
        elif action == Action.TWO_GAPS:
            # two gaps
            score = self.two_gaps
        else:
            raise ValueError("Invalid offset")

        return score, action

    def _get_pairs(self, n):
        """
        Returns all possible pairs of n numbers (i.e. for n=3, returns [(0,1), (0,2), (1,2)])
        :param n: Number of numbers
        :return: List of all possible pairs of n numbers
        """
        pairs = []
        for i in range(n):
            for j in range(i + 1, n):
                pairs.append((i, j))
        return pairs

    def alignment(self, sequences: list[str], method: str):
        """
        Returns the optimal alignment of the given sequences
        :param sequences: List of sequences to align
        :param method: Method to use for alignment, either "global" or "local"
        :return: List of aligned sequences
        """
        return self.global_alignment(sequences) if method == "global" else self.local_alignment(sequences)

    def global_alignment(self, sequences: list[str]):
        """
        Returns the optimal global alignment of the given sequences
        :param sequences: List of sequences to align
        :return: List of aligned sequences
        """
        alignment_matrix, backtrack_matrix = self.get_alignment_matrix(sequences, method="global")
        return self._get_alignment(sequences, alignment_matrix, backtrack_matrix, method="global")

    def local_alignment(self, sequences: list[str]):
        """
        Returns the optimal local alignment of the given sequences
        :param sequences: List of sequences to align
        :return: List of aligned sequences
        """
        alignment_matrix, backtrack_matrix = self.get_alignment_matrix(sequences, method="local")
        return self._get_alignment(sequences, alignment_matrix, backtrack_matrix, method="local")

    def _get_alignment(self, sequences: list[str], alignment_matrix: np.ndarray, backtrack_matrix: np.ndarray,
                       method: str):
        """
        Backtraces through the alignment matrix to get the optimal alignment of the given sequences
        :param sequences: List of sequences to align
        :param alignment_matrix: Matrix of scores for all possible alignments of the given sequences
        :param backtrach_matrix: Matrix of backtrach directions for all possible alignments of the given sequences
        :param method: Method to use for alignment. Either "global" or "local"
        :return: List of aligned sequences
        """
        aligned_sequences_reversed = [""] * len(sequences)

        current_position = [len(sequence) for sequence in sequences]
        if method == "local":
            current_position = np.unravel_index(np.argmax(alignment_matrix), alignment_matrix.shape)

        while True:

            previous_neighbour = backtrack_matrix[tuple(current_position)]

            for sequence_idx in range(len(sequences)):
                if current_position[sequence_idx] == previous_neighbour[sequence_idx]:
                    aligned_sequences_reversed[sequence_idx] += "."
                else:
                    aligned_sequences_reversed[sequence_idx] += sequences[sequence_idx][
                        current_position[sequence_idx] - 1]

            if self._backtrack_break_condition(previous_neighbour, method, alignment_matrix):
                break

            current_position = previous_neighbour

            # change the floats (e.g. 2.000000) in current_position to ints
            current_position = [int(i) for i in current_position]

        aligned_sequences = [sequence[::-1] for sequence in aligned_sequences_reversed]
        return aligned_sequences

    def _backtrack_break_condition(self, previous_neighbour, method, alignment_matrix):
        if method == "global":
            return np.all(previous_neighbour == 0)
        elif method == "local":
            previous_neighbour = [int(i) for i in previous_neighbour]
            return np.all(alignment_matrix[tuple(previous_neighbour)] == 0)
        else:
            raise ValueError(f"Invalid method: {method}")

    def _get_match_score(self, aa_i: str, aa_j: str) -> float:
        """
        Returns the score for matching the given amino acids (or other characters)
        :param aa_i: Amino acid i
        :param aa_j: Amino acid j
        :return: Score for matching the given amino acids
        """
        if aa_i is None or aa_j is None:
            return -np.inf
        return self.match if aa_i == aa_j else self.mismatch

    def align_fasa(self, input_file: str, output_file: str = None, method: str = "global") -> None:
        """
        Aligns the sequences in the given fasta file and writes the result to the given output file. If no output file is given, the result is printed to the console.
        :param input_file: Path to the fasta file containing the sequences to align
        :param output_file: Path to the output file to write the result to
        :param method: Method to use for alignment. Either "global" or "local"
        """
        sequence_dict: dict = self._read_fasta(input_file)

        ids, sequences = zip(*sequence_dict.items())

        aligned_sequences = self.alignment(sequences, method=method)

        aligned_sequence_dict = dict(zip(ids, aligned_sequences))

        output_str = "\n".join([f"{id}: {sequence}" for id, sequence in aligned_sequence_dict.items()])

        if output_file is None:
            print(output_str)
        else:
            with open(output_file, "w") as f:
                f.write(output_str)

    def _read_fasta(self, input_file: str) -> dict[str, str]:
        """
        Reads the given fasta file and returns a dictionary containing the ids and sequences
        :param input_file: Path to the fasta file
        :return: Dictionary containing the ids and sequences
        """
        sequence_dict = {}
        for record in SeqIO.parse(input_file, "fasta"):
            sequence_dict[record.id] = str(record.seq)
        return sequence_dict

In [120]:
def print_alignments(sequences, method='global', params=None):
    if params is None:
        params = {}
    aligner = MultiSequenceAligner(**params)
    alignments = aligner.alignment(sequences, method=method)
    print('\n'.join(alignments))

In [121]:
print_alignments(["ACTGGTCA", "CAGGGTCA", "CCAGGGACCA"])

ACTGG.TC.A
.CAGGGTC.A
CCAGGGACCA


In [122]:
aligner = MultiSequenceAligner()
seqs = ["ABC", "AC", "AB"]
alignment_matrix, backtrack_matrix = aligner.get_alignment_matrix(seqs, method="global")
alignment_matrix

array([[[  0.,  -8., -16.],
        [ -8.,  -3., -11.],
        [-16., -11., -13.]],

       [[ -8.,  -3., -11.],
        [ -3.,  15.,   7.],
        [-11.,   7.,   5.]],

       [[-16., -11.,  -6.],
        [-11.,   7.,  12.],
        [-13.,   5.,  16.]],

       [[-24., -19., -14.],
        [-19.,  -1.,   4.],
        [-14.,   4.,   9.]]])

In [123]:
aligner.alignment(seqs, method="global")

['ABC', 'A.C', 'AB.']

In [124]:
print_alignments(["ABCDEF", "CD"], method="local")

CD
CD


In [126]:
print_alignments(["ABC", "AC", "AB", "BCD"], method="local")

AB
AC
AB
.B
