In [11]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import logging


In [12]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

from app.helpers.finetuning.training import train_per_protein, score_sequences
from app.helpers.sequence_util import (
    get_seq_ids_for_deep_mutational_scan,
    seq_id_to_seq,
    maybe_get_seq_id_error_message,
    process_and_validate_evolve_input_files,
)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
import re

cas_wt_aa_seq = 'MIKVYRYEIVKPLDLDWKEFGTILRQLQQETRFALNKATQLAWEWMGFSSDYKDNHGEYPKSKDILGYTNVHGYAYHTIKTKAYRLNSGNLSQTIKRATDRFKAYQKEILRGDMSIPSYKRDIPLDLIKENISVNRMNHGDYIASLSLLSNPAKQEMNVKRKISVIIIVRGAGKTIMDRILSGEYQVSASQIIHDDRKNKWYLNISYDFEPQTRVLDLNKIMGIDLGVAVAVYMAFQHTPARYKLEGGEIENFRRQVESRRISMLRQGKYAGGARGGHGRDKRIKPIEQLRDKIANFRDTTNHRYSRYIVDMAIKEGCGTIQMEDLTNIRDIGSRFLQNWTYYDLQQKIIYKAEEAGIKVIKIDPQYTSQRCSECGNIDSGNRIGQAIFKCRACGYEANADYNAARNIAIPNIDKIIAESIK'

cas_activity = pd.read_excel('notebooks/data/DMS_AsCas12f_preprocessed.xlsx')
cas_activity = cas_activity.rename(columns={'variant': 'seq_id', 'avg_activity': 'activity'})[['seq_id', 'activity']]


tem1_wt_aa_seq = ''.join("""MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRP
EERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVREL
CSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTM
PAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGS
RGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW""".split('\n'))

tem1_activity = pd.read_excel('notebooks/data/tem1_activity.xlsx', sheet_name='simplified')
tem1_activity = tem1_activity.rename(columns={'Mutation': 'seq_id', 'activity_at_2500': 'activity'})[['seq_id', 'activity']]
def fix_tem1_seq_id(seq_id):
  m = re.match(r'([A-Z])(\d+)(.*)', seq_id)
  assert m, seq_id
  old_id = m.groups()[0]
  locus = int(m.groups()[1])
  new_id = m.groups()[2]

  new_locus = locus - 2
  
  if new_locus >= 245:
    new_locus = new_locus - 2

  return f'{old_id}{new_locus}{new_id}'
tem1_activity['seq_id'] = tem1_activity.seq_id.apply(fix_tem1_seq_id)


In [14]:
# wt_aa_seq = tem1_wt_aa_seq
# raw_activity_df = tem1_activity.copy()
wt_aa_seq = cas_wt_aa_seq
raw_activity_df = cas_activity.copy()

In [16]:
pg_df = pd.read_csv('notebooks/data/DMS_substitutions.csv')
import numpy as np
# set numpy random key.
np.random.seed(42)


def get_training_pairs_for_dms(dms_id, dms_fname, dms_target_seq, num_examples):
  dms_df = pd.read_csv(f'notebooks/data/DMS_ProteinGym_substitutions/{dms_fname}')
  dms_df = dms_df.sort_values('DMS_score', ascending=False).reset_index(drop=True)

  training_pairs = []
  for ii in range(num_examples):
    # Randomly select one of the top 10% mutants.
    # Then randomly select another mutant at least 10% lower on the list.
    example_w_idx = np.random.randint(0, int(len(dms_df) * 0.1))
    example_l_idx = np.random.randint(example_w_idx + int(len(dms_df) * 0.1), len(dms_df))
    example_w = dms_df.iloc[example_w_idx]
    example_l = dms_df.iloc[example_l_idx]

    # if get_loci_set(example_w.mutant) == get_loci_set(example_l.mutant):
    #   print(f'Skipping a duplicate locus set {example_w.mutant} and {example_l.mutant}')
    #   continue

    training_pairs.append(pd.DataFrame({
      'dms_id': [dms_id],
      'sequence': [dms_target_seq],
      'seq_id_w': [example_w.mutant],
      'seq_id_l': [example_l.mutant],
      'rank_w': [example_w_idx],
      'rank_l': [example_l_idx],
      'DMS_score_w': [example_w.DMS_score],
      'DMS_score_l': [example_l.DMS_score],
    }))
  return pd.concat(training_pairs)


