def generate_pairs(alignments):
    pairs = []
    for i in range(len(alignments[0])):  
        column = [alignment[i] for alignment in alignments]
        for j in range(len(column)):
            for k in range(j + 1, len(column)):  
                if column[j] != '-' and column[k] != '-':  
                    pairs.append((column[j], column[k]))
    return pairs

def count_pair_frequencies(pairs):
    pair_counts = {}
    for pair in pairs:
        normalized_pair = tuple(sorted(pair))
        if normalized_pair in pair_counts:
            pair_counts[normalized_pair] += 1
        else:
            pair_counts[normalized_pair] = 1
    return pair_counts

def calculate_frequencies(alignments):
    counts = {}
    total_count = 0
    for sequence in alignments:
        for nucleotide in sequence:
            if nucleotide != '-': 
                if nucleotide in counts:
                    counts[nucleotide] += 1
                else:
                    counts[nucleotide] = 1
                total_count += 1
    frequencies = {nucleotide: count / total_count for nucleotide, count in counts.items()}
    return frequencies

def calculate_scores(pair_counts, freqs, scale=3):
    total_pairs = sum(pair_counts.values())
    scores = {}
    for (x, y), observed_count in pair_counts.items():
        observed_freq = observed_count / total_pairs
        expected_freq = freqs.get(x, 0) * freqs.get(y, 0)
        if expected_freq > 0:
            score = scale * math.log2(observed_freq / expected_freq)
            scores[(x, y)] = round(score)
        else:
            scores[(x, y)] = 0
    
    return scores

def create_blosum_matrix(scores, nucleotides):
    blosum_matrix = {}
    for nucleotide in nucleotides:
        blosum_matrix[nucleotide] = {}
        for other_nucleotide in nucleotides:
            if nucleotide != other_nucleotide:
                pair1 = (nucleotide, other_nucleotide)
                pair2 = (other_nucleotide, nucleotide)
                if pair1 in scores:
                    blosum_matrix[nucleotide][other_nucleotide] = scores[pair1]
                elif pair2 in scores:
                    blosum_matrix[nucleotide][other_nucleotide] = scores[pair2]
                else:
                    blosum_matrix[nucleotide][other_nucleotide] = 0
            else:
                # Если это сам нуклеотид, записываем его сам по себе (диагональ)
                blosum_matrix[nucleotide][other_nucleotide] = scores.get((nucleotide, nucleotide), 0)
    
    return blosum_matrix

def print_blosum_matrix(blosum_matrix, nucleotides):
    print("    " + "   ".join(nucleotides))
    for row in nucleotides:
        row_values = [f"{blosum_matrix[row][col]:3}" for col in nucleotides]
        print(f"{row}  " + "  ".join(row_values))

def init(m, n, sigma):

    matrix = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(1, m + 1):
        matrix[i][0] = -i * sigma
    for j in range(1, n + 1):
        matrix[0][j] = -j * sigma
    
    return matrix

def fill_matrix(matrix, a, b, blosum_matrix, sigma):

    m, n = len(a), len(b)

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            match_mismatch = matrix[i - 1][j - 1] + blosum_matrix[a[i - 1]][b[j - 1]]
            insertion = matrix[i][j - 1] - sigma
            deletion = matrix[i - 1][j] - sigma
            matrix[i][j] = max(match_mismatch, insertion, deletion)
    
    return matrix

def get_new_score(up, left, middle, s_score, gap_penalty):
    match_mismatch = middle + s_score
    insertion = left - gap_penalty
    deletion = up - gap_penalty
    return max(match_mismatch, insertion, deletion)

def align(top_seq, bottom_seq, gap_penalty, blosum_matrix):
    m, n = len(top_seq), len(bottom_seq)
    A = [[0] * (m + 1) for _ in range(n + 1)]
    for i in range(1, m + 1):
        A[0][i] = -i * gap_penalty
    for j in range(1, n + 1):
        A[j][0] = -j * gap_penalty
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            s_score = blosum_matrix[top_seq[j - 1]][bottom_seq[i - 1]]
            A[i][j] = max(
                A[i - 1][j] - gap_penalty,         
                A[i][j - 1] - gap_penalty,         
                A[i - 1][j - 1] + s_score          
            )
    
    return A

def get_alignment(top_seq, bottom_seq, score_matrix, gap_penalty, blosum_matrix):
    aligned_top = []
    aligned_bottom = []

    i, j = len(bottom_seq), len(top_seq)

    while i > 0 and j > 0:
        current_score = score_matrix[i][j]
        s_score = blosum_matrix[bottom_seq[i - 1]][top_seq[j - 1]]

        if current_score == score_matrix[i - 1][j - 1] + s_score:
            aligned_top.append(top_seq[j - 1])
            aligned_bottom.append(bottom_seq[i - 1])
            i -= 1
            j -= 1
        elif current_score == score_matrix[i][j - 1] - gap_penalty:
            # Gap in bottom_seq
            aligned_top.append(top_seq[j - 1])
            aligned_bottom.append('-')
            j -= 1
        elif current_score == score_matrix[i - 1][j] - gap_penalty:
            # Gap in top_seq
            aligned_top.append('-')
            aligned_bottom.append(bottom_seq[i - 1])
            i -= 1

    while j > 0:
        aligned_top.append(top_seq[j - 1])
        aligned_bottom.append('-')
        j -= 1

    while i > 0:
        aligned_top.append('-')
        aligned_bottom.append(bottom_seq[i - 1])
        i -= 1

    aligned_top = ''.join(reversed(aligned_top))
    aligned_bottom = ''.join(reversed(aligned_bottom))

    return aligned_bottom, aligned_top