In [2]:
import torch
from typing import Dict, List

In [3]:
CODON_TABLE : Dict[str, List[str]] = {
    'A': ['GCU', 'GCC', 'GCA', 'GCG'],
    'C': ['UGU', 'UGC'],
    'D': ['GAU', 'GAC'],
    'E': ['GAA', 'GAG'],
    'F': ['UUU', 'UUC'],
    'G': ['GGU', 'GGC', 'GGA', 'GGG'],
    'H': ['CAU', 'CAC'],
    'I': ['AUU', 'AUC', 'AUA'],
    'K': ['AAA', 'AAG'],
    'L': ['UUA', 'UUG', 'CUU', 'CUC', 'CUA', 'CUG'],
    'M': ['AUG'],
    'N': ['AAU', 'AAC'],
    'P': ['CCU', 'CCC', 'CCA', 'CCG'],
    'Q': ['CAA', 'CAG'],
    'R': ['CGU', 'CGC', 'CGA', 'CGG', 'AGA', 'AGG'],
    'S': ['UCU', 'UCC', 'UCA', 'UCG', 'AGU', 'AGC'],
    'T': ['ACU', 'ACC', 'ACA', 'ACG'],
    'V': ['GUU', 'GUC', 'GUA', 'GUG'],
    'W': ['UGG'],
    'Y': ['UAU', 'UAC'],
    '*': ['UAA', 'UAG', 'UGA'],  # Stop codons
}

# Dictionary ambiguous amino acids to standard amino acids
AMBIGUOUS_AMINOACID_MAP: Dict[str, list[str]] = {
    "B": ["N", "D"],  # Asparagine (N) or Aspartic acid (D)
    "Z": ["Q", "E"],  # Glutamine (Q) or Glutamic acid (E)
    "X": ["A"],  # Any amino acid (typically replaced with Alanine)
    "J": ["L", "I"],  # Leucine (L) or Isoleucine (I)
    "U": ["C"],  # Selenocysteine (typically replaced with Cysteine)
    "O": ["K"],  # Pyrrolysine (typically replaced with Lysine)
}

AA_LIST = list(CODON_TABLE.keys())

In [None]:
def protein_to_tensor(protein):

    amino_acid_counts = [0] * len(protein)

    if protein is None or protein == '':
        return torch.tensor(amino_acid_counts, dtype=torch.float)
    
    for amino_acid in protein:
        if amino_acid in AA_LIST:
            idx = AA_LIST.index(amino_acid)
            amino_acid_counts[idx] += 1

    return torch.tensor(amino_acid_counts, dtype=torch.float)

In [9]:
protein = 'MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG'

In [10]:
t = protein_to_tensor(protein)
t

tensor([ 8.,  2.,  2.,  5.,  3.,  7.,  2.,  0.,  1., 14.,  2.,  1.,  4.,  3.,
         4.,  1.,  2.,  5.,  2.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])

In [12]:
len(protein)

70

In [13]:
amino_acid_counts = [0] * len(protein)

for amino_acid in protein:
    if amino_acid in AA_LIST:
        idx = AA_LIST.index(amino_acid)
        amino_acid_counts[idx] += 1

In [13]:
import torch
device = torch.device("cpu")

In [14]:
s0 = torch.zeros(0, dtype=torch.long, device=device)  # Use torch.long for integer indices

# Mapping from codons to indices
codon_to_index = {"AUG": 0, "UAA": 1, "UAG": 2}  # Example mapping, extend as needed

def step(state, action):
    # Convert the codon to its corresponding index
    action_index = torch.tensor([codon_to_index[action]], dtype=torch.long, device=device)
    # Append the chosen codon index to the mRNA sequence
    return torch.cat([state, action_index.unsqueeze(-1)], dim=-1)

In [17]:
actions = ["AUG", "UAA", "UAG"]

state1 = step(s0, actions[0])
state2 = step(state1, actions[1])
state3 = step(state2, actions[2])

state3

tensor([[0, 1, 2]])

In [23]:
len(s0)

0

In [19]:
from typing import Dict, List

