# DQN


## Imports and Setup

In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import copy
import os
from pathlib import Path
from tqdm import tqdm
from collections import Counter
from datasets import load_dataset

# Setup
project_root = Path.cwd().parent if Path.cwd().name == 'notebooks' else Path.cwd()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Constants
NUCLEOTIDE_MAP = {"A": 1, "T": 2, "C": 3, "G": 4, "-": 5, "PAD": 0}
NUCLEOTIDES = ["PAD", "A", "T", "C", "G", "-"]

ModuleNotFoundError: No module named 'matplotlib'

In [None]:
import numpy as np
import torch

def masked_epsilon_greedy(q_values: torch.Tensor, valid_mask: np.ndarray, epsilon: float, rng=None) -> int:
    if q_values.ndim > 1:
        q_values = q_values.reshape(-1)
    valid_mask = np.asarray(valid_mask, dtype=bool)
    
    if rng is None:
        rng = np.random
    
    valid_idx = np.flatnonzero(valid_mask)
    
    if valid_idx.size == 0:
        return int(torch.argmax(q_values).item())
    
    if rng.random() < epsilon:
        return int(rng.choice(valid_idx))
    
    q = q_values.detach().cpu().numpy().copy()
    q[~valid_mask] = -np.inf
    return int(np.argmax(q))

In [None]:
def get_expected_alignment(sample):
    for key in ("solution", "aligned", "target"):
        if key in sample and sample[key] is not None:
            result = sample[key]
            if isinstance(result, str):
                return [result]
            return [str(s) for s in result]
    return None

def get_expected_gaps(sample):
    if 'n_gaps' in sample and sample['n_gaps'] is not None:
        return sample['n_gaps']
    
    if 'moves' in sample and isinstance(sample['moves'], list):
        return len(sample['moves'])
    
    if 'solution' in sample and sample['solution']:
        if 'start' in sample:
            start_gaps = sum(str(s).count('-') for s in sample['start'])
            solution_gaps = sum(str(s).count('-') for s in sample['solution'])
            return max(0, solution_gaps - start_gaps)
        return sum(seq.count('-') for seq in sample['solution'])
    
    return 20

## ReplayMemory

In [None]:
class ReplayMemory:
    def __init__(self, memory_size=1000):
        self.storage = []
        self.memory_size = memory_size
        self.size = 0
    
    def store(self, data: tuple):
        if len(self.storage) == self.memory_size:
            self.storage.pop(0)
        self.storage.append(data)
        self.size = min(self.size + 1, self.memory_size)
    
    def sample(self, batch_size):
        samples = random.sample(self.storage, batch_size)
        state = [s for s, _, _, _, _, _ in samples]
        next_state = [ns for _, ns, _, _, _, _ in samples]
        action = [a for _, _, a, _, _, _ in samples]
        reward = [r for _, _, _, r, _, _ in samples]
        done = [d for _, _, _, _, d, _ in samples]
        next_mask = [m for _, _, _, _, _, m in samples]
        return state, next_state, action, reward, done, next_mask

class ScaledDotProductAttention(nn.Module):
    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
    
    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)
        return output, attn

class PositionalEncoding(nn.Module):
    def __init__(self, d_hid, n_position=200):
        super().__init__()
        self.register_buffer("pos_table", self._get_sinusoid_table(n_position, d_hid))
    
    @staticmethod
    def _get_sinusoid_table(n_position, d_hid):
        positions = torch.arange(n_position).float().unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_hid, 2).float() * (-np.log(10000.0) / d_hid))
        sin = torch.sin(positions * div_term)
        cos = torch.cos(positions * div_term)
        return torch.cat([sin, cos], dim=-1).unsqueeze(0)
    
    def forward(self, x):
        return x + self.pos_table[:, :x.size(1), :].clone().detach()

class SelfAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v, dropout=0.1):
        super().__init__()
        self.d_k = d_k
        self.d_v = d_v
        self.w_qs = nn.Linear(d_model, d_k, bias=False)
        self.w_ks = nn.Linear(d_model, d_k, bias=False)
        self.w_vs = nn.Linear(d_model, d_v, bias=False)
        self.fc = nn.Linear(d_v, d_model, bias=False)
        self.attention = ScaledDotProductAttention(temperature=d_k**0.5)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
    
    def forward(self, q, k, v, mask=None):
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
        residual = q
        
        q = self.w_qs(q).view(sz_b, len_q, 1, self.d_k)
        k = self.w_ks(k).view(sz_b, len_k, 1, self.d_k)
        v = self.w_vs(v).view(sz_b, len_v, 1, self.d_v)
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
        
        if mask is not None:
            mask = mask.unsqueeze(1)
        
        q, attn = self.attention(q, k, v, mask=mask)
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        q = self.dropout(self.fc(q))
        q += residual
        q = self.layer_norm(q)
        return q, attn

