In [11]:
import modal
from typing import List, Optional, Dict
import logging
import pandas as pd
import torch
from datetime import datetime
import numpy as np

from egfr_binder_rd2.bt import BTEnsemble
from egfr_binder_rd2 import (
    LOGGING_CONFIG, MODAL_VOLUME_PATH, OUTPUT_DIRS, 
    ExpertType, ExpertConfig, EvolutionMetadata, PartialEnsembleExpertConfig)
from egfr_binder_rd2.sampler import DirectedEvolution
# Set up logging

In [16]:
parent_binder_seqs = [
       'PSFSACPSNYDGVCCNGGVCHLAESLTSYTCQCILGYSGHRVQTFDLRYTELRRR'
    ] * 10

parent_binder_seqs=parent_binder_seqs
generations=80
n_to_fold=50                # Total sequences to fold per generation
num_parents=25               # Number of parents to keep
top_k=50                    # Top sequences to consider
n_parallel_chains=4        # Parallel chains per sequence
n_serial_chains=1           # Sequential runs per sequence
n_steps=10                  # Steps per chain
max_mutations=5             # Max mutations per sequence
evoprotgrad_top_fraction=0.25
parent_selection_temperature=0.5
sequence_sampling_temperature=0.5
retrain_frequency=3
seed=42
select_from_current_gen_only=False

In [4]:
current_parent_seqs = parent_binder_seqs.copy()
expert_configs = None
all_final_sequences = []

# Create metadata tracker
config = {
"generations": generations,
"n_to_fold": n_to_fold,
"num_parents": num_parents,
"top_k": top_k,
"n_parallel_chains": n_parallel_chains,
"n_serial_chains": n_serial_chains,
"n_steps": n_steps,
"max_mutations": max_mutations,
"evoprotgrad_top_fraction": evoprotgrad_top_fraction,
"parent_selection_temperature": parent_selection_temperature,
"sequence_sampling_temperature": sequence_sampling_temperature,
"retrain_frequency": retrain_frequency,
"seed": seed
}

In [39]:
sample_sequences = modal.Function.lookup("bt-training", "sample_sequences")
train_bt_model = modal.Function.lookup("bt-training", "train_bt_model")
process_sequences = modal.Function.lookup("esm2-inference", "process_sequences")
update_pll_metrics = modal.Function.lookup("esm2-inference", "update_pll_metrics")
fold_binder = modal.Function.lookup("simplefold", "fold_binder")
parallel_fold_binder = modal.Function.lookup("simplefold", "parallel_fold_binder")
update_metrics = modal.Function.lookup("simplefold", "update_metrics_for_all_folded")

In [8]:
# Train PAE interaction expert
pae_model_path = train_bt_model.remote(
    yvar="pae_interaction",
    wandb_project="egfr-binder-rd2",
    wandb_entity="anaka_personal",
    transform_type="standardize",
    make_negative=True,
)


In [48]:
iptm_model_path = train_bt_model.remote(
    yvar="i_ptm",
    wandb_project="egfr-binder-rd2",
    wandb_entity="anaka_personal",
    transform_type="standardize",
    make_negative=False,
    max_epochs=3,
)

In [9]:
gen = 1

In [54]:
expert_configs = [
    ExpertConfig(
        type=ExpertType.ESM,
        temperature=1.0,
    ),
    PartialEnsembleExpertConfig(
        type=ExpertType.iPAE,
        temperature=1.0,
        make_negative=True,
        transform_type="standardize",
    ),
    PartialEnsembleExpertConfig(
        type=ExpertType.iPTM,
        temperature=1.0,
        make_negative=False,
        transform_type="standardize",
    ),
]

In [55]:
seqs_per_parent = max(1, n_to_fold // len(current_parent_seqs))

# Process each parent sequence
all_variants = []


In [61]:
evoprotgrad_df = sample_sequences.remote(
    sequences=current_parent_seqs,  # Pass all parent sequences at once
    expert_configs=expert_configs,
    n_parallel_chains=n_parallel_chains,
    n_serial_chains=n_serial_chains,
    n_steps=n_steps,
    max_mutations=max_mutations,
    seed=seed + gen,
    run_inference=True,
)

In [63]:
evoprotgrad_df['i_ptm_ucb_rank'] = evoprotgrad_df['i_ptm_ucb'].rank(pct=True)
evoprotgrad_df['pae_interaction_ucb_rank'] = evoprotgrad_df['pae_interaction_ucb'].rank(pct=True)
evoprotgrad_df['sequence_log_pll_rank'] = evoprotgrad_df['sequence_log_pll'].rank(pct=True)
evoprotgrad_df['fitness_ucb'] = (evoprotgrad_df['i_ptm_ucb'] + evoprotgrad_df['pae_interaction_ucb'] + evoprotgrad_df['sequence_log_pll_rank']) / 3

In [59]:
evoprotgrad_df.columns


Index(['run', 'parent_idx', 'parent_seq', 'parent_hash', 'chain', 'step',
       'score', 'sequence', 'sequence_hash', 'length', 'sequence_log_pll',
       'pae_interaction_mean', 'pae_interaction_std', 'pae_interaction_ucb',
       'pae_interaction_head_0', 'pae_interaction_head_1',
       'pae_interaction_head_2', 'pae_interaction_head_3',
       'pae_interaction_head_4', 'i_ptm_mean', 'i_ptm_std', 'i_ptm_ucb',
       'i_ptm_head_0', 'i_ptm_head_1', 'i_ptm_head_2', 'i_ptm_head_3',
       'i_ptm_head_4'],
      dtype='object')

In [68]:
# Sample sequences from the top fraction, now considering parent information
all_variants_with_parents = []  # New list to track variants with their parents
for parent_idx, parent_seq in enumerate(current_parent_seqs):
    parent_variants = evoprotgrad_df[evoprotgrad_df['parent_seq'] == parent_seq]
    if len(parent_variants) > 0:
        sampled_variants = DirectedEvolution.sample_from_evoprotgrad_sequences(
            parent_variants,
            top_fraction=evoprotgrad_top_fraction,
            sample_size=seqs_per_parent,
            temperature=sequence_sampling_temperature
        )
        # Store variants with their parent information
        all_variants_with_parents.extend([(variant, parent_seq) for variant in sampled_variants])
        all_variants.extend(sampled_variants)