In [20]:
CODON_TABLE : Dict[str, List[str]] = {
    'A': ['GCU', 'GCC', 'GCA', 'GCG'],
    'C': ['UGU', 'UGC'],
    'D': ['GAU', 'GAC'],
    'E': ['GAA', 'GAG'],
    'F': ['UUU', 'UUC'],
    'G': ['GGU', 'GGC', 'GGA', 'GGG'],
    'H': ['CAU', 'CAC'],
    'I': ['AUU', 'AUC', 'AUA'],
    'K': ['AAA', 'AAG'],
    'L': ['UUA', 'UUG', 'CUU', 'CUC', 'CUA', 'CUG'],
    'M': ['AUG'],
    'N': ['AAU', 'AAC'],
    'P': ['CCU', 'CCC', 'CCA', 'CCG'],
    'Q': ['CAA', 'CAG'],
    'R': ['CGU', 'CGC', 'CGA', 'CGG', 'AGA', 'AGG'],
    'S': ['UCU', 'UCC', 'UCA', 'UCG', 'AGU', 'AGC'],
    'T': ['ACU', 'ACC', 'ACA', 'ACG'],
    'V': ['GUU', 'GUC', 'GUA', 'GUG'],
    'W': ['UGG'],
    'Y': ['UAU', 'UAC'],
    '*': ['UAA', 'UAG', 'UGA'],  # Stop codons
}

# Dictionary ambiguous amino acids to standard amino acids
AMBIGUOUS_AMINOACID_MAP: Dict[str, list[str]] = {
    "B": ["N", "D"],  # Asparagine (N) or Aspartic acid (D)
    "Z": ["Q", "E"],  # Glutamine (Q) or Glutamic acid (E)
    "X": ["A"],  # Any amino acid (typically replaced with Alanine)
    "J": ["L", "I"],  # Leucine (L) or Isoleucine (I)
    "U": ["C"],  # Selenocysteine (typically replaced with Cysteine)
    "O": ["K"],  # Pyrrolysine (typically replaced with Lysine)
}


AA_LIST = list(CODON_TABLE.keys())
AMBIG_AA_LIST = list(AMBIGUOUS_AMINOACID_MAP.keys())

CODON_MAP = {codon: i for i, codon in enumerate(sorted(set(c for codons in CODON_TABLE.values() for c in codons)))}

In [21]:
CODON_MAP

{'AAA': 0,
 'AAC': 1,
 'AAG': 2,
 'AAU': 3,
 'ACA': 4,
 'ACC': 5,
 'ACG': 6,
 'ACU': 7,
 'AGA': 8,
 'AGC': 9,
 'AGG': 10,
 'AGU': 11,
 'AUA': 12,
 'AUC': 13,
 'AUG': 14,
 'AUU': 15,
 'CAA': 16,
 'CAC': 17,
 'CAG': 18,
 'CAU': 19,
 'CCA': 20,
 'CCC': 21,
 'CCG': 22,
 'CCU': 23,
 'CGA': 24,
 'CGC': 25,
 'CGG': 26,
 'CGU': 27,
 'CUA': 28,
 'CUC': 29,
 'CUG': 30,
 'CUU': 31,
 'GAA': 32,
 'GAC': 33,
 'GAG': 34,
 'GAU': 35,
 'GCA': 36,
 'GCC': 37,
 'GCG': 38,
 'GCU': 39,
 'GGA': 40,
 'GGC': 41,
 'GGG': 42,
 'GGU': 43,
 'GUA': 44,
 'GUC': 45,
 'GUG': 46,
 'GUU': 47,
 'UAA': 48,
 'UAC': 49,
 'UAG': 50,
 'UAU': 51,
 'UCA': 52,
 'UCC': 53,
 'UCG': 54,
 'UCU': 55,
 'UGA': 56,
 'UGC': 57,
 'UGG': 58,
 'UGU': 59,
 'UUA': 60,
 'UUC': 61,
 'UUG': 62,
 'UUU': 63}

In [None]:
import torch
from typing import List, Dict, Union
from torchgfn.env import DiscreteEnv

# Amino acids and codon tables
AMINO_ACIDS: List[str] = [
    "A", "C", "D", "E", "F", "G", "H", "I", "K", "L",
    "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"
]

CODON_TABLE: Dict[str, List[str]] = {
    'A': ['GCU', 'GCC', 'GCA', 'GCG'],
    'C': ['UGU', 'UGC'],
    'D': ['GAU', 'GAC'],
    'E': ['GAA', 'GAG'],
    'F': ['UUU', 'UUC'],
    'G': ['GGU', 'GGC', 'GGA', 'GGG'],
    'H': ['CAU', 'CAC'],
    'I': ['AUU', 'AUC', 'AUA'],
    'K': ['AAA', 'AAG'],
    'L': ['UUA', 'UUG', 'CUU', 'CUC', 'CUA', 'CUG'],
    'M': ['AUG'],
    'N': ['AAU', 'AAC'],
    'P': ['CCU', 'CCC', 'CCA', 'CCG'],
    'Q': ['CAA', 'CAG'],
    'R': ['CGU', 'CGC', 'CGA', 'CGG', 'AGA', 'AGG'],
    'S': ['UCU', 'UCC', 'UCA', 'UCG', 'AGU', 'AGC'],
    'T': ['ACU', 'ACC', 'ACA', 'ACG'],
    'V': ['GUU', 'GUC', 'GUA', 'GUG'],
    'W': ['UGG'],
    'Y': ['UAU', 'UAC'],
    '*': ['UAA', 'UAG', 'UGA'],  # Stop codons
}

