## Conditional GflowNets for mRNA sequences

In [1]:
from argparse import ArgumentParser
from typing import Tuple

import torch
from torch.optim import Adam
from tqdm import tqdm

from gfn.gflownet import (
    DBGFlowNet,
    FMGFlowNet,
    ModifiedDBGFlowNet,
    SubTBGFlowNet,
    TBGFlowNet,
)

from gfn.gym import HyperGrid

from gfn.modules import (
    ConditionalDiscretePolicyEstimator,
    ConditionalScalarEstimator,
    ScalarEstimator,
)
from gfn.utils.modules import MLP

DEFAULT_SEED: int = 4444

In [2]:
def build_conditional_pf_pb(
    env: HyperGrid,
) -> Tuple[ConditionalDiscretePolicyEstimator, ConditionalDiscretePolicyEstimator]:

    """Build conditional policy forward and backward estimators.
    Args:
        env: The HyperGrid environment
    Returns:
        A tuple of (forward policy estimator, backward policy estimator)
    """

    CONCAT_SIZE = 16
    module_PF = MLP(
        input_dim=env.n_states,
        output_dim=CONCAT_SIZE,
        hidden_dim=256,
    )
    module_PB = MLP(
        input_dim=env.n_states,
        output_dim=CONCAT_SIZE,
        hidden_dim=256,
        trunk=module_PF.trunk,
    )

    # Encoder for the Conditioning information.
    module_cond = MLP(
        input_dim=1,
        output_dim=CONCAT_SIZE,
        hidden_dim=256,
    )

    # Modules post-concatenation.
    module_final_PF = MLP(
        input_dim=CONCAT_SIZE * 2,
        output_dim=env.n_actions,
    )
    module_final_PB = MLP(
        input_dim=CONCAT_SIZE * 2,
        output_dim=env.n_actions - 1,
        trunk=module_final_PF.trunk,
    )

    pf_estimator = ConditionalDiscretePolicyEstimator(
        module_PF,
        module_cond,
        module_final_PF,
        env.n_actions,
        is_backward=False,
    )
    pb_estimator = ConditionalDiscretePolicyEstimator(
        module_PB,
        module_cond,
        module_final_PB,
        env.n_actions,
        is_backward=True,
    )

    return pf_estimator, pb_estimator

In [3]:
def build_conditional_logF_scalar_estimator(
    env: HyperGrid,
) -> ConditionalScalarEstimator:
    """Build conditional log flow estimator.
    Args:
        env: The HyperGrid environment
    Returns:
        A conditional scalar estimator for log flow
    """
    CONCAT_SIZE = 16

    module_state_logF = MLP(
        input_dim=env.n_states,
        output_dim=CONCAT_SIZE,
        hidden_dim=256,
        n_hidden_layers=1,
    )
    module_conditioning_logF = MLP(
        input_dim=1,
        output_dim=CONCAT_SIZE,
        hidden_dim=256,
        n_hidden_layers=1,
    )
    module_final_logF = MLP(
        input_dim=CONCAT_SIZE * 2,
        output_dim=1,
        hidden_dim=256,
        n_hidden_layers=1,
    )

    logF_estimator = ConditionalScalarEstimator(
        module_state_logF,
        module_conditioning_logF,
        module_final_logF,
    )

    return logF_estimator

In [4]:
# Build the GFlowNet -- Modules pre-concatenation.
def build_tb_gflownet(env: HyperGrid) -> TBGFlowNet:
    
    """Build a Trajectory Balance GFlowNet.
    Args:
        env: The HyperGrid environment
    Returns:
        A TBGFlowNet instance
    """
    pf_estimator, pb_estimator = build_conditional_pf_pb(env)

    module_logZ = MLP(
        input_dim=1,
        output_dim=1,
        hidden_dim=16,
        n_hidden_layers=2,
    )

    logZ_estimator = ScalarEstimator(module_logZ)
    gflownet = TBGFlowNet(logZ=logZ_estimator, pf=pf_estimator, pb=pb_estimator)

    return gflownet


