# NB1: Testing things out

## Try out the dataset whats in it? 

In [None]:
from datasets import load_dataset

ds = load_dataset("dotan1111/MSA-nuc-3-seq")

In [None]:
ds['train'][0]

## Run the main script

In [None]:
import mainrun
mainrun.main()

Here is how we can load up a pretrained model. (make the env first to initlailize all these things needed for the agent and then use the internal function load() to feed the weights in)

In [10]:
import torch
from dqn import DQN
from mainrun import make_agent_from_example
from MSADataset import MSADataset
test_ds = MSADataset('test')
seqs, msa_ref = test_ds.sample_row()
agent = make_agent_from_example(seqs)
agent.load("shared_dqn_final")


Loading test split from dotan1111/MSA-nuc-3-seq ...
Filtering MSAs longer than 150 bases ...
âœ… Loaded 2678 examples (max MSA length = 150)
The length of the sequences are [35, 35, 35]
shared_dqn_final has been loaded...


## debugging and testing

In [None]:
def colwise_msa_metrics(pred_aln: str, ref_aln: str):
    """
    Evaluate multiple sequence alignment quality using column-wise precision/recall/F1.

    pred_aln: multiline string where each line is a predicted aligned sequence.
              Example:
                  A-C
                  AG-
                  ACC
    ref_aln:  flattened column-major reference string (e.g. 'AAA--CCGC')
              which encodes the same number of sequences (n_seq).
    """
    # --- Parse predicted alignment ---
    pred_lines = [ln.strip() for ln in pred_aln.strip().split("\n") if ln.strip()]
    n_seq = len(pred_lines)
    assert n_seq > 1, "Need at least 2 sequences"

    L_pred = len(pred_lines[0])
    assert all(len(ln) == L_pred for ln in pred_lines), "Predicted sequences must be same length"

    # Convert predicted lines into column tuples [(A,A,A), (-,-,G), (C,C,C), ...]
    pred_cols = [tuple(pred_lines[row][col] for row in range(n_seq)) for col in range(L_pred)]

    # --- Parse reference alignment (already column-major) ---
    assert len(ref_aln) % n_seq == 0, (
        f"Reference alignment length ({len(ref_aln)}) must be divisible by number of sequences ({n_seq})"
    )
    L_ref = len(ref_aln) // n_seq
    ref_cols = [tuple(ref_aln[c * n_seq + r] for r in range(n_seq)) for c in range(L_ref)]

    print(f"[DEBUG] pred lines is: {pred_lines}")
    print(f"[DEBUG] pred cols is: {pred_cols}")
    print(f"[DEBUG] ref cols is: {ref_cols}")

    # --- Compute column-based matching ---
    min_len = min(len(pred_cols), len(ref_cols))
    TP = sum(pred_cols[i] == ref_cols[i] for i in range(min_len))
    FP = len(pred_cols) - TP
    FN = len(ref_cols) - TP

    precision = TP / (TP + FP + 1e-8)
    recall = TP / (TP + FN + 1e-8)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)
    accuracy = TP / (TP + FP + FN + 1e-8)

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "TP": TP,
        "FP": FP,
        "FN": FN,
        "n_cols_pred": L_pred,
        "n_cols_ref": L_ref,
    }


colwise_msa_metrics(pred_aln = "A-C\nAG-\nACC", ref_aln = "AAA--CCGC")

In [None]:
def reshape_ref_alignment(ref_aln: str, n_seq: int):
    """Convert flattened column-major reference alignment to list of per-sequence strings."""
    assert len(ref_aln) % n_seq == 0, (
        f"Reference alignment length {len(ref_aln)} must be divisible by number of sequences {n_seq}"
    )
    L_ref = len(ref_aln) // n_seq
    # For each sequence, collect positions r, r+n_seq, r+2*n_seq, ...
    ref_lines = ["".join(ref_aln[r + c * n_seq] for c in range(L_ref)) for r in range(n_seq)]
    return ref_lines


import config

def calc_ref_PS(ref_aln: str, n_seq: int):
    """
    Compute the Sum-of-Pairs (SP) score for the reference alignment.
    Uses the same scoring scheme as environment.calc_score().
    """
    seqs = reshape_ref_alignment(ref_aln, n_seq)
    n_cols = len(seqs[0])
    score = 0

    for i in range(n_cols):
        for j in range(n_seq):
            for k in range(j + 1, n_seq):
                a, b = seqs[j][i], seqs[k][i]
                if a == '-' and b == '-':
                    # No score for gap-gap
                    score += 0
                    #print(f"[DEBUG] No score for gap-gap in column {i}, seqs {j} and {k}: {a}, {b}")
                elif a == '-' or b == '-':
                    score += config.GAP_PENALTY
                    #print(f"[DEBUG] Gap penalty applied for column {i}, seqs {j} and {k}: {a}, {b}")
                elif a == b:
                    score += config.MATCH_REWARD
                    #print(f"[DEBUG] Match reward applied for column {i}, seqs {j} and {k}: {a}, {b}")
                else:
                    score += config.MISMATCH_PENALTY
                    #print(f"[DEBUG] Mismatch penalty applied for column {i}, seqs {j} and {k}: {a}, {b}")

    return score

def calc_ref_CS(ref_aln: str, n_seq: int):
    """
    Compute the Column Score (CS) for the reference alignment.
    Returns the fraction of columns where all residues match (excluding gaps).
    """
    seqs = reshape_ref_alignment(ref_aln, n_seq)
    n_cols = len(seqs[0])
    exact_match_count = 0

    for i in range(n_cols):
        col = [seqs[r][i] for r in range(n_seq)]
        if all(c == col[0] for c in col):  # all same base (including gaps)
            exact_match_count += 1

    cs = exact_match_count / n_cols if n_cols > 0 else 0
    return cs

ref_aln = "AAA--CCGC"
n_seq = 3
print(reshape_ref_alignment(ref_aln, n_seq))
print(calc_ref_PS(ref_aln, n_seq))
print(calc_ref_CS(ref_aln, n_seq))
