In [2]:
!python src/esm/extract.py -h

usage: extract.py [-h] [--toks_per_batch TOKS_PER_BATCH]
                  [--repr_layers REPR_LAYERS [REPR_LAYERS ...]] --include
                  {mean,per_tok,bos,contacts}
                  [{mean,per_tok,bos,contacts} ...]
                  [--truncation_seq_length TRUNCATION_SEQ_LENGTH] [--nogpu]
                  [--concatenate_dir CONCATENATE_DIR]
                  model_location fasta_file output_dir

Extract per-token representations and model outputs for sequences in a FASTA
file

positional arguments:
  model_location        PyTorch model file OR name of pretrained model to
                        download (see README for models)
  fasta_file            FASTA file on which to extract representations
  output_dir            output directory for extracted representations

options:
  -h, --help            show this help message and exit
  --toks_per_batch TOKS_PER_BATCH
                        maximum batch size
  --repr_layers REPR_LAYERS [REPR_LAYERS ...]
                  

In [1]:
!python src/esm/extract.py esm1b_t33_650M_UR50S data/p1450.fasta data/esm_embedings/P1450 --toks_per_batch 512 --include mean --concatenate_dir /data/home/maorunzegroup/Basepro/data/esm_embedings

download over
Transferred model to GPU
Read data/p1450.fasta with 3 sequences
Processing 1 of 2 batches (2 sequences)
Device: cuda:0
Processing 2 of 2 batches (1 sequences)
Device: cuda:0
Saved representations to data/esm_embedings/P1450
  file_data = torch.load(file_path)
Shape of concatenated DataFrame: (3, 1280)
Saved concatenated representations to /data/home/maorunzegroup/Basepro/data/esm_embedings/p1450_esm1b_t33_650M_UR50S.csv


In [2]:
import numpy as np
import torch
import pandas as pd
import os

### round_0

In [2]:
def random_sample_csv(input_file_path,saved_file_path,sample_size=200):
    """
    Randomly samples rows from a large CSV file and saves to a new file as round0 data.
    
    Parameters:
    input_file_path (str): Path to input CSV file
    saved_file_path (str): Path to save the sampled CSV file
    sample_size (int): Number of rows to sample (default: 200)
    """
    try:
        # Read the CSV file
        print(f"Reading file: {os.path.basename(input_file_path)}...")
        df = pd.read_csv(input_file_path)
        
        # Validate file size
        if len(df) < sample_size:
            print(f"Warning: File has only {len(df)} rows, less than requested sample size {sample_size}")
            sample_size = len(df)
        
        # Perform random sampling
        np.random.seed(42)  
        round0_indices = np.random.choice(len(df), size=sample_size, replace=False)

        sampled_df = pd.DataFrame()
        sampled_df['variant'] = df['variant'][round0_indices]  # Fixed seed for reproducibility
        sampled_df['fitness'] = df['fitness'][round0_indices]
        sampled_df['indices'] = round0_indices
        # Save sampled data
        sampled_df.to_csv(saved_file_path, index=False)
        print(f"✓ Sampling complete! Saved to: {saved_file_path}")
        print(f"Original rows: {len(df)}, Sampled rows: {len(sampled_df)}")
        
    except Exception as e:
        print(f"Error: {str(e)}")
        print("Operation failed. Please check file path and format")


In [3]:
random_sample_csv('data/GB1/fitness.csv', 'rounds_data/GB1/GB1_round_0.csv', sample_size=200)

Reading file: fitness.csv...
✓ Sampling complete! Saved to: rounds_data/GB1/GB1_round_0.csv
Original rows: 149361, Sampled rows: 200


In [3]:
protein_name = 'GB1'
embeddings_base_path = 'data/GB1'
embeddings_file_name = 'ESM2_x.pt'
round_base_path = 'rounds_data/GB1'
number_of_variants = 90
output_dir = 'output'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


### round_1

In [4]:
from src.model import run_directed_evolution

In [5]:

round_base_path = 'rounds_data/GB1'
round_name = 'round_1'
round_data_filenames = [
    'GB1_round_0.csv',
    # 'GB1_round_1.csv'
]


In [6]:
fitness = pd.read_csv('data/GB1/fitness.csv')
all_variants = pd.DataFrame({
    'variant': fitness['variant'],
})

In [7]:

df_next_round, df_pre_all_sorted = run_directed_evolution(
    protein_name,
    round_name,
    embeddings_base_path,
    embeddings_file_name,
    round_base_path,
    round_data_filenames,
    number_of_variants,
    output_dir,
    regression_model='xgboost',
    all_variants=all_variants
)

Processing GB1 - round_1
Using device: cuda


  embeddings = torch.load(file_path, map_location=device)


Loaded embeddings from data/GB1/ESM2_x.pt with shape torch.Size([149361, 5120])
Embeddings loaded: torch.Size([149361, 5120])
Loaded: GB1_round_0.csv (Round 0)

Top 90 variants predicted by the modelf or next round: 90
       variant   fitness  indices
104538    WYAG  2.455739   104538
82283     YIAG  2.291723    82283
30244     WFAG  2.289244    30244
80659     YFAG  2.273627    80659
35767     TIAG  2.248921    35767
...        ...       ...      ...
115498    IGAG  1.632279   115498
7548      LGAG  1.628683     7548
20533     KVAG  1.624723    20533
10104     IVGG  1.621714    10104
77161     KICG  1.621399    77161

[90 rows x 3 columns]

Data saved to output/GB1/round_1


## round2

In [8]:

round_base_path = 'rounds_data/GB1'
round_name = 'round_2'
round_data_filenames = [
    'GB1_round_0.csv',
    'GB1_round_1.csv'
]


In [9]:

df_next_round, df_pre_all_sorted = run_directed_evolution(
    protein_name,
    round_name,
    embeddings_base_path,
    embeddings_file_name,
    round_base_path,
    round_data_filenames,
    number_of_variants,
    output_dir,
    regression_model='xgboost',
    all_variants=all_variants
)

Processing GB1 - round_2
Using device: cuda


  embeddings = torch.load(file_path, map_location=device)


Loaded embeddings from data/GB1/ESM2_x.pt with shape torch.Size([149361, 5120])
Embeddings loaded: torch.Size([149361, 5120])
Loaded: GB1_round_0.csv (Round 0)
Loaded: GB1_round_1.csv (Round 1)

Top 90 variants predicted by the modelf or next round: 90
       variant   fitness  indices
102915    WIGG  1.971513   102915
97910     IYIG  1.936826    97910
126925    ITVG  1.916372   126925
24555     YTAG  1.910432    24555
89858     YTPG  1.904741    89858
...        ...       ...      ...
119817    YNPG  1.760194   119817
92992     WCNG  1.759691    92992
84732     GKVG  1.759637    84732
86641     WEGG  1.758777    86641
50471     WSPG  1.757703    50471

[90 rows x 3 columns]

Data saved to output/GB1/round_2