class Encoder(nn.Module):
    def __init__(self, n_src_vocab, d_model, n_position, d_k=164, d_v=164, pad_idx=0, dropout=0.1):
        super().__init__()
        self.src_word_emb = nn.Embedding(n_src_vocab, d_model, padding_idx=pad_idx)
        self.position_enc = PositionalEncoding(d_model, n_position=n_position)
        self.dropout = nn.Dropout(p=dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.self_attention = SelfAttention(d_model, d_k, d_v, dropout=dropout)
    
    def forward(self, src_seq, mask):
        enc_output = self.src_word_emb(src_seq)
        enc_output = self.position_enc(enc_output)
        enc_output = self.dropout(enc_output)
        enc_output = self.layer_norm(enc_output)
        enc_output, _ = self.self_attention(enc_output, enc_output, enc_output, mask=mask)
        return enc_output

class QNetwork(nn.Module):
    def __init__(self, num_sequences, max_sequence_length, num_actions, max_action_value, d_model=64):
        super().__init__()
        self.num_sequences = num_sequences
        self.num_rows = num_sequences + 2
        
        dim = (num_sequences + 2) * max_sequence_length
        self.encoder = Encoder(6, d_model, dim)
        self.seq_embedding = nn.Embedding(self.num_rows, d_model)
        nn.init.normal_(self.seq_embedding.weight, 0.0, 0.1)
        
        self.fc1 = nn.Linear(dim * d_model, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, num_actions)
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, x):
        B, R, C = x.shape
        
        x_flat = x.view(B, R * C)
        mask_flat = (x_flat != 0).unsqueeze(1)
        h = self.encoder(x_flat, mask_flat)
        
        if h.dim() == 3:
            h = h.view(B, R, C, -1)
        
        row_ids = torch.arange(R, device=h.device)
        row_emb = self.seq_embedding(row_ids)[None, :, None, :]
        h = h + row_emb
        
        h = h.reshape(B, -1)
        h = F.leaky_relu(self.fc1(h))
        h = self.dropout(h)
        h = F.leaky_relu(self.fc2(h))
        h = self.dropout(h)
        q = self.fc3(h)
        return q

## DQN Agent


