In [None]:
!pip install biopython

In [None]:

from Bio import Align

In [None]:
!wget http://www.lbgi.fr/balibase/BalibaseDownload/BAliBASE_R1-5.tar.gz

In [None]:
!tar -xvzf BAliBASE_R1-5.tar.gz

In [None]:
import os
from Bio import SeqIO

In [None]:
import os

def get_filenames_without_extension(directory):
    """
    Collect filenames without extensions from a directory.
    """
    filenames = []
    for filename in os.listdir(directory):
        if os.path.isfile(os.path.join(directory, filename)):
            filenames.append(os.path.splitext(filename)[0])
    return list(set(filenames))

# Example usage
# directory = "/content/bb3_release/RV50"
# filenames_without_extension = get_filenames_without_extension(directory)
# print(filenames_without_extension)
# print(len(filenames_without_extension), len(list(set(filenames_without_extension))))


In [None]:
def parse_balibase_directory(directory, extn):
    """
    Parse the BAliBASE dataset to extract sequence pairs and reference alignments.
    """
    alignment_files = []
    for root, _, files in os.walk(directory):
        for file in files:
            print(file)
            if file.endswith(extn):  # BAliBASE uses .tfa files for alignments
                alignment_files.append(os.path.join(root, file))
    return alignment_files

In [None]:
def load_sequences(file):
    """
    Load sequences from a .tfa file.
    """
    with open(file, "r") as f:
        records = list(SeqIO.parse(f, "fasta"))
    seq_dict = {}
    for record in records:
        seq_dict[record.name] = str(record.seq)
        # record.seq = record.seq.upper()
    return seq_dict

In [None]:
def sp_score(reference_alignment, test_alignment):
    """
    Compute the Sum-of-Pairs (SP) score to evaluate alignment quality.
    """
    ref_pairs = set()
    test_pairs = set()

    # Extract pair indices from reference alignment
    for i in range(len(reference_alignment)):
        for j in range(i + 1, len(reference_alignment)):
            ref_seq_i = reference_alignment[i]
            ref_seq_j = reference_alignment[j]
            pair_indices = [
                (k, l)
                for k, l in zip(range(len(ref_seq_i)), range(len(ref_seq_j)))
                if ref_seq_i[k] != '-' and ref_seq_j[l] != '-'
            ]
            ref_pairs.update(pair_indices)

    # Extract pair indices from test alignment
    for i in range(len(test_alignment)):
        for j in range(i + 1, len(test_alignment)):
            test_seq_i = test_alignment[i]
            test_seq_j = test_alignment[j]
            pair_indices = [
                (k, l)
                for k, l in zip(range(len(test_seq_i)), range(len(test_seq_j)))
                if test_seq_i[k] != '-' and test_seq_j[l] != '-'
            ]
            test_pairs.update(pair_indices)

    # Calculate SP score
    sp_score = len(ref_pairs & test_pairs) / len(ref_pairs) if ref_pairs else 0
    return sp_score


def load_reference_alignment(file):
    """
    Load reference alignments from a .msf file (formatted for BAliBASE).  Ref code: https://publish.illinois.edu/msaevaluation/files/2017/07/msf2fasta.txt
    """

    with open(file) as f:
        lines = f.read().splitlines()

    # print(lines)
    datadict={}
    infosection=True
    for i,line in enumerate(lines):
        if not(infosection):
            if any(c.isalpha() for c in line):
                linesplit=line.split()
                taxonname=linesplit[0]
                subseq=''.join(linesplit[1:])
                subseq=subseq.replace('.','-')
                if taxonname in datadict:
                    datadict[taxonname]=datadict[taxonname]+subseq
                else:
                    datadict[taxonname] = subseq
        if line.startswith('//'):
            infosection=False

    return datadict