In [17]:
train_df = get_training_pairs_for_dms(
  pg_df.iloc[0].DMS_id,
  pg_df.iloc[0].DMS_filename,
  pg_df.iloc[0].target_seq,
  2
)
train_df = pd.concat([train_df] * 5)

In [18]:
NUM_EPOCHS = 10

# Configure logging to display in Jupyter
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    force=True  # Override any existing configuration
)

gpu_available = False
tokenizer2, model2, history2 = train_per_protein(
    checkpoint="facebook/esm2_t6_8M_UR50D",
    train_df=train_df,
    valid_df=train_df,
    device='cpu',
    loss='dpo',
    train_batch_size=5,
    grad_accum_steps=1,
    val_batch_size=5,
    epochs=NUM_EPOCHS,
    learning_rate=3e-4,
    seed=42,
    mixed_precision=gpu_available,
    train_full=True,
)

2025-03-03 00:35:58,112 - root - INFO - Loading model from facebook/esm2_t6_8M_UR50D
2025-03-03 00:35:58,449 - root - INFO - Creating training arguments
2025-03-03 00:35:58,542 - root - INFO - Creating trainer
2025-03-03 00:35:58,546 - root - INFO - Training model


Step,Training Loss,Validation Loss,Spearmanr
2,0.5797,0.483345,1.0
4,0.4772,0.346859,1.0
6,0.278,0.158172,1.0
8,0.083,0.039611,-1.0
10,0.0315,0.011136,-1.0
12,0.0059,0.002068,-1.0
14,0.0007,0.000327,1.0
16,0.0001,7.9e-05,1.0
18,0.0,3.1e-05,1.0
20,0.0,1.4e-05,1.0


2025-03-03 00:36:18,779 - root - INFO - Step 1: loss=0.4475, learning_rate=0.0000, epoch=0.5000
2025-03-03 00:36:37,431 - root - INFO - Step 2: loss=0.5797, learning_rate=0.0000, epoch=1.0000
2025-03-03 00:36:59,507 - root - INFO - Step 2: eval_loss=0.4833, eval_spearmanr=1.0000, eval_runtime=22.0730, eval_samples_per_second=0.4530, eval_steps_per_second=0.0910, epoch=1.0000
2025-03-03 00:36:59,508 - root - INFO - Evaluation at step 2: eval_loss=0.4833, eval_spearmanr=1.0000, eval_runtime=22.0730, eval_samples_per_second=0.4530, eval_steps_per_second=0.0910, epoch=1.0000
2025-03-03 00:37:18,435 - root - INFO - Step 3: loss=0.4224, learning_rate=0.0000, epoch=1.5000
2025-03-03 00:37:37,026 - root - INFO - Step 4: loss=0.4772, learning_rate=0.0000, epoch=2.0000
2025-03-03 00:37:59,298 - root - INFO - Step 4: eval_loss=0.3469, eval_spearmanr=1.0000, eval_runtime=22.2702, eval_samples_per_second=0.4490, eval_steps_per_second=0.0900, epoch=2.0000
2025-03-03 00:37:59,298 - root - INFO - Eval

In [9]:
dms_seq_ids = get_seq_ids_for_deep_mutational_scan(wt_aa_seq, ['WT'], [])
dms_seq_ids = dms_seq_ids[:100]
print(f"Scoring {len(dms_seq_ids)} sequences", flush=True)

scores_df = score_sequences(model2, tokenizer2, wt_aa_seq, dms_seq_ids)


Scoring 100 sequences


NameError: name 'model2' is not defined

In [10]:
# Randomly sample 16, 64, 256 rows of cas_activity and save to excel
from pathlib import Path
Path('notebooks/data/cas_activity').mkdir(parents=True, exist_ok=True)
cas_activity.sample(n=16).to_excel('notebooks/data/cas_activity/cas_activity_16.xlsx', index=False)
cas_activity.sample(n=64).to_excel('notebooks/data/cas_activity/cas_activity_64.xlsx', index=False)
cas_activity.sample(n=256).to_excel('notebooks/data/cas_activity/cas_activity_256.xlsx', index=False)


In [13]:
from scipy.stats import spearmanr
df = pd.merge(scores_df, raw_activity_df, on='seq_id')
spearmanr(df.wt_marginal_score, df.activity)

SignificanceResult(statistic=0.10474405319766142, pvalue=0.3046862456357439)