In [None]:
class DQNAgent:    
    def __init__(self, action_number, num_seqs, max_grid, max_value,
                 epsilon=0.8, delta=0.05, decrement_iteration=5,
                 update_iteration=128, batch_size=128, gamma=1.0,
                 learning_rate=0.001, memory_size=1000):
        self.action_number = action_number
        self.num_seqs = num_seqs
        self.max_grid = max_grid
        self.seq_num = num_seqs
        self.max_seq_len = max_grid + 1
        
        self.eval_net = QNetwork(num_seqs, self.max_seq_len, action_number, max_value).to(device)
        self.target_net = QNetwork(num_seqs, self.max_seq_len, action_number, max_value).to(device)
        self.target_net.load_state_dict(self.eval_net.state_dict())
        
        self.replay_memory = ReplayMemory(memory_size=memory_size)
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=learning_rate)
        self.loss_func = nn.SmoothL1Loss()
        
        self.batch_size = batch_size
        self.gamma = gamma
        self.update_iteration = update_iteration
        self.update_step_counter = 0
        self.tau = 0.005
        self.use_double_dqn = True
        
        self.initial_epsilon = epsilon
        self.current_epsilon = epsilon
        self.epsilon_end = delta
        self.epsilon_decay = 0.999
        
        self.losses = []
        self.epsilons = []

    def update_epsilon(self):
        if self.update_step_counter < 5000:
            decay_rate = 0.9999
        elif self.update_step_counter < 10000:
            decay_rate = 0.999
        else:
            decay_rate = 0.995
        
        self.current_epsilon = max(self.epsilon_end, self.current_epsilon * decay_rate)
        self.epsilons.append(self.current_epsilon)

    def select_action(self, state, valid_action_mask=None):
        is_random = (random.random() <= self.current_epsilon)
        
        if is_random:
            if valid_action_mask is not None:
                valid_idx = np.flatnonzero(valid_action_mask)
                action = int(np.random.choice(valid_idx)) if len(valid_idx) else random.randrange(self.action_number)
            else:
                action = random.randrange(self.action_number)
        else:
            self.eval_net.eval()
            with torch.no_grad():
                s = torch.as_tensor(state, dtype=torch.long, device=device).view(1, self.seq_num + 2, self.max_seq_len)
                q = self.eval_net(s).squeeze(0).detach().cpu().numpy()
            self.eval_net.train()
            
            if valid_action_mask is not None:
                q[~valid_action_mask] = -np.inf
            
            action = int(np.argmax(q))
        
        return action

    def predict(self, state, valid_action_mask=None):
        self.eval_net.eval()
        with torch.no_grad():
            s = torch.as_tensor(state, dtype=torch.long, device=device).view(1, self.seq_num + 2, self.max_seq_len)
            q = self.eval_net(s).squeeze(0).detach().cpu().numpy()
        self.eval_net.train()
        
        if valid_action_mask is not None:
            q[~valid_action_mask] = -np.inf
        
        return int(np.nanargmax(q))

    def forward(self, state):
        self.eval_net.eval()
        with torch.no_grad():
            s = torch.as_tensor(state, dtype=torch.long, device=device).view(
                1, self.seq_num + 2, self.max_seq_len
            )
            q = self.eval_net(s).squeeze(0)
        self.eval_net.train()
        return q
    
    @property
    def epsilon(self):
        return self.current_epsilon

    def update(self):
        self.update_step_counter += 1
        
        if self.replay_memory.size < self.batch_size:
            return None
        
        state, next_state, action, reward, done, next_mask = self.replay_memory.sample(self.batch_size)
        
        batch_state = torch.LongTensor(state).to(device).view(-1, self.seq_num + 2, self.max_seq_len)
        batch_next_state = torch.LongTensor(next_state).to(device).view(-1, self.seq_num + 2, self.max_seq_len)
        batch_action = torch.LongTensor(action).unsqueeze(-1).to(device)
        batch_reward = torch.FloatTensor(reward).unsqueeze(-1).to(device)
        batch_done = torch.FloatTensor(done).unsqueeze(-1).to(device)
        batch_next_mask = torch.BoolTensor(next_mask).to(device)
        
        q_eval = self.eval_net(batch_state).gather(1, batch_action)
        
        with torch.no_grad():
            if self.use_double_dqn:
                q_next_online = self.eval_net(batch_next_state)
                q_next_online_masked = q_next_online.clone()
                q_next_online_masked[~batch_next_mask] = float('-inf')
                best_next_actions = q_next_online_masked.max(1)[1]
                
                q_next_target = self.target_net(batch_next_state)
                q_next = q_next_target.gather(1, best_next_actions.unsqueeze(1))
            else:
                q_next_target = self.target_net(batch_next_state)
                q_next_masked = q_next_target.clone()
                q_next_masked[~batch_next_mask] = float('-inf')
                q_next = q_next_masked.max(1)[0].unsqueeze(-1)
            
            q_next = torch.where(torch.isinf(q_next), torch.zeros_like(q_next), q_next)
            q_target = batch_reward + (1.0 - batch_done) * self.gamma * q_next
            q_target = torch.clamp(q_target, -10.0, 10.0)
        
        loss = self.loss_func(q_eval, q_target)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.eval_net.parameters(), max_norm=1.0)
        self.optimizer.step()
        
        with torch.no_grad():
            for target_param, online_param in zip(self.target_net.parameters(), self.eval_net.parameters()):
                target_param.data.mul_(1.0 - self.tau)
                target_param.data.add_(self.tau * online_param.data)
        
        self.losses.append(loss.item())
        return loss.item()

## Environment