# Flatten and index all codons
ALL_CODONS = sorted({codon for codons in CODON_TABLE.values() for codon in codons})
CODON_TO_IDX = {codon: idx for idx, codon in enumerate(ALL_CODONS)}
IDX_TO_CODON = {idx: codon for codon, idx in CODON_TO_IDX.items()}


def get_synonymous_indices(amino_acid: str) -> List[int]:
    """
    Return the list of global codon indices that encode the given amino acid.
    """
    codons = CODON_TABLE.get(amino_acid, [])
    return [CODON_TO_IDX[c] for c in codons]


def compute_gc_content_from_indices(indices: torch.LongTensor) -> torch.FloatTensor:
    """
    Given a tensor of codon indices (batch x seq_len), compute GC-content (%) per sequence.
    """
    # Map indices to codon strings
    batch, seq_len = indices.shape
    # Expand to list of strings
    contents = []
    for seq in indices:
        rna = ''.join(IDX_TO_CODON[int(i)] for i in seq)
        gc = (rna.count('G') + rna.count('C')) / len(rna) * 100 if len(rna) > 0 else 0.0
        contents.append(gc)
    return torch.tensor(contents, dtype=torch.float, device=indices.device)


class CodonDesignEnv(DiscreteEnv):
    """
    Environment for designing mRNA codon sequences for a given protein.
    States are LongTensors of shape (batch, t) representing chosen codon indices.
    Action space is global codon set of size len(ALL_CODONS);
    dynamic masks restrict to synonymous codons at each step.
    Rewards are GC-content of the full sequence so far.
    """

    def __init__(
        self,
        protein_seq: str,
        discount_factor: float = 1.0,
    ):
        self.protein_seq = protein_seq
        self.seq_length = len(protein_seq)
        self.n_actions = len(ALL_CODONS)
        self.device = torch.device('cpu')

        # Precompute valid indices per position
        self.syn_indices = [get_synonymous_indices(aa) for aa in protein_seq]

        # Initial empty state
        initial_state = torch.empty((0,), dtype=torch.long, device=self.device)

        super().__init__(
            n_actions=self.n_actions,
            initial_state=initial_state,
            state_shape=(None,),         # variable-length
            action_shape=(),             # scalar action
            dummy_action=None,
            exit_action=None,
            discount_factor=discount_factor,
        )

    def step(
        self,
        states: torch.LongTensor,
        actions: torch.LongTensor,
    ) -> torch.LongTensor:
        # Append action indices to states
        # states: (batch, t), actions: (batch,)
        actions = actions.unsqueeze(-1)
        return torch.cat([states, actions], dim=1)

    def backward_step(
        self,
        states: torch.LongTensor,
    ) -> torch.LongTensor:
        # Remove last codon index
        return states[:, :-1]

    def update_masks(
        self,
        states: torch.LongTensor,
    ) -> torch.BoolTensor:
        # For each sequence in batch, mask only synonymous codons for next aa.
        batch = states.shape[0]
        next_pos = states.shape[1]
        if next_pos >= self.seq_length:
            # No valid actions beyond terminal
            return torch.zeros((batch, self.n_actions), dtype=torch.bool, device=self.device)
        valid = torch.zeros((batch, self.n_actions), dtype=torch.bool, device=self.device)
        valid_indices = self.syn_indices[next_pos]
        valid[:, valid_indices] = True
        return valid

    def reward(
        self,
        states: torch.LongTensor,
    ) -> torch.FloatTensor:
        # Compute GC-content percentage of each sequence
        return compute_gc_content_from_indices(states)

    def reset(self, batch_size: int) -> torch.LongTensor:
        # Return batch of empty sequences
        return torch.empty((batch_size, 0), dtype=torch.long, device=self.device)

    def is_terminal(
        self,
        states: torch.LongTensor,
    ) -> torch.BoolTensor:
        # Terminal when sequence length equals protein length
        return states.shape[1] >= self.seq_length