In [None]:
from tqdm import tqdm
def evaluate_balibase(balibase_dir, gap_open=-11, gap_extend=-1, n_samples_per_file = 5):
    """
    Evaluate Needleman-Wunsch algorithm with affine gap penalties on the BAliBASE dataset.
    """
    filenames_without_extension = get_filenames_without_extension(balibase_dir)

    # tfa_files = parse_balibase_directory(balibase_dir, ".tfa")
    # msf_files = parse_balibase_directory(balibase_dir, ".msf")

    truncated_sequences = []
    full_lenth_sequences = []
    for fname in filenames_without_extension:
        if 'BBS' in fname:
            truncated_sequences.append(fname)
        else:
            full_lenth_sequences.append(fname)


    assert len(truncated_sequences) + len(full_lenth_sequences) == len(filenames_without_extension)

    # print(len(truncated_sequences))
    # print(len(full_lenth_sequences))
    # return




    def compare_with_ref_sequences_and_muscle(fnames):
        """Compare the alignment with MUSCLE for validation of output"""
        total_sp_score = 0
        num_alignments = 0

        for fname in tqdm(fnames):
            tfa_file = os.path.join(balibase_dir, f"{fname}.tfa")
            msf_file = os.path.join(balibase_dir, f"{fname}.msf")
            unaligned_sequences_dict = load_sequences(tfa_file)

            ref_algined_sequences_dict = load_reference_alignment(msf_file)
            # ref_alignment = [str(record.seq) for record in records]

            # print(len(ref_alignment))
            # print(len(sequences))

            # return


            sequence_keys = list(set(unaligned_sequences_dict.keys()))
            ref_keys = list(set(unaligned_sequences_dict.keys()))
            assert sequence_keys == ref_keys

            sequence_keys = sequence_keys[:n_samples_per_file]

            # Pairwise evaluation (could extend to all combinations for larger datasets)

            for i in range(len(sequence_keys)):
                for j in range(i + 1, len(sequence_keys)):

                    seq1 = unaligned_sequences_dict[sequence_keys[i]]
                    seq2 = unaligned_sequences_dict[sequence_keys[j]]

                    # Perform alignment using Needleman-Wunsch with affine gap penalties

                    # try:
                    aligned_seq1, aligned_seq2, score = get_alignment(seq1, seq2)

                    # Compare with reference alignment
                    # print(i, j)
                    ref_alignment = [ref_algined_sequences_dict[sequence_keys[i]], ref_algined_sequences_dict[sequence_keys[j]]]
                    sp = sp_score(ref_alignment, [aligned_seq1, aligned_seq2])

                    # print([aligned_seq1, aligned_seq2])
                    # print(f"SP Score: {sp}", score)




                    muscle_aligner = Align.PairwiseAligner()

                    #we need to give MUSCLE similar scoring parameters as our function
                    muscle_aligner.match_score = 1.0
                    muscle_aligner.mismatch_score = -1.0
                    muscle_aligner.open_gap_score = -11
                    muscle_aligner.extend_gap_score = -1

                    alignments = muscle_aligner.align(seq1, seq2)
                    alignment = alignments[0]

                    #we compare muscle score with our score.
                    assert score == alignment.score

                    total_sp_score += sp
                    num_alignments += 1
                    # except:
                    #     continue

            avg_sp_score = total_sp_score / num_alignments if num_alignments > 0 else 0
        return avg_sp_score


    sp_scores_full_length = compare_with_ref_sequences_and_muscle(full_lenth_sequences)
    sp_scores_truncated = compare_with_ref_sequences_and_muscle(truncated_sequences)

    print(f"Average SP Score on BAliBASE with full-length sequences: {sp_scores_full_length}")
    print(f"Average SP Score on BAliBASE with truncated sequences: {sp_scores_truncated}")





In [None]:
# Import your algorithm here

# from nwag2 import NeedlemanWunschAffineGap

# # needleman_wunsch_fn = NeedlemanWunschAffineGap()

needleman_wunsch_fn = NeedlemanWunschAffine(match_score=1, mismatch_score=-1, gap_open=-11, gap_extend=-1)

def get_alignment(seq1, seq2):
    aligned_seq1, aligned_seq2, score = needleman_wunsch_fn.align(seq1, seq2)
    return aligned_seq1, aligned_seq2, score

evaluate_balibase("/content/bb3_release/RV50")


