# DQN


## Imports and Setup

In [12]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import copy
import os
import csv
from pathlib import Path
from tqdm import tqdm
from collections import Counter
from datasets import load_dataset
from itertools import combinations
from datetime import datetime

# Setup
project_root = Path.cwd().parent if Path.cwd().name == 'notebooks' else Path.cwd()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"The device being used is: {device}")

# Constants
NUCLEOTIDE_MAP = {"A": 1, "T": 2, "C": 3, "G": 4, "-": 5, "PAD": 0}
NUCLEOTIDES = ["PAD", "A", "T", "C", "G", "-"]

# config file info
training_size = 1000
max_nt_num = 150
MAX_MSA_LEN = 50
MAX_N_SEQS = 3


The device being used is: cpu


In [4]:
def masked_epsilon_greedy(q_values: torch.Tensor, valid_mask: np.ndarray, epsilon: float, rng=None) -> int:
    if q_values.ndim > 1:
        q_values = q_values.reshape(-1)
    valid_mask = np.asarray(valid_mask, dtype=bool)
    
    if rng is None:
        rng = np.random
    
    valid_idx = np.flatnonzero(valid_mask)
    
    if valid_idx.size == 0:
        return int(torch.argmax(q_values).item())
    
    if rng.random() < epsilon:
        return int(rng.choice(valid_idx))
    
    q = q_values.detach().cpu().numpy().copy()
    q[~valid_mask] = -np.inf
    return int(np.argmax(q))

In [35]:
def get_expected_alignment(sample):
    for key in ("solution", "aligned", "target"):
        if key in sample and sample[key] is not None:
            result = sample[key]
            if isinstance(result, str):
                return [result]
            return [str(s) for s in result]
    return None

'''def get_expected_gaps(sample):
    if 'n_gaps' in sample and sample['n_gaps'] is not None:
        print("using the n_gaps from sample")
        return sample['n_gaps']
    
    if 'moves' in sample and isinstance(sample['moves'], list):
        print("using the moves from sample")
        return len(sample['moves'])
    
    if 'solution' in sample and sample['solution']:
        if 'start' in sample:
            start_gaps = sum(str(s).count('-') for s in sample['start'])
            solution_gaps = sum(str(s).count('-') for s in sample['solution'])
            return max(0, solution_gaps - start_gaps)
        return sum(seq.count('-') for seq in sample['solution'])
    
    return 20 # why is this default? '''
def get_expected_gaps(sample, scale_factor=1.2, min_gaps=5, max_gaps=150):
    """
    Estimate expected number of gaps to insert during alignment, 
    without using ground-truth solution info.
    """
    '''if 'n_gaps' in sample and sample['n_gaps'] is not None:
        return int(sample['n_gaps'])'''

    if 'start' in sample and isinstance(sample['start'], list):
        lengths = [len(seq) for seq in sample['start']]
        max_len = max(lengths)
        total_missing = sum(max_len - l for l in lengths)
        est_gaps = int(scale_factor * total_missing)
        return int(np.clip(est_gaps, min_gaps, max_gaps))

    return 20


## ReplayMemory

In [6]:
class ReplayMemory:
    def __init__(self, memory_size=1000):
        self.storage = []
        self.memory_size = memory_size
        self.size = 0
    
    def store(self, data: tuple):
        if len(self.storage) == self.memory_size:
            self.storage.pop(0)
        self.storage.append(data)
        self.size = min(self.size + 1, self.memory_size)
    
    def sample(self, batch_size):
        samples = random.sample(self.storage, batch_size)
        state = [s for s, _, _, _, _, _ in samples]
        next_state = [ns for _, ns, _, _, _, _ in samples]
        action = [a for _, _, a, _, _, _ in samples]
        reward = [r for _, _, _, r, _, _ in samples]
        done = [d for _, _, _, _, d, _ in samples]
        next_mask = [m for _, _, _, _, _, m in samples]
        return state, next_state, action, reward, done, next_mask

class ScaledDotProductAttention(nn.Module):
    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
    
    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)
        return output, attn

class PositionalEncoding(nn.Module):
    def __init__(self, d_hid, n_position=200):
        super().__init__()
        self.register_buffer("pos_table", self._get_sinusoid_table(n_position, d_hid))
    
    @staticmethod
    def _get_sinusoid_table(n_position, d_hid):
        positions = torch.arange(n_position).float().unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_hid, 2).float() * (-np.log(10000.0) / d_hid))
        sin = torch.sin(positions * div_term)
        cos = torch.cos(positions * div_term)
        return torch.cat([sin, cos], dim=-1).unsqueeze(0)
    
    def forward(self, x):
        return x + self.pos_table[:, :x.size(1), :].clone().detach()

class SelfAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v, dropout=0.1):
        super().__init__()
        self.d_k = d_k
        self.d_v = d_v
        self.w_qs = nn.Linear(d_model, d_k, bias=False)
        self.w_ks = nn.Linear(d_model, d_k, bias=False)
        self.w_vs = nn.Linear(d_model, d_v, bias=False)
        self.fc = nn.Linear(d_v, d_model, bias=False)
        self.attention = ScaledDotProductAttention(temperature=d_k**0.5)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
    
    def forward(self, q, k, v, mask=None):
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
        residual = q
        
        q = self.w_qs(q).view(sz_b, len_q, 1, self.d_k)
        k = self.w_ks(k).view(sz_b, len_k, 1, self.d_k)
        v = self.w_vs(v).view(sz_b, len_v, 1, self.d_v)
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
        
        if mask is not None:
            mask = mask.unsqueeze(1)
        
        q, attn = self.attention(q, k, v, mask=mask)
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        q = self.dropout(self.fc(q))
        q += residual
        q = self.layer_norm(q)
        return q, attn