In [None]:
class AlignmentEnvironment:
    def __init__(self, sequences, consensus, total_gap, max_n_seqs=None, max_cons_len=None):
        self.consensus = consensus
        self.original_n_seqs = len(sequences)
        self.original_cons_len = len(consensus)
        
        self.max_n_seqs = max_n_seqs if max_n_seqs is not None else len(sequences)
        self.max_cons_len = max_cons_len if max_cons_len is not None else len(consensus)
        
        padded_consensus = consensus + [('PAD', 'PAD')] * (self.max_cons_len - len(consensus))
        sequences = self._pad_sequences(sequences, self.max_cons_len)
        while len(sequences) < self.max_n_seqs:
            sequences.append('-' * self.max_cons_len)
        
        self.sequences = sequences
        self.initial_sequences = copy.deepcopy(sequences)
        self.sep_nuc_in_seq = [[letter for letter in seq] for seq in sequences]
        self.label_encoded_seqs = [[NUCLEOTIDE_MAP[char] for char in seq] for seq in sequences]
        self.consensus_encoded = self._encode_consensus(padded_consensus)
        
        self.num_seqs = self.max_n_seqs
        self.action_number = self.max_n_seqs * self.max_cons_len
        self.aligned = [[] for _ in range(self.num_seqs)]
        self.original = copy.deepcopy(self.label_encoded_seqs)
        
        self.original_score = self._gearbox_score(consensus, self.sep_nuc_in_seq[:self.original_n_seqs])
        self.last_score = self.original_score
        
        self.total_gap = total_gap
        self.initial_gap = total_gap
        self.gaps_per_sequence = [0] * self.num_seqs
        
        self.recent_actions = []
        self.last_seq_chosen = None
        self.seq_streak_count = 0
        
        self.seq_permutation = None
        self.seq_permutation_inv = None
    
    @staticmethod
    def _pad_sequences(sequences, target_len):
        return [s.ljust(target_len, '-') for s in sequences]
    
    def _encode_consensus(self, consensus):
        cons_1 = [NUCLEOTIDE_MAP.get(p[0], 0) for p in consensus]
        cons_2 = [NUCLEOTIDE_MAP.get(p[1], 0) for p in consensus]
        return [cons_1, cons_2]
    
    def get_valid_action_mask(self):
        A = self.action_number
        mask = np.zeros(A, dtype=bool)
        for seq_idx in range(self.original_n_seqs):
            row = self.sequences[seq_idx]
            for pos in range(self.max_cons_len):
                action = seq_idx * self.max_cons_len + pos
                mask[action] = (row[pos] != '-')
        return mask
    
    @staticmethod
    def _gearbox_score(consensus, sequences):
        '''Calculate Column score (how many bases match with the reference) with bonus for perfect columns'''
        bonus = 1.15
        score = 0
        max_col = min(len(consensus), min(len(seq) for seq in sequences) if sequences else 0)
        for col_ind in range(max_col):
            col_bonus = True
            col_tot = 0
            for row in sequences:
                if col_ind >= len(row):
                    continue
                char = row[col_ind]
                if char == '-':
                    col_bonus = False
                    continue
                if col_ind < len(consensus) and char in consensus[col_ind]:
                    col_tot += 1
                else:
                    col_bonus = False
            if col_bonus:
                score += col_tot * bonus
            else:
                score += col_tot
        return score
    
    def get_current_state(self):
        state = []
        for seq in self.original:
            state.extend(seq)
        state.extend(self.consensus_encoded[0])
        state.extend(self.consensus_encoded[1])
        return state
    
    def calc_reward(self):
        current_alignment = self.get_alignment()
        rows = [[letter for letter in seq] for seq in current_alignment[:self.original_n_seqs]]
        return self._gearbox_score(self.consensus[:self.original_cons_len], rows)
    
    def _count_perfect_columns(self):
        current_alignment = self.get_alignment()
        rows = [[letter for letter in seq] for seq in current_alignment[:self.original_n_seqs]]
        perfect_count = 0
        max_col = min(self.original_cons_len, min(len(r) for r in rows) if rows else 0)
        for col_ind in range(max_col):
            all_match = True
            has_nucleotide = False
            for row in rows:
                if col_ind >= len(row):
                    continue
                char = row[col_ind]
                if char == '-':
                    continue
                has_nucleotide = True
                if col_ind < len(self.consensus) and char not in self.consensus[col_ind]:
                    all_match = False
                    break
            if all_match and has_nucleotide:
                perfect_count += 1
        return perfect_count
    
    def _calculate_position_aware_reward(self, seq_idx, pos):
        if seq_idx >= self.original_n_seqs or pos >= self.original_cons_len:
            return 0.0
        
        sequence = self.original[seq_idx]
        reward = 0
        look_ahead = min(5, len(sequence) - pos - 1)
        for offset in range(1, look_ahead + 1):
            nuc_pos = pos + offset
            if nuc_pos >= self.original_cons_len:
                break
            nuc_val = sequence[nuc_pos]
            if nuc_val in [5, 0]:
                continue
            nuc = NUCLEOTIDES[nuc_val]
            if nuc_pos < len(self.consensus) and nuc in self.consensus[nuc_pos]:
                reward += 1.0 / offset
        return reward
    
    def _calculate_diversity_bonus(self):
        gaps_in_real_seqs = self.gaps_per_sequence[:self.original_n_seqs]
        total_gaps = sum(gaps_in_real_seqs)
        
        if total_gaps == 0 or self.original_n_seqs == 1:
            return 0.0
        
        gaps_array = np.array(gaps_in_real_seqs, dtype=float)
        n = self.original_n_seqs
        sorted_gaps = np.sort(gaps_array)
        
        gini = (2.0 * np.sum((np.arange(n) + 1) * sorted_gaps)) / (n * total_gaps) - (n + 1.0) / n
        diversity_score = 1.0 - 2.0 * gini
        
        return diversity_score
    
    def _calculate_concentration_penalty(self):
        gaps_in_real_seqs = self.gaps_per_sequence[:self.original_n_seqs]
        total_gaps = sum(gaps_in_real_seqs)
        
        if total_gaps == 0:
            return 0.0
        
        max_gaps_in_any_seq = max(gaps_in_real_seqs)
        avg_gaps = total_gaps / self.original_n_seqs
        
        if max_gaps_in_any_seq > avg_gaps * 2.5:
            return -5.0
        elif max_gaps_in_any_seq > avg_gaps * 2.0:
            return -3.0
        elif max_gaps_in_any_seq > avg_gaps * 1.5:
            return -1.0
        else:
            return 0.0
    
    def _diversity_entropy(self, scale: float = 20.0) -> float:
        gaps = self.gaps_per_sequence[:self.original_n_seqs]
        S = float(sum(gaps))
        if S <= 0.0:
            return 0.0
        
        p = np.asarray(gaps, dtype=np.float64) / S
        nz = p > 0
        R_eff = int(nz.sum())
        if R_eff <= 1:
            return 0.0
        
        H = float(-(p[nz] * np.log(p[nz] + 1e-12)).sum())
        H_norm = H / np.log(float(R_eff))
        return scale * H_norm
    
    def step(self, action):
        seq_idx = int(action) // self.max_cons_len
        pos = int(action) % self.max_cons_len
        
        if not (0 <= seq_idx < self.original_n_seqs and 0 <= pos < self.max_cons_len):
            reward = -1.0
            return reward, self.get_current_state(), 0
        
        row = list(self.sequences[seq_idx])
        placed = False
        if row[pos] != '-':
            row.insert(pos, '-')
            row.pop()
            self.sequences[seq_idx] = "".join(row)
            self.sep_nuc_in_seq[seq_idx] = row
            self.label_encoded_seqs[seq_idx] = [NUCLEOTIDE_MAP[char] for char in row]
            self.original[seq_idx] = [NUCLEOTIDE_MAP[char] for char in row]
            self.gaps_per_sequence[seq_idx] += 1
            self.total_gap = max(0, self.total_gap - 1)
            placed = True
        
        reward = 0.0
        if placed:
            reward += 0.5
        else:
            reward -= 0.1
        
        done = 1 if self.total_gap == 0 else 0
        if done == 1:
            reward += self._diversity_entropy(scale=20.0)
        
        reward = float(np.clip(reward, -25.0, 25.0))
        
        return reward, self.get_current_state(), done
    
    def get_alignment(self):
        alignment = []
        for i in range(len(self.aligned)):
            alignment.append(''.join([NUCLEOTIDES[self.aligned[i][j]] for j in range(len(self.aligned[i]))]))
        return alignment
    
    def get_original_alignment(self):
        full_alignment = self.get_alignment()
        return [full_alignment[i][:self.original_cons_len] for i in range(self.original_n_seqs)]
    
    def reset(self):
        self.aligned = [[] for _ in range(self.num_seqs)]
        self.original = copy.deepcopy(self.label_encoded_seqs)
        self.total_gap = self.initial_gap
        self.last_score = self.original_score
        self.gaps_per_sequence = [0] * self.num_seqs
        self.sequences = copy.deepcopy(self.initial_sequences)
        self.sep_nuc_in_seq = [[letter for letter in seq] for seq in self.sequences]
        self.recent_actions = []
        self.last_seq_chosen = None
        self.seq_streak_count = 0
        self.seq_permutation = None
        self.seq_permutation_inv = None
        return self.get_current_state()
    
    def randomize_sequence_order(self, apply=True):
        if not apply or self.original_n_seqs <= 1:
            self.seq_permutation = None
            self.seq_permutation_inv = None
            return
        
        self.seq_permutation = np.random.permutation(self.original_n_seqs)
        self.seq_permutation_inv = np.argsort(self.seq_permutation)
        self.original[:self.original_n_seqs] = [self.original[i] for i in self.seq_permutation]
        self.sequences[:self.original_n_seqs] = [self.sequences[i] for i in self.seq_permutation]
        self.sep_nuc_in_seq[:self.original_n_seqs] = [self.sep_nuc_in_seq[i] for i in self.seq_permutation]
        self.label_encoded_seqs[:self.original_n_seqs] = [self.label_encoded_seqs[i] for i in self.seq_permutation]
        
        gaps_copy = self.gaps_per_sequence[:self.original_n_seqs].copy()
        for new_idx, old_idx in enumerate(self.seq_permutation):
            self.gaps_per_sequence[new_idx] = gaps_copy[old_idx]
    
    def get_unpermuted_gaps(self):
        if self.seq_permutation_inv is None:
            return self.gaps_per_sequence[:self.original_n_seqs]
        
        original_order = [0] * self.original_n_seqs
        for new_idx, old_idx in enumerate(self.seq_permutation):
            original_order[old_idx] = self.gaps_per_sequence[new_idx]
        return original_order

