# Breakdown overview of protein programming language

Aim: protein design based on protein language model
problem: designed seq may not fold experimentally
solution: build a MCMC optimizer on enegy function ( such as plddt score)

### PART I. ENERGY FUNCTION
metrics: calculated based on structure or output from ESM_fold model 
Goal is to use this metrics to guide optimization of protein

In [1]:
import os
import sys
sys.path.append('/home/yunyao/esm/examples/protein_programming_language')

In [2]:
import esm

In [3]:
from abc import ABC, abstractmethod
from typing import List, Optional

import numpy as np
#from biotite.structure import annotate_sse, AtomArray, rmsd, sasa, superimpose

from language.folding_callbacks import FoldingResult
from language.utilities import get_atomarray_in_residue_range

In [6]:
# Energy functions 
# This is a abstaract base class (ABC), it enforce every energy function has a compute method
from abc import ABC, abstractmethod
from typing import List, Optional

import numpy as np
from biotite.structure import annotate_sse, AtomArray, rmsd, sasa, superimpose

#from language.folding_callbacks import FoldingResult
from language.utilities import get_atomarray_in_residue_range

 # node and folding_results are outputs from esmfold model
class EnergyTerm(ABC):
    def __init__(self) -> None:
        pass

    @abstractmethod
    def compute(self, node, folding_result: FoldingResult) -> float:
        pass

#1 energy function predicted TM-score (PTM) 
class MaximizePTM(EnergyTerm):
    def __init__(self) -> None:
        super().__init__()

    def compute(self, node, folding_result: FoldingResult) -> float:
        del node
        return 1.0 - folding_result.ptm

#2 energy function plddt
class MaximizePLDDT(EnergyTerm):
    def __init__(self) -> None:
        super().__init__()

    def compute(self, node, folding_result: FoldingResult) -> float:
        del node
        return 1.0 - folding_result.plddt

#3 symmetry degree of complex made of  multimers (protomer)
class SymmetryRing(EnergyTerm):
    def __init__(self, all_to_all_protomer_symmetry: bool = False) -> None:
        super().__init__()
        self.all_to_all_protomer_symmetry: bool = all_to_all_protomer_symmetry

    def compute(self, node, folding_result: FoldingResult) -> float:
        protomer_nodes = node.get_children() # get sub units
        protomer_residue_ranges = [
            protomer_node.get_residue_index_range() for protomer_node in protomer_nodes
        ] # get residue index

        centers_of_mass = [] # placeholder for COM of each protomer
        for start, end in protomer_residue_ranges:
            backbone_coordinates = get_backbone_atoms(
                folding_result.atoms[
                    np.logical_and(
                        folding_result.atoms.res_id >= start,
                        folding_result.atoms.res_id < end,
                    )
                ]
            ).coord
            centers_of_mass.append(get_center_of_mass(backbone_coordinates))
        centers_of_mass = np.vstack(centers_of_mass)

        return (
            float(np.std(pairwise_distances(centers_of_mass)))
            if self.all_to_all_protomer_symmetry
            else float(np.std(adjacent_distances(centers_of_mass)))
        )


def get_backbone_atoms(atoms:AtomArray) -> AtomArray:
    return atoms[
        (atoms.atom_name == "CA") | (atoms.atom_name == "N") | (atoms.atom_name == "C")
    ]


def _is_Nx3(array: np.ndarray) -> bool:
    return len(array.shape) == 2 and array.shape[1] == 3


def get_center_of_mass(coordinates: np.ndarray) -> np.ndarray:
    assert _is_Nx3(coordinates), "Coordinates must be Nx3."
    return coordinates.mean(axis=0).reshape(1, 3)

# two way to estimate symmetry, one is based on pairwaise_distance 
def pairwise_distances(coordinates: np.ndarray) -> np.ndarray:
    assert _is_Nx3(coordinates), "Coordinates must be Nx3."
    m = coordinates[:, np.newaxis, :] - coordinates[np.newaxis, :, :]
    distance_matrix = np.linalg.norm(m, axis=-1)
    return distance_matrix[np.triu_indices(distance_matrix.shape[0], k=1)]