class Encoder(nn.Module):
    def __init__(self, n_src_vocab, d_model, n_position, d_k=164, d_v=164, pad_idx=0, dropout=0.1):
        super().__init__()
        self.src_word_emb = nn.Embedding(n_src_vocab, d_model, padding_idx=pad_idx)
        self.position_enc = PositionalEncoding(d_model, n_position=n_position)
        self.dropout = nn.Dropout(p=dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.self_attention = SelfAttention(d_model, d_k, d_v, dropout=dropout)
    
    def forward(self, src_seq, mask):
        enc_output = self.src_word_emb(src_seq)
        enc_output = self.position_enc(enc_output)
        enc_output = self.dropout(enc_output)
        enc_output = self.layer_norm(enc_output)
        enc_output, _ = self.self_attention(enc_output, enc_output, enc_output, mask=mask)
        return enc_output

class QNetwork(nn.Module):
    def __init__(self, num_sequences, max_sequence_length, num_actions, max_action_value, d_model=64):
        super().__init__()
        self.num_sequences = num_sequences
        self.num_rows = num_sequences  # <-- no +2

        dim = self.num_rows * max_sequence_length
        # Encoder: your existing module that maps (B, dim) with an attention/MLP stack
        self.encoder = Encoder(6, d_model, dim)

        # Learnable row embeddings (still helpful)
        self.seq_embedding = nn.Embedding(self.num_rows, d_model)
        nn.init.normal_(self.seq_embedding.weight, 0.0, 0.1)

        self.fc1 = nn.Linear(dim * d_model, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, num_actions)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        B, R, C = x.shape          # R == num_sequences
        x_flat = x.view(B, R * C)
        mask_flat = (x_flat != 0).unsqueeze(1)

        h = self.encoder(x_flat, mask_flat)   # shape can be (B, R*C, d_model) or (B, d_model) per your Encoder
        if h.dim() == 3:
            h = h.view(B, R, C, -1)

        # add per-row embedding (broadcast)
        row_ids = torch.arange(R, device=h.device)
        row_emb = self.seq_embedding(row_ids)[None, :, None, :]  # (1,R,1,d_model)
        h = h + row_emb

        h = h.reshape(B, -1)
        h = F.leaky_relu(self.fc1(h)); h = self.dropout(h)
        h = F.leaky_relu(self.fc2(h)); h = self.dropout(h)
        q = self.fc3(h)
        return q


## DQN Agent


In [7]:
class DQNAgent:    
    def __init__(self, action_number, num_seqs, max_grid, max_value,
                 epsilon=0.8, delta=0.05, decrement_iteration=5,
                 update_iteration=128, batch_size=128, gamma=1.0,
                 learning_rate=0.001, memory_size=1000):
        self.seq_num = num_seqs
        self.max_seq_len = max_grid  # <-- not +1
        self.action_number = action_number

        self.eval_net = QNetwork(num_seqs, self.max_seq_len, action_number, max_value).to(device)
        self.target_net = QNetwork(num_seqs, self.max_seq_len, action_number, max_value).to(device)
        self.target_net.load_state_dict(self.eval_net.state_dict())
        
        self.replay_memory = ReplayMemory(memory_size=memory_size)
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=learning_rate)
        self.loss_func = nn.SmoothL1Loss()
        
        self.batch_size = batch_size
        self.gamma = gamma
        self.update_iteration = update_iteration
        self.update_step_counter = 0
        self.tau = 0.005
        self.use_double_dqn = True
        
        self.initial_epsilon = epsilon
        self.current_epsilon = epsilon
        self.epsilon_end = delta
        self.epsilon_decay = 0.999
        
        self.losses = []
        self.epsilons = []

    def update_epsilon(self):
        if self.update_step_counter < 5000:
            decay_rate = 0.9999
        elif self.update_step_counter < 10000:
            decay_rate = 0.999
        else:
            decay_rate = 0.995
        
        self.current_epsilon = max(self.epsilon_end, self.current_epsilon * decay_rate)
        self.epsilons.append(self.current_epsilon)

    def select_action(self, state, valid_action_mask=None):
        is_random = (random.random() <= self.current_epsilon)
        
        if is_random:
            if valid_action_mask is not None:
                valid_idx = np.flatnonzero(valid_action_mask)
                action = int(np.random.choice(valid_idx)) if len(valid_idx) else random.randrange(self.action_number)
            else:
                action = random.randrange(self.action_number)
        else:
            self.eval_net.eval()
            with torch.no_grad():
                s = torch.as_tensor(state, dtype=torch.long, device=device).view(1, self.seq_num, self.max_seq_len)
                q = self.eval_net(s).squeeze(0).detach().cpu().numpy()
            self.eval_net.train()
            
            if valid_action_mask is not None:
                q[~valid_action_mask] = -np.inf
            
            action = int(np.argmax(q))
        
        return action

    def predict(self, state, valid_action_mask=None):
        self.eval_net.eval()
        with torch.no_grad():
            s = torch.as_tensor(state, dtype=torch.long, device=device).view(1, self.seq_num, self.max_seq_len)
            q = self.eval_net(s).squeeze(0).detach().cpu().numpy()
        self.eval_net.train()
        
        if valid_action_mask is not None:
            q[~valid_action_mask] = -np.inf
        
        return int(np.nanargmax(q))

    def forward(self, state):
        self.eval_net.eval()
        with torch.no_grad():
            s = torch.as_tensor(state, dtype=torch.long, device=device).view(
                1, self.seq_num, self.max_seq_len
            )
            q = self.eval_net(s).squeeze(0)
        self.eval_net.train()
        return q
    
    @property
    def epsilon(self):
        return self.current_epsilon

    def update(self):
        self.update_step_counter += 1
        
        if self.replay_memory.size < self.batch_size:
            return None
        
        state, next_state, action, reward, done, next_mask = self.replay_memory.sample(self.batch_size)

        batch_state = torch.LongTensor(state).to(device).view(-1, self.seq_num, self.max_seq_len)
        batch_next_state = torch.LongTensor(next_state).to(device).view(-1, self.seq_num, self.max_seq_len)
        batch_action = torch.LongTensor(action).unsqueeze(-1).to(device)
        batch_reward = torch.FloatTensor(reward).unsqueeze(-1).to(device)
        batch_done = torch.FloatTensor(done).unsqueeze(-1).to(device)
        batch_next_mask = torch.BoolTensor(next_mask).to(device)
        
        q_eval = self.eval_net(batch_state).gather(1, batch_action)
        
        with torch.no_grad():
            if self.use_double_dqn:
                q_next_online = self.eval_net(batch_next_state)
                q_next_online_masked = q_next_online.clone()
                q_next_online_masked[~batch_next_mask] = float('-inf')
                best_next_actions = q_next_online_masked.max(1)[1]
                
                q_next_target = self.target_net(batch_next_state)
                q_next = q_next_target.gather(1, best_next_actions.unsqueeze(1))
            else:
                q_next_target = self.target_net(batch_next_state)
                q_next_masked = q_next_target.clone()
                q_next_masked[~batch_next_mask] = float('-inf')
                q_next = q_next_masked.max(1)[0].unsqueeze(-1)
            
            q_next = torch.where(torch.isinf(q_next), torch.zeros_like(q_next), q_next)
            q_target = batch_reward + (1.0 - batch_done) * self.gamma * q_next
            q_target = torch.clamp(q_target, -10.0, 10.0)
        
        loss = self.loss_func(q_eval, q_target)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.eval_net.parameters(), max_norm=1.0)
        self.optimizer.step()
        
        with torch.no_grad():
            for target_param, online_param in zip(self.target_net.parameters(), self.eval_net.parameters()):
                target_param.data.mul_(1.0 - self.tau)
                target_param.data.add_(self.tau * online_param.data)
        
        self.losses.append(loss.item())
        return loss.item()
    
    def save_model(self, path):
        """Save the current model (eval_net) and target_net weights."""
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save({
            'eval_net_state_dict': self.eval_net.state_dict(),
            'target_net_state_dict': self.target_net.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epsilon': self.current_epsilon,
            'update_step_counter': self.update_step_counter
        }, path)

    def load_model(self, path, map_location=None):
        """Load model weights from a saved checkpoint."""
        checkpoint = torch.load(path, map_location=map_location or device)
        self.eval_net.load_state_dict(checkpoint['eval_net_state_dict'])
        self.target_net.load_state_dict(checkpoint['target_net_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.current_epsilon = checkpoint.get('epsilon', self.current_epsilon)
        self.update_step_counter = checkpoint.get('update_step_counter', 0)


## Environment

In [27]:
# Example NUCLEOTIDE_MAP / NUCLEOTIDES assumed defined elsewhere:
# NUCLEOTIDE_MAP = {'-':0,'A':1,'C':2,'G':3,'T':4}
# NUCLEOTIDES = {v:k for k,v in NUCLEOTIDE_MAP.items()}

class AlignmentEnvironment:
    def __init__(self, sequences, total_gap, max_n_seqs=None, max_msa_len=None):
        """
        consensus-free environment
        sequences: list[str] (unaligned but will be padded)
        total_gap: int (gap tokens the agent may insert)
        """
        self.original_n_seqs = len(sequences)

        # grid caps
        self.max_n_seqs = max_n_seqs if max_n_seqs is not None else len(sequences)
        self.max_msa_len = max_msa_len if max_msa_len is not None else max(len(s) for s in sequences)

        # pad rows & pad #rows
        sequences = self._pad_sequences(sequences, self.max_msa_len)
        while len(sequences) < self.max_n_seqs:
            sequences.append('-' * self.max_msa_len)

        self.sequences = sequences
        self.initial_sequences = copy.deepcopy(sequences)

        # tokenized views
        self.sep_nuc_in_seq = [[c for c in seq] for seq in sequences]
        self.label_encoded_seqs = [[NUCLEOTIDE_MAP[c] for c in seq] for seq in sequences]

        # numbers used by agent
        self.num_seqs = self.max_n_seqs
        self.action_number = self.max_n_seqs * self.max_msa_len

        # rolling aligned/int arrays for state build (same as label_encoded_seqs initially)
        self.aligned = [[] for _ in range(self.num_seqs)]
        self.original = copy.deepcopy(self.label_encoded_seqs)

        # rewards & gaps
        self.last_score = self._pairwise_column_score(self.sep_nuc_in_seq[:self.original_n_seqs])
        self.total_gap = total_gap
        self.initial_gap = total_gap
        self.gaps_per_sequence = [0] * self.num_seqs

        # bookkeeping
        self.recent_actions = []
        self.last_seq_chosen = None
        self.seq_streak_count = 0
        self.seq_permutation = None
        self.seq_permutation_inv = None

    @staticmethod
    def _pad_sequences(sequences, target_len):
        return [s.ljust(target_len, '-') for s in sequences]

    def get_valid_action_mask(self):
        """
        Valid action: placing a gap where the current character is not already '-'.
        Action space: index = seq_idx * max_msa_len + pos
        """
        A = self.action_number
        mask = np.zeros(A, dtype=bool)
        for seq_idx in range(self.original_n_seqs):
            row = self.sequences[seq_idx]
            for pos in range(self.max_msa_len):
                action = seq_idx * self.max_msa_len + pos
                mask[action] = (row[pos] != '-')
        return mask

    # ---------- CONSENSUS-FREE REWARD ----------
    def _pairwise_column_score(self, rows, bonus_perfect=1.15):
        """
        Sum over columns:
          - For each column, count majority agreement among non-gaps (sum-of-pairs proxy).
          - If all non-gap residues in the column are identical (and at least one non-gap), multiply that column's contribution by bonus_perfect.
        """
        if not rows:
            return 0.0
        n = len(rows)
        L = min(len(r) for r in rows)
        score = 0.0
        for c in range(L):
            col = [rows[r][c] for r in range(n)]
            nz = [b for b in col if b != '-']
            if not nz:
                continue
            # sum-of-pairs proxy: majority size
            majority = max(sum(b == base for b in nz) for base in set(nz))
            col_score = majority
            # perfect non-gap column bonus
            if len(set(nz)) == 1:
                col_score *= bonus_perfect
            score += col_score
        return float(score)

    def get_current_state(self):
        """
        Return a 1-D list of ints with shape (num_seqs * max_msa_len).
        (No more +2 consensus rows.)
        """
        state = []
        for seq in self.original:
            state.extend(seq)
        return state

    # Predicted-only metrics (SP/CS) for validation & optional reward shaping
    def _predicted_sp_score(self, msa):
        # +2 match (non-gap), -1 if any gap, -2 mismatch (non-gap)
        n = len(msa)
        if n == 0: return 0.0
        L = len(msa[0])
        assert all(len(s) == L for s in msa)
        score = 0
        from itertools import combinations
        for i, j in combinations(range(n), 2):
            si, sj = msa[i], msa[j]
            for c in range(L):
                a, b = si[c], sj[c]
                if a == '-' or b == '-':
                    score -= 1
                elif a == b:
                    score += 2
                else:
                    score -= 2
        return float(score)

    def _predicted_cs_fraction(self, msa):
        n = len(msa)
        if n == 0: return 0.0
        L = len(msa[0])
        assert all(len(s) == L for s in msa)
        good = 0
        for c in range(L):
            col = [msa[r][c] for r in range(n)]
            nz = [x for x in col if x != '-']
            if nz and len(set(nz)) == 1:
                good += 1
        return good / L if L > 0 else 0.0

    # Public calc for validation “reward-like” score
    def calc_reward_from_alignment(self, aligned_list_of_str):
        rows = [[ch for ch in row] for row in aligned_list_of_str[:self.original_n_seqs]]
        return self._pairwise_column_score(rows)

    def calc_reward(self):
        """Current alignment score of internal sequences (for validation prints)."""
        rows = [[ch for ch in row] for row in self.sequences[:self.original_n_seqs]]
        return self._pairwise_column_score(rows)

    # ----------- STEP -----------
    def step(self, action):
        seq_idx = int(action) // self.max_msa_len
        pos = int(action) % self.max_msa_len

        if not (0 <= seq_idx < self.original_n_seqs and 0 <= pos < self.max_msa_len):
            # invalid action
            reward = -1.0
            return reward, self.get_current_state(), 0

        row = list(self.sequences[seq_idx])
        placed = False
        if row[pos] != '-':
            row.insert(pos, '-')        # shift right from pos
            row.pop()                   # keep row length fixed
            self.sequences[seq_idx] = "".join(row)

            # sync token views
            self.sep_nuc_in_seq[seq_idx]     = row
            self.label_encoded_seqs[seq_idx] = [NUCLEOTIDE_MAP[c] for c in row]
            self.original[seq_idx]           = [NUCLEOTIDE_MAP[c] for c in row]

            self.gaps_per_sequence[seq_idx] += 1
            self.total_gap = max(0, self.total_gap - 1)
            placed = True

        # compute dense local reward
        #  - small step reward/penalty
        #  - add terminal bonus based on entropy/diversity if you kept those helpers
        reward = 0.0
        reward += 0.5 if placed else -0.1

        done = 1 if self.total_gap == 0 else 0
        if done == 1:
            # optional: add diversity term if you kept it
            # reward += self._diversity_entropy(scale=20.0)
            pass

        reward = float(np.clip(reward, -25.0, 25.0))
        return reward, self.get_current_state(), done

    # -------- Reset / Permutation (unchanged) --------
    def get_alignment(self):
        alignment = []
        for i in range(len(self.aligned)):
            alignment.append(''.join([NUCLEOTIDES[self.aligned[i][j]] for j in range(len(self.aligned[i]))]))
        return alignment

    def get_original_alignment(self):
        # returns only real sequences, up to grid width
        full_alignment = self.sequences
        return [full_alignment[i] for i in range(self.original_n_seqs)]

    def reset(self):
        self.aligned = [[] for _ in range(self.num_seqs)]
        self.original = copy.deepcopy(self.label_encoded_seqs)
        self.total_gap = self.initial_gap
        self.last_score = self._pairwise_column_score(self.sep_nuc_in_seq[:self.original_n_seqs])
        self.gaps_per_sequence = [0] * self.num_seqs
        self.sequences = copy.deepcopy(self.initial_sequences)
        self.sep_nuc_in_seq = [[c for c in seq] for seq in self.sequences]
        self.recent_actions = []
        self.last_seq_chosen = None
        self.seq_streak_count = 0
        self.seq_permutation = None
        self.seq_permutation_inv = None
        return self.get_current_state()

    def randomize_sequence_order(self, apply=True):
        if not apply or self.original_n_seqs <= 1:
            self.seq_permutation = None
            self.seq_permutation_inv = None
            return
        self.seq_permutation = np.random.permutation(self.original_n_seqs)
        self.seq_permutation_inv = np.argsort(self.seq_permutation)
        self.original[:self.original_n_seqs]        = [self.original[i]        for i in self.seq_permutation]
        self.sequences[:self.original_n_seqs]       = [self.sequences[i]       for i in self.seq_permutation]
        self.sep_nuc_in_seq[:self.original_n_seqs]  = [self.sep_nuc_in_seq[i]  for i in self.seq_permutation]
        self.label_encoded_seqs[:self.original_n_seqs] = [self.label_encoded_seqs[i] for i in self.seq_permutation]
        gaps_copy = self.gaps_per_sequence[:self.original_n_seqs].copy()
        for new_idx, old_idx in enumerate(self.seq_permutation):
            self.gaps_per_sequence[new_idx] = gaps_copy[old_idx]

    
    def get_unpermuted_gaps(self):
        if self.seq_permutation_inv is None:
            return self.gaps_per_sequence[:self.original_n_seqs]
        
        original_order = [0] * self.original_n_seqs
        for new_idx, old_idx in enumerate(self.seq_permutation):
            original_order[old_idx] = self.gaps_per_sequence[new_idx]
        return original_order
    
    @staticmethod
    def get_sp_score(pred, match=2, mismatch=-2, gap = -1):
        n = len(pred)
        L = len(pred[0])

        score = 0
        for i, j in combinations(range(n), 2):
            si, sj = pred[i], pred[j]
            for c in range(L):
                a, b = si[c], sj[c]
                if a == '-' and b == '-':
                    score += 0 
                elif a == '-' or b == '-':
                    score += gap
                elif a == b:
                    score += match
                else:
                    score += mismatch
        return float(score)
    
    @staticmethod
    def get_cs_score(pred):
        n = len(pred)
        L = len(pred[0])

        good = 0
        for c in range(L):
            col = [pred[r][c] for r in range(n)]
            nz = [x for x in col if x != '-']      # non-gap residues
            if len(nz) == 0:
                # no residues in this column → by definition not a “matching residues” column
                continue
            if len(set(nz)) == 1:
                good += 1
        return good / L if L > 0 else 0.0
    
    @staticmethod
    def compute_alignment_metrics(pred, ref):
        """
        Computes predicted-only scores (SP, CS) and reference-based metrics (Q*, TC*).

        Q(A,R): pair-based accuracy (sum-of-pairs)
        TC(A,R): column-based accuracy (total-column match)

        Columns are compared by tuples (bases_across_sequences, column_index)
        so the logic is parallel to the pair-based comparison.
        """
        assert len(pred) == len(ref), "Pred/ref must have same number of sequences."
        n = len(ref)
        if n == 0:
            return {k: 0.0 for k in [
                'pred_sp','pred_cs','Q_acc','Q_prec','Q_rec','Q_f1',
                'TC_acc','TC_prec','TC_rec','TC_f1'
            ]}
        Lp = len(pred[0])
        Lr = len(ref[0])
        assert all(len(s) == Lp for s in pred)
        assert all(len(s) == Lr for s in ref)
        L = min(Lp, Lr)

        # --- predicted-only metrics
        pred_sp = AlignmentEnvironment.get_sp_score(pred)
        pred_cs = AlignmentEnvironment.get_cs_score(pred)

        # ---------- Q (pair) metrics ----------
        def get_pairs(msa):
            pairs = set()
            for i, j in combinations(range(n), 2):
                seq_i, seq_j = msa[i], msa[j]
                for c in range(L):
                    a, b = seq_i[c], seq_j[c]
                    if a != '-' and b != '-':
                        pairs.add((i, j, c))
            return pairs

        pred_pairs = get_pairs(pred)
        ref_pairs = get_pairs(ref)

        TPp = len(pred_pairs & ref_pairs)
        FPp = len(pred_pairs - ref_pairs)
        FNp = len(ref_pairs - pred_pairs)

        Q_acc  = TPp / len(ref_pairs) if len(ref_pairs) > 0 else 0.0
        Q_prec = TPp / (TPp + FPp) if (TPp + FPp) > 0 else 0.0
        Q_rec  = TPp / (TPp + FNp) if (TPp + FNp) > 0 else 0.0
        Q_f1   = (2 * Q_prec * Q_rec / (Q_prec + Q_rec)) if (Q_prec + Q_rec) > 0 else 0.0

        # ---------- TC (column) metrics ----------
        def get_columns(msa):
            cols = set()
            for c in range(L):
                col = tuple(msa[r][c] for r in range(n))
                cols.add((c, col))  # include index for uniqueness
            return cols

        pred_cols = get_columns(pred)
        ref_cols  = get_columns(ref)

        TPc = len(pred_cols & ref_cols)
        FPc = len(pred_cols - ref_cols)
        FNc = len(ref_cols - pred_cols)

        TC_acc  = TPc / len(ref_cols) if len(ref_cols) > 0 else 0.0
        TC_prec = TPc / (TPc + FPc) if (TPc + FPc) > 0 else 0.0
        TC_rec  = TPc / (TPc + FNc) if (TPc + FNc) > 0 else 0.0
        TC_f1   = (2 * TC_prec * TC_rec / (TC_prec + TC_rec)) if (TC_prec + TC_rec) > 0 else 0.0

        return {
            "pred_sp": float(pred_sp),
            "pred_cs": float(pred_cs),
            "Q_acc": float(Q_acc),
            "Q_prec": float(Q_prec),
            "Q_rec": float(Q_rec),
            "Q_f1": float(Q_f1),
            "TC_acc": float(TC_acc),
            "TC_prec": float(TC_prec),
            "TC_rec": float(TC_rec),
            "TC_f1": float(TC_f1)
        }


    

In [10]:
def run_prediction_episode(env, agent, expected_gaps, max_factor=1.25):
    import math
    import numpy as np
    
    state = env.reset()
    max_steps = max(1, int(math.ceil(expected_gaps * max_factor)))
    steps, rewards = 0, []
    
    for t in range(max_steps):
        valid_mask = env.get_valid_action_mask().astype(bool)
        
        q_values = agent.forward(state)
        action = masked_epsilon_greedy(
            q_values=q_values,
            valid_mask=valid_mask,
            epsilon=getattr(agent, "epsilon", 0.1),
        )
        
        reward, next_state, done = env.step(action)
        rewards.append(float(reward))
        
        if hasattr(agent, "replay_memory"):
            agent.replay_memory.store(
                (state, next_state, action, reward, bool(done), valid_mask.copy())
            )
        
        state = next_state
        steps += 1
        
        if done == 1 or not valid_mask.any():
            break
    
    gaps_dist = env.get_unpermuted_gaps() if hasattr(env, "get_unpermuted_gaps") \
                else env.gaps_per_sequence[:env.original_n_seqs]
    
    return {
        "steps": steps,
        "avg_reward": float(np.mean(rewards)) if rewards else 0.0,
        "sum_reward": float(np.sum(rewards)) if rewards else 0.0,
        "gaps_distribution": gaps_dist,
    }

## 7. Load and Analyze Data

In [11]:
def _count_inserted_gaps_from_sequences(start, solution):
    dash_start = sum(str(s).count('-') for s in start)
    dash_solution = sum(str(s).count('-') for s in solution)
    return max(0, dash_solution - dash_start)

def convert_column_major_solution(msa_string, n_seq):
    """
    Converts a column-major MSA string (down columns first) into
    a row-major list of aligned sequences.
    
    Args:
        msa_string (str): e.g. "AAACC---CGGTTTT"
        n_seq (int): number of sequences (rows)
    
    Returns:
        list[str]: e.g. ["ACGT-", "A-GT-", "AC-T-"]
    """
    if not msa_string or n_seq <= 0:
        return []

    # Split into chunks of n_seq (each chunk = one column)
    columns = [msa_string[i:i+n_seq] for i in range(0, len(msa_string), n_seq)]

    # Transpose columns -> rows
    seqs = [''.join(col[i] for col in columns) for i in range(n_seq)]
    return seqs

def convert_huggingface_to_samples(dataset, max_samples=None):
    samples = []
    for i, ex in enumerate(dataset):
        if max_samples and i >= max_samples:
            break

        unaligned_seqs = ex.get('unaligned_seqs', {})
        MSA = ex.get('MSA', "")

        if not unaligned_seqs or not MSA:
            continue

        start = [unaligned_seqs[k] for k in sorted(unaligned_seqs.keys())]
        n_seq = len(start)
        solution = convert_column_major_solution(MSA, n_seq)

        accepted_pairs = [(str(a), str(b)) for a, b in combinations(range(len(start)), 2)]
        n_gaps = _count_inserted_gaps_from_sequences(start, solution)

        sample = {
            'start': start,
            'solution': solution,
            'n_gaps': n_gaps,
            'moves': [-1] * n_gaps,  # keep list length equal to n_gaps as this is never actually used in the DQN
            'n_sequences': len(start),
            'idx': i
        }
        samples.append(sample)
    return samples


def filter_by_seq_length(example, max_len=MAX_MSA_LEN):
    """Keep only samples where every unaligned sequence is <= max_len."""
    if "unaligned_seqs" not in example:
        return False
    seqs = example["unaligned_seqs"].values() if isinstance(example["unaligned_seqs"], dict) else example["unaligned_seqs"]
    return all(len(seq) <= max_len for seq in seqs)

# --- Load and filter datasets ---
ds = load_dataset("dotan1111/MSA-nuc-3-seq", split="train")
ds = ds.filter(filter_by_seq_length)
train_samples = convert_huggingface_to_samples(ds, max_samples=training_size)
print(train_samples[0])

ds = load_dataset("dotan1111/MSA-nuc-3-seq", split="validation")
ds = ds.filter(filter_by_seq_length)
val_samples = convert_huggingface_to_samples(ds)

ds = load_dataset("dotan1111/MSA-nuc-3-seq", split="test")
ds = ds.filter(filter_by_seq_length)
test_samples = convert_huggingface_to_samples(ds) 


Filter: 100%|██████████| 1494999/1494999 [00:07<00:00, 189521.61 examples/s]


{'start': ['TACTACAGTTCTTAAAAATAATCTATTAAAATTTTTTTGCT', 'TAGTACGATTCGTGAAAATAATCTGTTAAAATTCTTTTTCT', 'TTATACAATTTTTTGAGGATTAATCTGTTGAAATTATTGTTCT'], 'solution': ['TACTACAGTTCTT--AAAAATAATCTATTAAAATTTTTTTGCT', 'TAGTACGATTCGT--GAAAATAATCTGTTAAAATTCTTTTTCT', 'TTATACAATTTTTTGAGGATTAATCTGTTGAAATTATTGTTCT'], 'n_gaps': 4, 'moves': [-1, -1, -1, -1], 'n_sequences': 3, 'idx': 0}


Filter: 100%|██████████| 2000/2000 [00:00<00:00, 118644.04 examples/s]
Filter: 100%|██████████| 3001/3001 [00:00<00:00, 91076.94 examples/s]


## 8. Dimension Analysis (Using ALL Data with Padding) DEPRECATEDDDDDDD

In [None]:
max_n_seqs = max(s['n_sequences'] for s in train_samples)
max_cons_len = max(s['consensus_length'] for s in train_samples)

for sample in train_samples:
    sample['original_n_seqs'] = sample['n_sequences']
    sample['original_cons_len'] = sample['consensus_length']

max_n_seqs = max(s['n_sequences'] for s in val_samples)
max_cons_len = max(s['consensus_length'] for s in val_samples)

for sample in val_samples:
    sample['original_n_seqs'] = sample['n_sequences']
    sample['original_cons_len'] = sample['consensus_length']

max_n_seqs = max(s['n_sequences'] for s in test_samples)
max_cons_len = max(s['consensus_length'] for s in test_samples)
for sample in test_samples:
    sample['original_n_seqs'] = sample['n_sequences']
    sample['original_cons_len'] = sample['consensus_length']


## 9. Training with Variable Dimensions (Using Padding)

In [33]:
# --- Create results directory ---
os.makedirs("../result/log", exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_path = os.path.join("../result/log", f"{timestamp}_log.csv")

# --- Prepare CSV logging ---
log_fields = [
    "epoch", "avg_train_reward", "epsilon",
    "SP", "CS", "Q_acc", "Q_prec", "Q_rec", "Q_f1", "TC_acc", "TC_prec", "TC_rec", "TC_f1"
]

with open(log_path, "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=log_fields)
    writer.writeheader()

max_n_seqs = MAX_N_SEQS
agent = DQNAgent(
    action_number=max_n_seqs * MAX_MSA_LEN,
    num_seqs=max_n_seqs,
    max_grid=MAX_MSA_LEN,           # not consensus length
    max_value=MAX_MSA_LEN * 100,
    epsilon=0.9,
    delta=0.01,
    batch_size=64,
    gamma=0.99,
    learning_rate=0.001,
    memory_size=5000
)

epochs = 2 # 10
samples_per_epoch = 5 # 100
val_samples_per_epoch = 5 # 20

for epoch in range(epochs):
    epoch_samples = random.sample(train_samples, min(len(train_samples), samples_per_epoch))
    epoch_rewards = []
    
    for sample in tqdm(epoch_samples, desc=f"Epoch {epoch+1}/{epochs}"):
        env = AlignmentEnvironment(
            sequences=sample['start'],               # no accepted_pairs
            total_gap=get_expected_gaps(sample),     # your existing helper
            max_n_seqs=max_n_seqs,
            max_msa_len=MAX_MSA_LEN
        )
        print(f"The env setup is initialized with {get_expected_gaps(sample)} gaps, max_n_seqs={max_n_seqs}")
        
        state = env.reset()
        env.randomize_sequence_order(apply=True)
        
        episode_reward = 0
        step_count = 0
        max_steps = get_expected_gaps(sample) * 5
        valid_mask = env.get_valid_action_mask()
        
        while step_count < max_steps:
            action = agent.select_action(state, valid_action_mask=valid_mask)
            reward, next_state, done = env.step(action)
            
            next_valid_mask = env.get_valid_action_mask()
            agent.replay_memory.store((state, next_state, action, reward, done, next_valid_mask))
            
            agent.update()
            
            episode_reward += reward
            step_count += 1
            state = next_state
            valid_mask = next_valid_mask
            
            if done == 1:
                break
        
        agent.update_epsilon()
        epoch_rewards.append(episode_reward)
    
    avg_train_reward = np.mean(epoch_rewards)
    print(f"Epoch {epoch+1}: reward={avg_train_reward:.2f}, memory={agent.replay_memory.size}, ε={agent.current_epsilon:.3f}")

    # Validation section
    val_metrics_all = {
        "pred_sp": [], "pred_cs": [],
        "Q_acc": [], "Q_prec": [], "Q_rec": [], "Q_f1": [],
        "TC_acc": [], "TC_prec": [], "TC_rec": [], "TC_f1": []
    }

    for val_sample in random.sample(val_samples, min(val_samples_per_epoch, len(val_samples))):
        env = AlignmentEnvironment(
            sequences=sample['start'],               # no accepted_pairs
            total_gap=get_expected_gaps(sample),     # your existing helper
            max_n_seqs=max_n_seqs,
            max_msa_len=MAX_MSA_LEN
        )

        predicted, _ = run_dqn_inference(agent, env, get_expected_gaps(val_sample))

        metrics = env.compute_alignment_metrics(predicted, val_sample['solution'])
        val_reward = env.calc_reward()

        # Aggregate metrics
        for k, v in metrics.items():
            val_metrics_all[k].append(v)


    # --- Average metrics across all validation samples ---
    avg_val_metrics = {k: np.mean(v) if len(v) > 0 else 0.0 for k, v in val_metrics_all.items()}

    # --- Print epoch summary ---
    print(
        f"Epoch {epoch+1}: "
        f"SP={avg_val_metrics['pred_sp']:.2f}, "
        f"CS={avg_val_metrics['pred_cs']:.3f}, "
        f"Q_acc={avg_val_metrics['Q_acc']:.3f}, "
        f"Q_prec={avg_val_metrics['Q_prec']:.3f}, "
        f"Q_rec={avg_val_metrics['Q_rec']:.3f}, "
        f"Q_f1={avg_val_metrics['Q_f1']:.3f}, "
        f"TC_acc={avg_val_metrics['TC_acc']:.3f}, "
        f"TC_prec={avg_val_metrics['TC_prec']:.3f}, "
        f"TC_rec={avg_val_metrics['TC_rec']:.3f}, "
        f"TC_f1={avg_val_metrics['TC_f1']:.3f}"
    )

    with open(log_path, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=log_fields)
        writer.writerow({
            "epoch": epoch + 1,
            "avg_train_reward": avg_train_reward,
            "epsilon": agent.current_epsilon,
            "SP": avg_val_metrics['pred_sp'],
            "CS": avg_val_metrics['pred_cs'],
            "Q_acc": avg_val_metrics['Q_acc'],
            "Q_prec": avg_val_metrics['Q_prec'],
            "Q_rec": avg_val_metrics['Q_rec'],
            "Q_f1": avg_val_metrics['Q_f1'],
            "TC_acc": avg_val_metrics['TC_acc'],
            "TC_prec": avg_val_metrics['TC_prec'],
            "TC_rec": avg_val_metrics['TC_rec'],
            "TC_f1": avg_val_metrics['TC_f1']
        })

# --- Save trained model ---
os.makedirs("../result/agent", exist_ok=True)
model_path = os.path.join("../result/agent", f"{timestamp}_model.pt")
agent.save_model(model_path)
print(f"\nTraining complete. Model saved to: {model_path}")
print(f"Training log saved to: {log_path}")



Epoch 1/2: 100%|██████████| 5/5 [00:00<00:00, 107.16it/s]


using the n_gaps from sample
using the n_gaps from sample
The env setup is initialized with 13 gaps, max_n_seqs=3
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
The env setup is initialized with 2 gaps, max_n_seqs=3
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
The env setup is initialized with 40 gaps, max_n_seqs=3
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
The env setup is initialized with 2 gaps, max_n_seqs=3
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
The env setup is initialized with 0 gaps, max_n_seqs=3
using the n_gaps from sample
Epoch 1: reward=5.70, memory=57, ε=0.900
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
using the n_gap

  batch_next_mask = torch.BoolTensor(next_mask).to(device)


using the n_gaps from sample
using the n_gaps from sample
The env setup is initialized with 6 gaps, max_n_seqs=3
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
The env setup is initialized with 5 gaps, max_n_seqs=3
using the n_gaps from sample


Epoch 2/2:  40%|████      | 2/5 [00:02<00:03,  1.31s/it]

using the n_gaps from sample
using the n_gaps from sample
The env setup is initialized with 33 gaps, max_n_seqs=3
using the n_gaps from sample


Epoch 2/2:  60%|██████    | 3/5 [00:17<00:13,  6.81s/it]

using the n_gaps from sample
using the n_gaps from sample
The env setup is initialized with 6 gaps, max_n_seqs=3
using the n_gaps from sample


Epoch 2/2:  80%|████████  | 4/5 [00:20<00:05,  5.53s/it]

using the n_gaps from sample
using the n_gaps from sample
The env setup is initialized with 4 gaps, max_n_seqs=3
using the n_gaps from sample


Epoch 2/2: 100%|██████████| 5/5 [00:22<00:00,  4.48s/it]

Epoch 2: reward=5.40, memory=111, ε=0.899
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
using the n_gaps from sample
Epoch 2: SP=-58.80, CS=0.136, Q_acc=0.868, Q_prec=0.992, Q_rec=0.868, Q_f1=0.923, TC_acc=0.014, TC_prec=0.014, TC_rec=0.014, TC_f1=0.014

Training complete. Model saved to: ../result/agent/2025-11-09_02-04-25_model.pt
Training log saved to: ../result/log/2025-11-09_02-04-25_log.csv





## 10. Inference 

In [16]:
def run_dqn_inference(agent, env, expected_gaps, max_steps=None):
    """
    Runs an inference episode using the trained DQN agent in the consensus-free environment.
    """
    if max_steps is None:
        max_steps = expected_gaps * 2

    state = env.reset()
    env.randomize_sequence_order(apply=True)

    actions_taken = []
    valid_mask = env.get_valid_action_mask()

    for step in range(max_steps):
        if valid_mask.sum() == 0:
            break

        # Predict next action
        action = agent.predict(state, valid_action_mask=valid_mask)

        # Convert linear action index to (sequence, position)
        seq_idx = action // env.max_msa_len
        pos = action % env.max_msa_len
        actions_taken.append((seq_idx, pos))

        # Apply action in environment
        _, next_state, done = env.step(action)
        state = next_state
        valid_mask = env.get_valid_action_mask()

        if done == 1:
            break

    predicted = env.sequences[:env.original_n_seqs]

    return predicted, actions_taken