In [None]:
def run_prediction_episode(env, agent, expected_gaps, max_factor=1.25):
    import math
    import numpy as np
    
    state = env.reset()
    max_steps = max(1, int(math.ceil(expected_gaps * max_factor)))
    steps, rewards = 0, []
    
    for t in range(max_steps):
        valid_mask = env.get_valid_action_mask().astype(bool)
        
        q_values = agent.forward(state)
        action = masked_epsilon_greedy(
            q_values=q_values,
            valid_mask=valid_mask,
            epsilon=getattr(agent, "epsilon", 0.1),
        )
        
        reward, next_state, done = env.step(action)
        rewards.append(float(reward))
        
        if hasattr(agent, "replay_memory"):
            agent.replay_memory.store(
                (state, next_state, action, reward, bool(done), valid_mask.copy())
            )
        
        state = next_state
        steps += 1
        
        if done == 1 or not valid_mask.any():
            break
    
    gaps_dist = env.get_unpermuted_gaps() if hasattr(env, "get_unpermuted_gaps") \
                else env.gaps_per_sequence[:env.original_n_seqs]
    
    return {
        "steps": steps,
        "avg_reward": float(np.mean(rewards)) if rewards else 0.0,
        "sum_reward": float(np.sum(rewards)) if rewards else 0.0,
        "gaps_distribution": gaps_dist,
    }