# second way is only use adjacent distance
def adjacent_distances(coordinates: np.ndarray) -> np.ndarray:
    assert _is_Nx3(coordinates), "Coordinates must be Nx3."
    m = coordinates - np.roll(coordinates, shift=1, axis=0)
    return np.linalg.norm(m, axis=-1)

#4 surface hydrophobic area (# hydrophobic residue on surfeace/ total # hydrophobic residues)
# whether a residue is on surface is determinzed by sasa (solvent-accessible surface area) from biotite.structure

class MinimizeSurfaceHydrophobics(EnergyTerm):
    def __init__(self) -> None:
        super().__init__()

    def compute(self, node, folding_result: FoldingResult) -> float:
        start, end = node.get_residue_index_range()

        return hydrophobic_score(folding_result.atoms, start, end)


_HYDROPHOBICS = {"VAL", "ILE", "LEU", "PHE", "MET", "TRP"}


def hydrophobic_score(
    atom_array: AtomArray,
    start_residue_index: Optional[int] = None,
    end_residue_index: Optional[int] = None,
) -> float:
    """
    Computes ratio of hydrophobic atoms in a biotite AtomArray that are also surface
    exposed. Typically, lower is better.
    """

    hydrophobic_mask = np.array([aa in _HYDROPHOBICS for aa in atom_array.res_name])

    if start_residue_index is None and end_residue_index is None:
        selection_mask = np.ones_like(hydrophobic_mask)
    else:
        start_residue_index = 0 if start_residue_index is None else start_residue_index
        end_residue_index = (
            len(hydrophobic_mask) if end_residue_index is None else end_residue_index
        )
        selection_mask = np.array(
            [
                i >= start_residue_index and i < end_residue_index
                for i in range(len(hydrophobic_mask))
            ]
        )

    # TODO(scandido): Resolve the float/bool thing going on here.
    hydrophobic_surf = np.logical_and(
        selection_mask * hydrophobic_mask, sasa(atom_array)
    )
    # TODO(brianhie): Figure out how to handle divide-by-zero.
    return sum(hydrophobic_surf) / sum(selection_mask * hydrophobic_mask)

# energy function 5: surface Exposeure 
# # of residue on the surface / total # of residues
# I think this is not ideal. We should maximize charged residues on the surface
class MinimizeSurfaceExposure(EnergyTerm):
    def __init__(self) -> None:
        super().__init__()

    def compute(self, node, folding_result: FoldingResult) -> float:
        start, end = node.get_residue_index_range()

        return surface_ratio(folding_result.atoms, list(range(start, end)))


class MaximizeSurfaceExposure(EnergyTerm):
    def __init__(self) -> None:
        super().__init__()

    def compute(self, node, folding_result: FoldingResult) -> float:
        start, end = node.get_residue_index_range()

        return 1.0 - surface_ratio(folding_result.atoms, list(range(start, end)))


def surface_ratio(atom_array: AtomArray, residue_indices: List[int]) -> float:
    """Computes ratio of atoms in specified ratios which are on the protein surface."""

    residue_mask = np.array([res_id in residue_indices for res_id in atom_array.res_id])
    surface = np.logical_and(residue_mask, sasa(atom_array))
    return sum(surface) / sum(residue_mask)


# energy function 6:  Radius of gyratio 
class MaximizeGlobularity(EnergyTerm):
    def __init__(self) -> None:
        super().__init__()

    def compute(self, node, folding_result: FoldingResult) -> float:
        start, end = node.get_residue_index_range()

        backbone = get_backbone_atoms(
            folding_result.atoms[
                np.logical_and(
                    folding_result.atoms.res_id >= start,
                    folding_result.atoms.res_id < end,
                )
            ]
        ).coord

        return float(np.std(distances_to_centroid(backbone)))