In [9]:
def train(env, gflownet, seed=42):

    torch.manual_seed(seed)
    exploration_rate = 0.5
    lr = 0.0005

    # Policy parameters and logZ/logF get independent LRs (logF/Z typically higher).
    if type(gflownet) is TBGFlowNet:
        optimizer = Adam(gflownet.pf_pb_parameters(), lr=lr)
        optimizer.add_param_group({"params": gflownet.logz_parameters(), "lr": lr * 100})

    elif type(gflownet) is DBGFlowNet or type(gflownet) is SubTBGFlowNet:
        optimizer = Adam(gflownet.pf_pb_parameters(), lr=lr)
        optimizer.add_param_group({"params": gflownet.logF_parameters(), "lr": lr * 100})

    elif type(gflownet) is FMGFlowNet or type(gflownet) is ModifiedDBGFlowNet:
        optimizer = Adam(gflownet.parameters(), lr=lr)

    else:
        print("Unknown gflownet type: {}".format(type(gflownet)))

    n_iterations = int(10)  # 1e4)
    batch_size = int(1e4)

    print("+ Training Conditional {}!".format(type(gflownet)))

    for _ in (pbar := tqdm(range(n_iterations))):

        conditioning = torch.rand((batch_size, 1)).to(env.device)  # type: ignore
        conditioning = (conditioning > 0.5).to(torch.float)  # Randomly 1 and zero.

        trajectories = gflownet.sample_trajectories(
            env,
            n=batch_size,
            conditioning=conditioning,
            save_logprobs=False,
            save_estimator_outputs=True,
            epsilon=exploration_rate,
        )
        optimizer.zero_grad()
        loss = gflownet.loss_from_trajectories(
            env, trajectories, recalculate_all_logprobs=False
        )
        loss.backward()
        optimizer.step()
        pbar.set_postfix({"loss": loss.item()})

    print("+ Training complete!")

In [7]:
environment = HyperGrid(
        ndim=5,
        height=2
    )



In [1]:
t = 'MMFPQSRHSGSSHLPQQLKFTTSDSCDRIKDEFQLLQAQYHSLKLECDKLASEKSEMQRHYVMYYEMSYGLNIEMHKQAEIVKRLNGICAQVLPYLSQEHQQQVLGAIERAKQVTAPELNSIIRQQLQAHQLSQLQALALPLTPLPVGLQPPSLPAVSAGTGLLSLSALGSQAHLSKEDKNGHDGDTHQEDDGEKSD*'

In [2]:
len(t)

198

In [3]:
s = 'MGASARLLRAVIMGAPGSGKGTVSSRITTHFELKHLSSGDLLRDNMLRGTEIGVLAKAFIDQGKLIPDDVMTRLALHELKNLTQYSWLLDGFPRTLPQAEALDRAYQIDTVINLNVPFEVIKQRLTARWIHPASGRVYNIEFNPPKTVGIDDLTGEPLIQREDDKPETVIKRLKAYEDQTKPVLEYYQKKGVLETFSGTETNKIWPYVYAFLQTKVPQRSQKASVTP*'

In [4]:
len(s)

228

In [5]:
def dna_to_mrna(dna: str) -> str:
    """Convert a DNA sequence to an mRNA sequence by replacing T with U."""
    dna = dna.upper().replace(" ", "")
    mrna = dna.replace('T', 'U')
    return mrna

In [6]:
dna = 'ATG GGG GCG TCC GCG CGG CTG CTG CGA GCG GTG ATC ATG GGG GCC CCG GGC TCG GGC AAG GGC ACC GTG TCG TCG CGC ATC ACT ACA CAC TTC GAG CTG AAG CAC CTC TCC AGC GGG GAC CTG CTC CGG GAC AAC ATG CTG CGG GGC ACA GAA ATT GGC GTG TTA GCC AAG GCT TTC ATT GAC CAA GGG AAA CTC ATC CCA GAT GAT GTC ATG ACT CGG CTG GCC CTT CAT GAG CTG AAA AAT CTC ACC CAG TAT AGC TGG CTG TTG GAT GGT TTT CCA AGG ACA CTT CCA CAG GCA GAA GCC CTA GAT AGA GCT TAT CAG ATC GAC ACA GTG ATT AAC CTG AAT GTG CCC TTT GAG GTC ATT AAA CAA CGC CTT ACT GCT CGC TGG ATT CAT CCC GCC AGT GGC CGA GTC TAT AAC ATT GAA TTC AAC CCT CCC AAA ACT GTG GGC ATT GAT GAC CTG ACT GGG GAG CCT CTC ATT CAG CGT GAG GAT GAT AAA CCA GAG ACG GTT ATC AAG AGA CTA AAG GCT TAT GAA GAC CAA ACA AAG CCA GTC CTG GAA TAT TAC CAG AAA AAA GGG GTG CTG GAA ACA TTC TCC GGA ACA GAA ACC AAC AAG ATT TGG CCC TAT GTA TAT GCT TTC CTA CAA ACT AAA GTT CCA CAA AGA AGC CAG AAA GCT TCA GTT ACT CCA TGA'

In [9]:
(dna_to_mrna(dna))