## 7. Load and Analyze Data

In [None]:
import os
import json
import random
import numpy as np
from pathlib import Path

try:
    project_root
except NameError:
    project_root = Path.cwd()

def _parse_accepted_pairs(accepted_pairs_raw):
    pairs = []
    if isinstance(accepted_pairs_raw, str):
        for pair in accepted_pairs_raw.split(','):
            if ':' in pair:
                a, b = pair.split(':', 1)
                pairs.append((a.strip(), b.strip()))
    elif isinstance(accepted_pairs_raw, list):
        for pair in accepted_pairs_raw:
            if isinstance(pair, (list, tuple)) and len(pair) == 2:
                pairs.append((str(pair[0]), str(pair[1])))
    return pairs

def _count_inserted_gaps_from_sequences(start, solution):
    dash_start = sum(str(s).count('-') for s in start)
    dash_solution = sum(str(s).count('-') for s in solution)
    return max(0, dash_solution - dash_start)

def load_samples_from_jsonl(jsonl_path='hackaton_extracted_data.jsonl', max_samples=None, shuffle_data=True):
    if not os.path.isabs(jsonl_path):
        jsonl_path = project_root / jsonl_path

    if not os.path.exists(jsonl_path):
        return []

    samples = []
    
    with open(jsonl_path, 'r') as f:
        for line_num, line in enumerate(f):
            if max_samples and len(samples) >= max_samples:
                break

            try:
                data = json.loads(line)
                
                start = data.get('start', [])
                solution = data.get('solution', [])
                accepted_pairs_raw = data.get('accepted_pairs', [])
                moves = data.get('moves', [])
                
                accepted_pairs = _parse_accepted_pairs(accepted_pairs_raw)
                
                if isinstance(moves, list):
                    n_gaps = len(moves)
                elif isinstance(moves, int):
                    n_gaps = moves
                else:
                    n_gaps = _count_inserted_gaps_from_sequences(start, solution)
                
                n_gaps = max(0, int(n_gaps))
                
                if start and solution and accepted_pairs and len(start) == len(solution):
                    samples.append({
                        'start': start,
                        'solution': solution,
                        'accepted_pairs': accepted_pairs,
                        'n_gaps': n_gaps,
                        'moves': moves,
                        'n_sequences': len(start),
                        'consensus_length': len(accepted_pairs),
                        'idx': line_num,
                    })
            except Exception:
                continue

    if shuffle_data and samples:
        random.shuffle(samples)

    return samples