def distances_to_centroid(coordinates: np.ndarray) -> np.ndarray:
    """
    Computes the distances from each of the coordinates to the
    centroid of all coordinates.
    """
    assert _is_Nx3(coordinates), "Coordinates must be Nx3."
    center_of_mass = get_center_of_mass(coordinates)
    m = coordinates - center_of_mass
    return np.linalg.norm(m, axis=-1)

# energy function 7: minimize rmsd to referece structure 
# two methods are implemented
# first one is to use superimpose to align structure first then calculate rmsd
# problem align structure could be highly inaccurate
class MinimizeCRmsd(EnergyTerm):
    def __init__(self, template: AtomArray, backbone_only: bool = False) -> None:
        super().__init__()

        self.template: AtomArray = template
        self.backbone_only: bool = backbone_only
        if self.backbone_only:
            self.template = get_backbone_atoms(template)

    def compute(self, node, folding_result: FoldingResult) -> float:
        start, end = node.get_residue_index_range()

        atoms = get_atomarray_in_residue_range(folding_result.atoms, start, end)

        if self.backbone_only:
            atoms = get_backbone_atoms(atoms)

        return crmsd(self.template, atoms)


def crmsd(atom_array_a: AtomArray, atom_array_b: AtomArray) -> float:
    # TODO(scandido): Add this back.
    # atom_array_a = canonicalize_within_residue_atom_order(atom_array_a)
    # atom_array_b = canonicalize_within_residue_atom_order(atom_array_b)
    superimposed_atom_array_b_onto_a, _ = superimpose(atom_array_a, atom_array_b)
    return float(rmsd(atom_array_a, superimposed_atom_array_b_onto_a).mean())

# the second method: calculate pairwise distance, then use distance to estimate rmsd (translation invirant)
class MinimizeDRmsd(EnergyTerm):
    def __init__(self, template: AtomArray, backbone_only: bool = False) -> None:
        super().__init__()

        self.template: AtomArray = template
        self.backbone_only: bool = backbone_only
        if self.backbone_only:
            self.template = get_backbone_atoms(template)

    def compute(self, node, folding_result: FoldingResult) -> float:
        start, end = node.get_residue_index_range()

        atoms = get_atomarray_in_residue_range(folding_result.atoms, start, end)

        if self.backbone_only:
            atoms = get_backbone_atoms(atoms)

        return drmsd(self.template, atoms)

def drmsd(atom_array_a: AtomArray, atom_array_b: AtomArray) -> float:
    # TODO(scandido): Add this back.
    # atom_array_a = canonicalize_within_residue_atom_order(atom_array_a)
    # atom_array_b = canonicalize_within_residue_atom_order(atom_array_b)

    dp = pairwise_distances(atom_array_a.coord)
    dq = pairwise_distances(atom_array_b.coord)

    return float(np.sqrt(((dp - dq) ** 2).mean()))


def pairwise_distances(coordinates: np.ndarray) -> np.ndarray:
    assert _is_Nx3(coordinates), "Coordinates must be Nx3."
    m = coordinates[:, np.newaxis, :] - coordinates[np.newaxis, :, :]
    distance_matrix = np.linalg.norm(m, axis=-1)
    return distance_matrix[np.triu_indices(distance_matrix.shape[0], k=1)]


# energy function 7: match secondary structure to aimed ss

class MatchSecondaryStructure(EnergyTerm):
    def __init__(self, secondary_structure_element: str) -> None:
        super().__init__()
        self.secondary_structure_element = secondary_structure_element

    def compute(self, node, folding_result: FoldingResult) -> float:
        start, end = node.get_residue_index_range()

        subprotein = folding_result.atoms[
            np.logical_and(
                folding_result.atoms.res_id >= start,
                folding_result.atoms.res_id < end,
            )
        ]
        sse = annotate_sse(subprotein) # calculate ss 

        return np.mean(sse != self.secondary_structure_element)



### PART II. How we will use ESM_fold model to generate strucutres and metrics like pldda

In [9]:
# This is the main engine : taking seqence run ESMFOLD and ouput structure and plddt ptm scores
import esm
import torch
from biotite.structure import AtomArray
import numpy as np
from language.openfold_modules import atom_order
#from openfold.np.residue_constants import atom_order
from torch.utils._pytree import tree_map

