In [1]:
from datetime import datetime
from pathlib import Path
import os

import click
import torch
import wandb
from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.loggers import WandbLogger
import logging
import lightning as L

import pandas as pd
import seaborn as sns

from egfr_binder_rd2.datamodule import SequenceDataModule
from egfr_binder_rd2.bt import BTRegressionModule


%load_ext autoreload
%autoreload 2

torch.set_float32_matmul_precision('medium')

seed = 42
debug = True


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def setup_logging(debug: bool):
    level = logging.DEBUG if debug else logging.INFO
    logging.basicConfig(
        level=level,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    logger = logging.getLogger(__name__)
    logger.info("Logging setup complete.")
    return logger

In [3]:
yvar = 'Average_i_pAE'
# yvar = 'Average_i_pTM'
yvar = 'Average_pLDDT'
yvar = 'encoded_expression'

In [None]:
fp = '/home/naka/code/BindCraft/outputs/EGFR_single_domain/mpnn_design_stats.csv'
df = pd.read_csv(fp)


In [4]:
fp = 'https://raw.githubusercontent.com/adaptyvbio/egfr_competition_1/refs/heads/main/results/replicate_summary.csv'
seqs = pd.read_csv('https://raw.githubusercontent.com/adaptyvbio/egfr_competition_1/refs/heads/main/results/result_summary.csv')
df = pd.read_csv(fp).merge(seqs[['name', 'sequence']])

In [5]:
expression_map = {'low': 1, 'medium': 2, 'high': 3, 'none': 0}
df['encoded_expression'] = df['nc_adjusted_expression'].map(expression_map)

In [6]:
df = df.groupby(['name', 'sequence']).agg({'encoded_expression': 'mean'}).reset_index().rename(columns={'sequence': 'Sequence'})

In [7]:
df.dropna(subset=[yvar])

Unnamed: 0,name,Sequence,encoded_expression
0,Cetuximab_scFv,QVQLKQSGPGLVQPSQSLSITCTVSGFSLTNYGVHWVRQSPGKGLE...,1.0
1,Razora712-sequence_10,EELKKALQALKKEYRDKQWAVVQEMLKQHAEIAKKKEAGEINEKEA...,3.0
2,Razora712-sequence_2,RVKELEEEAKRKADEAEELKKRIDALQAKFNELLAAAKASSDPRKS...,3.0
3,Razora712-sequence_3,KELEEARKKLKEEIIKEKKAIVDQELKNHAEIADLVEAGKINEKEA...,3.0
4,Razora712-sequence_6,EALEEALKALKAEHAKKRKAIYDELLESHSNIADKVEKGEINKEEA...,2.0
...,...,...,...
197,x.rustamov-s_11_5,MPELEAFKEEFEKFMKEFKKLSEEDIKDFKENLKKKGKPVTEEDIE...,3.0
198,x.rustamov-s_15_28,MKEKLNELADEAISFAKSIFGDHPSLATFTSFANSVADDLSKEDIS...,3.0
199,zalavi-egfr_binder3,SEEAKELKEKAKEKLKEALEKAKEALKDAEKAAEILKKIPEAKEAL...,3.0
200,zalavi-egfr_binder7,AQAAAKETIRAVLKAAAEAARKMAEEARKLAKELEKYNKEAAKHAL...,3.0


In [8]:

# Create and setup the DataModule
data_module = SequenceDataModule(
    df, 
    tokenizer_name="facebook/esm2_t33_650M_UR50D",
    yvar=yvar,  # You can change this to any other column name in your DataFrame
    batch_size=6,
    max_length=512
)
data_module.setup()



Map: 100%|██████████| 144/144 [00:00<00:00, 6744.53 examples/s]
Map: 100%|██████████| 17/17 [00:00<00:00, 4519.15 examples/s]
Map: 100%|██████████| 41/41 [00:00<00:00, 6234.06 examples/s]


In [9]:

# 3. Create the model
model = BTRegressionModule(
    label=yvar,
    model_name="facebook/esm2_t33_650M_UR50D",
    lr=5e-4,
    peft_r=8,
    peft_alpha=16,
    max_length=512,
)

# 4. Set up callbacks
early_stop_callback = EarlyStopping(
    monitor='val_spearman',
    min_delta=0.00,
    patience=30,
    verbose=False,
    mode='max'
)

# 5. Set up wandb logger
wandb_logger = WandbLogger(project="bt_regression", name="bt_regression_run")

# 6. Create the trainer
trainer = L.Trainer(
    max_epochs=40,
    callbacks=[early_stop_callback],
    logger=wandb_logger,
    accelerator='gpu',
    devices=1,  # Use 1 GPU
    log_every_n_steps=10,
    enable_checkpointing=False,  # Disable checkpoint saving
    val_check_interval=0.25,  # Check validation 4 times per epoch
)


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
/home/naka/code/egfr_binder_rd2/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/naka/code/egfr_binder_rd2/.venv/lib/python3.11 ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [10]:

# 7. Train the model
trainer.fit(model, data_module)


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33manaka[0m ([33manaka_personal[0m). Use [1m`wandb login --relogin`[0m to force relogin


Map: 100%|██████████| 144/144 [00:00<00:00, 7681.09 examples/s]
Map: 100%|██████████| 17/17 [00:00<00:00, 4492.95 examples/s]
Map: 100%|██████████| 41/41 [00:00<00:00, 5835.31 examples/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                               | Params | Mode 
-----------------------------------------------------------------------------
0 | esm_model     | PeftModelForSequenceClassification | 656 M  | train
1 | bt_loss       | BradleyTerryLoss                   | 0      | train
2 | train_metrics | MetricCollection                   | 0      | train
3 | val_mae       | MeanAbsoluteError                  | 0      | train
4 | val_spearman  | SpearmanCorrCoef                   | 0      | train
-----------------------------------------------------------------------------
3.7 M     Trainable params
652 M     Non-trainable params
656 M     Total params
2,624.095 Total estimated model params size (MB)
1000      Modules in train mode
614       Modules i

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/home/naka/code/egfr_binder_rd2/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


                                                                           

/home/naka/code/egfr_binder_rd2/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Epoch 8:  50%|█████     | 12/24 [00:04<00:04,  2.44it/s, v_num=ws7o]


In [None]:

# 8. Test the model
test_result = trainer.test(model, data_module)
print(f"Test result: {test_result}")