## 8. Dimension Analysis (Using ALL Data with Padding)

In [None]:
max_n_seqs = max(s['n_sequences'] for s in samples)
max_cons_len = max(s['consensus_length'] for s in samples)

for sample in samples:
    sample['original_n_seqs'] = sample['n_sequences']
    sample['original_cons_len'] = sample['consensus_length']

## 9. Training with Variable Dimensions (Using Padding)

In [None]:
train_size = int(0.8 * len(samples))
train_samples = samples[:train_size]
test_samples = samples[train_size:]

action_number = max_n_seqs * max_cons_len
agent = DQNAgent(
    action_number=action_number,
    num_seqs=max_n_seqs,
    max_grid=max_cons_len - 1,
    max_value=(max_cons_len - 1) * 100,
    epsilon=0.9,
    delta=0.01,
    batch_size=64,
    gamma=0.99,
    learning_rate=0.001,
    memory_size=5000
)

epochs = 10
samples_per_epoch = 100

for epoch in range(epochs):
    epoch_samples = random.sample(train_samples, min(len(train_samples), samples_per_epoch))
    epoch_rewards = []
    
    for sample in tqdm(epoch_samples, desc=f"Epoch {epoch+1}/{epochs}"):
        env = AlignmentEnvironment(
            sample['start'],
            sample['accepted_pairs'],
            get_expected_gaps(sample),
            max_n_seqs=max_n_seqs,
            max_cons_len=max_cons_len
        )
        
        state = env.reset()
        env.randomize_sequence_order(apply=True)
        
        episode_reward = 0
        step_count = 0
        max_steps = get_expected_gaps(sample) * 5
        valid_mask = env.get_valid_action_mask()
        
        while step_count < max_steps:
            action = agent.select_action(state, valid_action_mask=valid_mask)
            reward, next_state, done = env.step(action)
            
            next_valid_mask = env.get_valid_action_mask()
            agent.replay_memory.store((state, next_state, action, reward, done, next_valid_mask))
            
            agent.update()
            
            episode_reward += reward
            step_count += 1
            state = next_state
            valid_mask = next_valid_mask
            
            if done == 1:
                break
        
        agent.update_epsilon()
        epoch_rewards.append(episode_reward)
    
    avg_reward = np.mean(epoch_rewards)
    print(f"Epoch {epoch+1}: reward={avg_reward:.2f}, memory={agent.replay_memory.size}, Îµ={agent.current_epsilon:.3f}")

In [None]:
def run_dqn_inference(agent, env, expected_gaps, max_steps=None):
    if max_steps is None:
        max_steps = expected_gaps * 2
    
    state = env.reset()
    env.randomize_sequence_order(apply=True)
    
    actions_taken = []
    valid_mask = env.get_valid_action_mask()
    
    for step in range(max_steps):
        if valid_mask.sum() == 0:
            break
        
        action = agent.predict(state, valid_action_mask=valid_mask)
        
        seq_idx = action // env.max_cons_len
        pos = action % env.max_cons_len
        actions_taken.append((seq_idx, pos))
        
        _, next_state, done = env.step(action)
        state = next_state
        valid_mask = env.get_valid_action_mask()
        
        if done == 1:
            break
    
    predicted = env.sequences[:env.original_n_seqs]
    
    return predicted, actions_taken