from language.utilities import pdb_file_to_atomarray
from dataclasses import dataclass



@dataclass
class FoldingResult:
    atoms: AtomArray
    ptm: float
    plddt: float


class FoldingCallback(ABC):
    "Interface for running ESMFold and other folding methods."

    def __init__(self) -> None:
        pass

    @abstractmethod
    def load(self, device: str) -> None:
        pass

    @abstractmethod
    def fold(self, sequence: str, residue_indices: List[int]) -> FoldingResult:
        pass


class EsmFoldv1(FoldingCallback):
    "Runs ESMFold v1.0."

    def __init__(self) -> None:
        super().__init__()

        self.model = None

    def load(self, device: str) -> None:
        self.model = esm.pretrained.esmfold_v1().eval()
        self.model = self.model.to(device)

    def fold(self, sequence: str, residue_indices: List[int]) -> FoldingResult:
        assert self.model is not None, "Must call load() before fold()."

        # TODO: Current `esm.esmfold.v1.misc.output_to_pdb()` adds 1 to the `residx`
        # mistakenly, just subtract 1 for now but fix in a later version.
        residue_indices = np.array(residue_indices) - 1

        raw_output = self.model.infer(
            sequence, residx=torch.Tensor(residue_indices).long().reshape(1, -1),
        )
        raw_output = tree_map(lambda x: x.to("cpu"), raw_output)

        pdb_string = esm.esmfold.v1.misc.output_to_pdb(raw_output)[0]
        atoms: AtomArray = pdb_file_to_atomarray(StringIO(pdb_string))

        plddt = raw_output["plddt"]
        plddt = plddt[0, ...].numpy()
        plddt = plddt.transpose()
        plddt = plddt[atom_order["CA"], :]
        plddt = float(plddt.mean()) / 100.0

        ptm = float(raw_output["ptm"])

        return FoldingResult(atoms=atoms, ptm=ptm, plddt=plddt)

### PART III. Algorithm used to optimze protein


In [17]:
# Now we look at the Optimizer (Matropolis-Hasing ceriteria and stimulated annnealiing)
# Quite conventional, combing stimulated annealing (a constant decayed temperature factor)
#and Matropolis certeria(allowing higher erengy state by a probablistic/stochastic fashion)
# all state info is stored in a dataclass (similar to a dictionary)
from copy import deepcopy
from dataclasses import dataclass

import numpy as np
from rich.live import Live
from rich.table import Table

class MetropolisHastingsState:
    program: ProgramNode
    temperature: float
    annealing_rate: float
    num_steps: int
    candidate_energy: float
    candidate_energy_term_fn_values: list
    current_energy: float
    current_energy_term_fn_values: list
    best_energy: float
    best_energy_term_fn_values: list

#the ProgramNode contain all molecule info
# How one step MetropolisHasting optimization work
# there are three container: candidate-seq to be examinzed
# if candidate gets accepted, it becomes 'current' (current may not have the least energy acorrding to MC)
# The 'best' is the one saving the best score 
# in each MC step, current and best are updated or not based on MC algorithm