'AUGGGGGCGUCCGCGCGGCUGCUGCGAGCGGUGAUCAUGGGGGCCCCGGGCUCGGGCAAGGGCACCGUGUCGUCGCGCAUCACUACACACUUCGAGCUGAAGCACCUCUCCAGCGGGGACCUGCUCCGGGACAACAUGCUGCGGGGCACAGAAAUUGGCGUGUUAGCCAAGGCUUUCAUUGACCAAGGGAAACUCAUCCCAGAUGAUGUCAUGACUCGGCUGGCCCUUCAUGAGCUGAAAAAUCUCACCCAGUAUAGCUGGCUGUUGGAUGGUUUUCCAAGGACACUUCCACAGGCAGAAGCCCUAGAUAGAGCUUAUCAGAUCGACACAGUGAUUAACCUGAAUGUGCCCUUUGAGGUCAUUAAACAACGCCUUACUGCUCGCUGGAUUCAUCCCGCCAGUGGCCGAGUCUAUAACAUUGAAUUCAACCCUCCCAAAACUGUGGGCAUUGAUGACCUGACUGGGGAGCCUCUCAUUCAGCGUGAGGAUGAUAAACCAGAGACGGUUAUCAAGAGACUAAAGGCUUAUGAAGACCAAACAAAGCCAGUCCUGGAAUAUUACCAGAAAAAAGGGGUGCUGGAAACAUUCUCCGGAACAGAAACCAACAAGAUUUGGCCCUAUGUAUAUGCUUUCCUACAAACUAAAGUUCCACAAAGAAGCCAGAAAGCUUCAGUUACUCCAUGA'

In [10]:
f = "MCSLGLFPPPPPRGQVTLYEHNNELVTGSSYESPPPDFRGQWINLPVLQLTKDPLKTPGRLDHGTRTAFIHHREQVWKRCINIWRDVGLFGVLNEIANSEEEVFEWVKTASGWALALCRWASSLHGSLFPHLSLRSEDLIAEFAQVTNWSSCCLRVFAWHPHTNKFAVALLDDSVRVYNASSTIVPSLKHRLQRNVASLAWKPLSASVLAVACQSCILIWTLDPTSLSTRPSSGCAQVLSHPGHTPVTSLAWAPSGGRLLSASPVDAAIRVWDVSTETCVPLPWFRGGGVTNLLWSPDGSKILATTPSAVFRVWEAQMWTCERWPTLSGRCQTGCWSPDGSRLLFTVLGEPLIYSLSFPERCGEGKGCVGGAKSATIVADLSETTIQTPDGEERLGGEAHSMVWDPSGERLAVLMKGKPRVQDGKPVILLFRTRNSPVFELLPCGIIQGEPGAQPQLITFHPSFNKGALLSVGWSTGRIAHIPLYFVNAQFPRFSPVLGRAQEPPAGGGGSIHDLPLFTETSPTSAPWDPLPGPPPVLPHSPHSHL*"

In [11]:
len(f)

547

In [12]:
g = 'MDSEVQRDGRILDLIDDAWREDKLPYEDVAIPLNELPEPEQDNGGTTESVKEQEMKWTDLALQYLHENVPPIGN*'

In [13]:
len(g)

75

In [14]:
dna = 'ATG GAC AGT GAG GTT CAG AGA GAT GGA AGG ATC TTG GAT TTG ATT GAT GAT GCT TGG CGA GAA GAC AAG CTG CCT TAT GAG GAT GTC GCA ATA CCA CTG AAT GAG CTT CCT GAA CCT GAA CAA GAC AAT GGT GGC ACC ACA GAA TCT GTC AAA GAA CAA GAA ATG AAG TGG ACA GAC TTA GCC TTA CAG TAC CTC CAT GAG AAT GTT CCC CCC ATT GGA AAC TGA'

In [15]:
(dna_to_mrna(dna))

'AUGGACAGUGAGGUUCAGAGAGAUGGAAGGAUCUUGGAUUUGAUUGAUGAUGCUUGGCGAGAAGACAAGCUGCCUUAUGAGGAUGUCGCAAUACCACUGAAUGAGCUUCCUGAACCUGAACAAGACAAUGGUGGCACCACAGAAUCUGUCAAAGAACAAGAAAUGAAGUGGACAGACUUAGCCUUACAGUACCUCCAUGAGAAUGUUCCCCCCAUUGGAAACUGA'

In [16]:
s = 'MTSMTQSLREVIKAMTKARNFERVLGKITLVSAAPGKVICEMKVEEEHTNAIGTLHGGLTATLVDNISTMALLCTERGAPGVSVDMNITYMSPAKLGEDIVITAHVLKQGKTLAFTSVDLTNKATGKLIAQGRHTKHLGN*'

In [17]:
len(s)

141