def metropolis_hastings_step(
    state: MetropolisHastingsState,
    folding_callback: FoldingCallback,
    verbose: bool = False,
) -> MetropolisHastingsState:
    temperature = state.temperature * state.annealing_rate

    candidate: ProgramNode = deepcopy(state.program)
    candidate.mutate()

    sequence, residue_indices = candidate.get_sequence_and_set_residue_index_ranges()
    folding_output = folding_callback.fold(sequence, residue_indices)

    energy_term_fns = candidate.get_energy_term_functions()
    candidate_energy_term_fn_values = [
        (name, weight, energy_fn(folding_output)) for name, weight, energy_fn in energy_term_fns
    ]
    # TODO(scandido): Log these.
    candidate_energy: float = sum(
        [weight * value for _, weight, value in candidate_energy_term_fn_values]
    )

    accept_candidate = False
    if state.current_energy is None:
        accept_candidate = True
    else:
        # NOTE(scandido): We are minimizing the function here so instead of
        # candidate - current we do -1 * (candidate - current) = -candidate + current.
        energy_differential: float = -candidate_energy + state.current_energy
        accept_probability: float = np.clip(
            # NOTE(scandido): We approximate the ratio of transition probabilities from
            # current to candidate vs. candidate to current to be equal, which is
            # approximately correct.
            np.exp(energy_differential / temperature), # high temperatre, higher accept_proabability
            a_min=None,
            a_max=1.0,
        )
        accept_candidate: bool = np.random.uniform() < accept_probability # the stochastic step

    if accept_candidate and verbose:
        print(f"Accepted {sequence} with energy {candidate_energy:.2f}.")

    best = (state.best_energy is None) or candidate_energy < state.best_energy

    return MetropolisHastingsState(
        program=candidate if accept_candidate else state.program,
        temperature=temperature,
        annealing_rate=state.annealing_rate,
        num_steps=state.num_steps + 1,
        candidate_energy=candidate_energy,
        candidate_energy_term_fn_values=candidate_energy_term_fn_values,
        current_energy=candidate_energy if accept_candidate else state.current_energy,
        current_energy_term_fn_values=candidate_energy_term_fn_values
        if accept_candidate
        else state.current_energy_term_fn_values,
        best_energy=candidate_energy if best else state.best_energy,
        best_energy_term_fn_values=candidate_energy_term_fn_values
        if best
        else state.best_energy_term_fn_values,
    )

# Temperature is lower expoentially by annealing_rate, here annealing rate is unchange, but we could also make it decay by increasing interation step
# The following is how n step of iterations is run and logged.
def run_simulated_annealing(
    program: ProgramNode,
    initial_temperature: float,
    annealing_rate: float,
    total_num_steps: int,
    folding_callback: FoldingCallback,
    display_progress: bool = True,
    progress_verbose_print: bool = False,
) -> ProgramNode:
    # TODO(scandido): Track accept rate.

    state = MetropolisHastingsState(
        program=program,
        temperature=initial_temperature,
        annealing_rate=annealing_rate,
        num_steps=0,
        candidate_energy=None,
        candidate_energy_term_fn_values=None,
        current_energy=None,
        current_energy_term_fn_values=None,
        best_energy=None,
        best_energy_term_fn_values=None,
    )

    def _generate_table(state):
        table = Table()
        table.add_column("Energy name")
        table.add_column("Weight")
        table.add_column("Candidate Value")
        table.add_column("Current Value")
        table.add_column("Best Value")
        if state.current_energy_term_fn_values is None:
            return table
        for (name, weight, candidate_value), (_, _, current_value), (_, _, best_value) in zip(
            state.candidate_energy_term_fn_values,
            state.current_energy_term_fn_values,
            state.best_energy_term_fn_values,
        ):
            table.add_row(
                name,
                f"{weight:.2f}",
                f"{candidate_value:.2f}",
                f"{current_value:.2f}",
                f"{best_value:.2f}",
            )
        table.add_row(
            "Energy",
            "",
            f"{state.candidate_energy:.2f}",
            f"{state.current_energy:.2f}",
            f"{state.best_energy:.2f}",
        )
        table.add_row("Iterations", "", f"{state.num_steps} / {total_num_steps}")
        return table

    with Live() as live:
        for _ in range(1, total_num_steps + 1):
            state = metropolis_hastings_step(
                state,
                folding_callback,
                verbose=progress_verbose_print,
            )
            if display_progress:
                live.update(_generate_table(state))

    return state.program

### PART IV: How Protein is changed (mutation/deletion/insertion)

There are two type of modifiction ons sequence:

1. Fixed-length: only muntation
2. Variable-length: allow for mutation del and insert with certain provided probablity



In [13]:
from typing import Optional, List, Dict, Union


class SequenceSegmentFactory(ABC):
    def __init__(self) -> None:
        pass

    @abstractmethod
    def get(self) -> str:
        pass

    @abstractmethod
    def mutate(self) -> None:
        pass

    @abstractmethod
    def num_mutation_candidates(self) -> int:
        pass


class FixedLengthSequenceSegment(SequenceSegmentFactory):
    def __init__(
        self, initial_sequence: Union[str, int], disallow_mutations_to_cysteine=True,
    ) -> None:
        super().__init__()
        self.mutation_residue_types = (
            RESIDUE_TYPES_WITHOUT_CYSTEINE
            if disallow_mutations_to_cysteine
            else ALL_RESIDUE_TYPES
        )

        self.sequence = (
            initial_sequence
            if type(initial_sequence) == str
            else random_sequence(
                length=initial_sequence, corpus=self.mutation_residue_types
            )
        )

    def get(self) -> str:
        return self.sequence

    def mutate(self) -> None:
        self.sequence = substitute_one_amino_acid(
            self.sequence, self.mutation_residue_types
        )

    def num_mutation_candidates(self) -> int:
        return len(self.sequence)


def substitute_one_amino_acid(sequence: str, corpus: List[str]) -> str:
    sequence = list(sequence)
    index = np.random.choice(len(sequence))
    sequence[index] = np.random.choice(corpus)
    return "".join(sequence)


def random_sequence(length: int, corpus: List[str]) -> str:
    "Generate a random sequence using amino acids in corpus."

    return "".join([np.random.choice(corpus) for _ in range(length)])


def sequence_from_atomarray(atoms: AtomArray) -> str:
    return "".join(
        [RESIDUE_TYPES_3to1[aa] for aa in atoms[atoms.atom_name == "CA"].res_name]
    )


class VariableLengthSequenceSegment(SequenceSegmentFactory):
    def __init__(
        self,
        initial_sequence: Union[str, int],
        disallow_mutations_to_cysteine=True,
        mutation_operation_probabilities: List[float] = [
            3., # Substitution weight.
            1., # Deletion weight.
            1., # Insertion weight.
        ],
    ) -> None:
        super().__init__()
        self.mutation_residue_types = (
            RESIDUE_TYPES_WITHOUT_CYSTEINE
            if disallow_mutations_to_cysteine
            else ALL_RESIDUE_TYPES
        )

        self.sequence = (
            initial_sequence
            if type(initial_sequence) == str
            else random_sequence(
                length=initial_sequence, corpus=self.mutation_residue_types
            )
        )

        self.mutation_operation_probabilities = np.array(mutation_operation_probabilities)
        self.mutation_operation_probabilities /= self.mutation_operation_probabilities.sum()

    def get(self) -> str:
        return self.sequence

    def mutate(self) -> None:
        mutation_operation = np.random.choice(
            [
                self._mutate_substitution,
                self._mutate_deletion,
                self._mutate_insertion,
            ],
            p=self.mutation_operation_probabilities,
        )
        mutation_operation()

    def _mutate_substitution(self) -> str:
        self.sequence = substitute_one_amino_acid(
            self.sequence, self.mutation_residue_types
        )

    def _mutate_deletion(self) -> str:
        self.sequence = delete_one_amino_acid(self.sequence)

    def _mutate_insertion(self) -> str:
        self.sequence = insert_one_amino_acid(
            self.sequence, self.mutation_residue_types
        )

    def num_mutation_candidates(self) -> int:
        # NOTE(brianhie): This should be `3*len(self.sequence) + 1`,
        # since there are `len(self.sequence)` substitutions and
        # deletions, and `len(self.sequence) + 1` insertions.
        # However, as this is used to weight sequence segments for
        # mutations when combined into a multi-segment program, we
        # just weight by `len(self.sequence)` for now.
        return len(self.sequence)


def delete_one_amino_acid(sequence: str) -> str:
    index = np.random.choice(len(sequence))
    return sequence[:index] + sequence[index + 1 :]


def insert_one_amino_acid(sequence: str, corpus: List[str]) -> str:
    n = len(sequence)
    index = np.random.randint(0, n) if n > 0 else 0
    insertion = np.random.choice(corpus)
    return sequence[:index] + insertion + sequence[index:]

In [16]:
# ProgramNode is a representationo of a protein
# The protein can be monomer(no child), or homo hetro multimers (has children), the child itself is a ProgramNode
# sequence_segment: how protein sequence will be modified (fixed or variable)
# return seq and index, if multi chain, offset is used between chains
# the code involves some dynamic programming
from typing import Tuple
from typing import Callable

class ProgramNode:
    def __init__(
        self,
        children: List["ProgramNode"] = None,
        sequence_segment: SequenceSegmentFactory = None,
        children_are_different_chains: bool = False,
        energy_function_terms: List[EnergyTerm] = [],
        energy_function_weights: Optional[List[float]] = None,
    ) -> None:
        self.children: Optional[List["ProgramNode"]] = children
        self.sequence_segment: SequenceSegmentFactory = sequence_segment
        self.children_are_different_chains: bool = children_are_different_chains
        self.energy_function_terms: List[energy_function_terms] = energy_function_terms
        self.energy_function_weights: List[
            float
        ] = energy_function_weights if energy_function_weights else [
            1.0 for _ in self.energy_function_terms
        ]
        if self.energy_function_weights:
            assert len(self.energy_function_terms) == len(
                self.energy_function_weights
            ), "One must have the same number of energy function terms and weights on a node."

        self.residue_index_range: Optional[Tuple[int, int]] = None

    def get_sequence_and_set_residue_index_ranges(
        self, residue_index_offset: int = 1
    ) -> Tuple[str, List[int]]:
        if self.is_leaf_node():
            sequence = self.sequence_segment.get()
            self.residue_index_range = (
                residue_index_offset,
                residue_index_offset + len(sequence),
            )
            return sequence, list(range(*self.residue_index_range))

        offset: int = residue_index_offset
        sequence = ""
        residue_indices = []
        for child in self.children:
            (
                sequence_segment,
                residue_indices_segment,
            ) = child.get_sequence_and_set_residue_index_ranges(
                residue_index_offset=offset
            )
            sequence += sequence_segment
            residue_indices += residue_indices_segment
            offset = residue_indices[-1] + 1
            if self.children_are_different_chains:
                offset += MULTIMER_RESIDUE_INDEX_SKIP_LENGTH
        self.residue_index_range = (residue_indices[0], residue_indices[-1] + 1)
        return sequence, residue_indices

    def get_residue_index_range(self) -> Tuple[int, int]:
        assert (
            self.residue_index_range
        ), "Must call get_sequence_and_set_residue_index_ranges() first."
        return self.residue_index_range

    def get_children(self) -> List["ProgramNode"]:
        return self.children

    def is_leaf_node(self) -> bool:
        return self.children is None

    def get_energy_term_functions(
        self, name_prefix: str = ""
    ) -> List[Tuple[str, float, Callable[[FoldingResult], float]]]:
        name_prefix = name_prefix if name_prefix else "root"

        terms = [
            (
                f"{name_prefix}:{type(term).__name__}",
                weight,
                partial(term.compute, self),
            )
            for weight, term in zip(
                self.energy_function_weights, self.energy_function_terms
            )
        ]

        if self.is_leaf_node():
            return terms

        for i, child in enumerate(self.children):
            terms += child.get_energy_term_functions(
                name_prefix=name_prefix + f".n{i+1}"
            )

        return terms

    def mutate(self) -> None:
        if self.is_leaf_node():
            return self.sequence_segment.mutate()

        weights = np.array(
            [float(child.num_mutation_candidates()) for child in self.children]
        )
        assert (
            weights.sum() > 0
        ), "Some mutations should be possible if mutate() was called."
        child_to_mutate = np.random.choice(self.children, p=weights / weights.sum())
        child_to_mutate.mutate()

    def num_mutation_candidates(self) -> int:
        if self.is_leaf_node():
            return self.sequence_segment.num_mutation_candidates()

        return sum([child.num_mutation_candidates() for child in self.children])