In [2]:
import torch
print(torch.__version__)

2.9.1+cu128


In [1]:
# import torch
import transformers
import peft
# print(torch.__version__)
print(transformers.__version__)
print(peft.__version__)

4.39.3
0.10.0


In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model
from peft.utils.other import fsdp_auto_wrap_policy
from transformers import EsmForMaskedLM, EsmTokenizer, EsmConfig
import os
import argparse
from pathlib import Path
import accelerate
from accelerate import Accelerator

from confit.data_utils import Mutation_Set, split_train, sample_data
from confit.stat_utils import spearman, compute_score, BT_loss, KLloss

from Bio import SeqIO
from transformers import EsmTokenizer

from scipy.stats import spearmanr

In [4]:
# pip install transformers==4.39.3 peft==0.10.0

In [None]:
# os.environ["HF_TOKEN"] = "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"

In [4]:
temp_df = pd.read_csv("data/BRCA1_HUMAN_Fields2015-e3/data.csv")

In [5]:
temp_df.describe()

Unnamed: 0.1,Unnamed: 0,log_fitness,n_mut,PID,mutated_position
count,1271.0,1271.0,1271.0,1271.0,1271.0
mean,693.125098,0.546818,1.0,693.125098,52.129819
std,398.213603,0.484519,0.0,398.213603,22.110417
min,0.0,-0.437145,1.0,0.0,15.0
25%,348.5,0.237671,1.0,348.5,33.0
50%,697.0,0.520609,1.0,697.0,51.0
75%,1040.5,0.834807,1.0,1040.5,71.0
max,1381.0,2.791509,1.0,1381.0,90.0


In [6]:
temp_df.head()

Unnamed: 0.1,Unnamed: 0,seq,log_fitness,n_mut,mutant,PID,mutated_position
0,0,MDLSALRVEEVQNVIAAMQKILECPICLELIKEPVSTKCDHIFCKF...,0.908475,1,N16A,0,15
1,1,MDLSALRVEEVQNVICAMQKILECPICLELIKEPVSTKCDHIFCKF...,0.156238,1,N16C,1,15
2,2,MDLSALRVEEVQNVIEAMQKILECPICLELIKEPVSTKCDHIFCKF...,1.287431,1,N16E,2,15
3,3,MDLSALRVEEVQNVIDAMQKILECPICLELIKEPVSTKCDHIFCKF...,1.074311,1,N16D,3,15
4,4,MDLSALRVEEVQNVIGAMQKILECPICLELIKEPVSTKCDHIFCKF...,1.177073,1,N16G,4,15


### Load data

In [7]:
dataset_name = "BRCA1_HUMAN_Fields2015-e3"

In [8]:
data_dir = f'data/{dataset_name}'
data = pd.read_csv(f"{data_dir}/data.csv", index_col=0)

In [9]:
# load wild type
wt_path = f"{data_dir}/wt.fasta"
wt_seq = str(next(SeqIO.parse(wt_path, "fasta")).seq)

In [11]:
print(f'Wild-type sequence length: {len(wt_seq)}')

Wild-type sequence length: 110


In [15]:
print(wt_seq[:50])

MDLSALRVEEVQNVINAMQKILECPICLELIKEPVSTKCDHIFCKFCMLK


In [9]:
# load VAE
vae_elbo = pd.read_csv(f"{data_dir}/vae_elbo.csv", index_col=0)

In [17]:
vae_elbo.head()

Unnamed: 0,seq,PID,elbo
0,MDLSALRVEEVQNVIAAMQKILECPICLELIKEPVSTKCDHIFCKF...,0,0.629923
1,MDLSALRVEEVQNVICAMQKILECPICLELIKEPVSTKCDHIFCKF...,1,-2.12385
2,MDLSALRVEEVQNVIEAMQKILECPICLELIKEPVSTKCDHIFCKF...,2,0.151109
3,MDLSALRVEEVQNVIDAMQKILECPICLELIKEPVSTKCDHIFCKF...,3,-0.040975
4,MDLSALRVEEVQNVIGAMQKILECPICLELIKEPVSTKCDHIFCKF...,4,-1.026119


### Prepare Tokenization

In [35]:
# from huggingface_hub import login
# login()

In [None]:
# from transformers import AutoModelForMaskedLM, AutoTokenizer

# model = AutoModelForMaskedLM.from_pretrained("facebook/esm1b_t33_650M_UR50S")
# tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")

In [5]:
# export HF_HOME=/work/commons/huggingface/hub

# "facebook/esm1v_t33_650M_UR90S_1" not working

from transformers import EsmForMaskedLM, EsmTokenizer

model_name = "facebook/esm2_t33_650M_UR50D"
model = EsmForMaskedLM.from_pretrained(model_name)
tokenizer = EsmTokenizer.from_pretrained(model_name)


seq_len = 1024



In [11]:
sequences = list(data["seq"])
fitness = torch.tensor(data["log_fitness"].values, dtype=torch.float32)
positions = [[int(p) for p in str(pos).split(',')] if isinstance(pos, str) else [pos] for pos in data['mutated_position']]
pids = data["PID"].values

In [12]:
len(sequences), len(fitness), len(positions), len(pids)

(1271, 1271, 1271, 1271)

In [52]:
# positions

In [12]:
wt_path

'data/BRCA1_HUMAN_Fields2015-e3/wt.fasta'

In [16]:
for seq_record in SeqIO.parse(wt_path, "fasta"):
    wt = str(seq_record.seq)
    print(wt)
    break

MDLSALRVEEVQNVINAMQKILECPICLELIKEPVSTKCDHIFCKFCMLKLLNQKKGPSQCPLCKNDITKRSLQESTRFSQLVEELLKIICAFQLDTGLEYANSYNFAKK


### Train

In [15]:
class Mutation_Set(Dataset):
    def __init__(self, data, fname, tokenizer, sep_len=1024):
        self.data = data
        self.tokenizer = tokenizer
        self.seq_len = sep_len
        self.seq, self.attention_mask = tokenizer(list(self.data['seq']), padding='max_length',
                                                  truncation=True,
                                                  max_length=self.seq_len).values()
        wt_path = os.path.join('data', fname, 'wt.fasta')
        for seq_record in SeqIO.parse(wt_path, "fasta"):
            wt = str(seq_record.seq)
        target = [wt]*len(self.data)
        self.target, self.tgt_mask = tokenizer(target, padding='max_length', truncation=True,
                                               max_length=self.seq_len).values()
        self.score = torch.tensor(np.array(self.data['log_fitness']))
        self.pid = np.asarray(data['PID'])

        if type(list(self.data['mutated_position'])[0]) != str:
            self.position = [[u] for u in self.data['mutated_position']]

        else:

            temp = [u.split(',') for u in self.data['mutated_position']]
            self.position = []
            for u in temp:
                pos = [int(v) for v in u]
                self.position.append(pos)

    def __getitem__(self, idx):
        return [self.seq[idx], self.attention_mask[idx], self.target[idx],self.tgt_mask[idx] ,self.position[idx], self.score[idx], self.pid[idx]]

    def __len__(self):
        return len(self.score)

    def collate_fn(self, data):
        seq = torch.tensor(np.array([u[0] for u in data]))
        att_mask = torch.tensor(np.array([u[1] for u in data]))
        tgt = torch.tensor(np.array([u[2] for u in data]))
        tgt_mask = torch.tensor(np.array([u[3] for u in data]))
        pos = [torch.tensor(u[4]) for u in data]
        score = torch.tensor(np.array([u[5] for u in data]), dtype=torch.float32)
        pid = torch.tensor(np.array([u[6] for u in data]))
        return seq, att_mask, tgt, tgt_mask, pos, score, pid

In [16]:
dataset = Mutation_Set(data, dataset_name, tokenizer, sep_len=seq_len)

In [17]:
len(dataset.attention_mask[0]), len(dataset.attention_mask)

(1024, 1271)

In [18]:
def sample_data(dataset_name, seed, shot, frac=0.2):
    '''
    sample the train data and test data
    :param seed: sample seed
    :param frac: the fraction of testing data, default to 0.2
    :param shot: the size of training data
    '''

    data = pd.read_csv(f'data/{dataset_name}/data.csv', index_col=0)
    test_data = data.sample(frac=frac, random_state=seed)
    train_data = data.drop(test_data.index)
    
    # low-N training
    # prepares the few labeled mutants for fine-tuning while leaving the rest for testing or validation
    kshot_data = train_data.sample(n=shot, random_state=seed)
    
    assert len(kshot_data) == shot, (
        f'expected {shot} train examples, received {len(train_data)}')

    kshot_data.to_csv(f'data/{dataset_name}/train.csv')
    test_data.to_csv(f'data/{dataset_name}/test.csv')


def split_train(dataset_name):
    '''
    five equal split training data, one of which will be used as validation set when training ConFit
    '''
    train = pd.read_csv(f'data/{dataset_name}/train.csv', index_col=0)
    tlen = int(np.ceil(len(train) / 5))
    start = 0
    for i in range(1, 5):
        csv = train[start:start + tlen]
        start += tlen
        csv.to_csv(f'data/{dataset_name}/train_{i}.csv')
    csv = train[start:]
    csv.to_csv(f'data/{dataset_name}/train_{5}.csv')





def spearman(y_pred, y_true):
    if np.var(y_pred) < 1e-6 or np.var(y_true) < 1e-6:
        return 0.0
    return spearmanr(y_pred, y_true)[0]

def compute_stat(sr):
    sr = np.asarray(sr)
    mean = np.mean(sr)
    std = np.std(sr)
    sr = (sr,)
    ci = list(bootstrap(sr, np.mean).confidence_interval)
    return mean, std, ci

In [27]:
def BT_loss(scores, golden_score):
    loss = torch.tensor(0.)
    loss = loss.cuda()
    for i in range(len(scores)):
        for j in range(i, len(scores)):
            if golden_score[i] > golden_score[j]:
                loss += torch.log(1+torch.exp(scores[j]-scores[i]))
            else:
                loss += torch.log(1+torch.exp(scores[i]-scores[j]))
    return loss


def KLloss(logits, logits_reg, seq, att_mask):

    creterion_reg = torch.nn.KLDivLoss(reduction='mean')
    batch_size = int(seq.shape[0])

    loss = torch.tensor(0.)
    loss = loss.cuda()
    probs = torch.softmax(logits, dim=-1)
    probs_reg = torch.softmax(logits_reg, dim=-1)
    for i in range(batch_size):

        probs_i = probs[i]
        probs_reg_i = probs_reg[i]


        seq_len = torch.sum(att_mask[i])

        reg = probs_reg_i[torch.arange(0, seq_len), seq[i, :seq_len]]
        pred = probs_i[torch.arange(0, seq_len), seq[i, :seq_len]]

        loss += creterion_reg(reg.log(), pred)
    return loss

def evaluate(model, testloader, tokenizer, accelerator, istest=False):
    model.eval()
    seq_list = []
    score_list = []
    gscore_list = []
    device = next(model.parameters()).device  # Get model's device (e.g., 'cuda:0')
    with torch.no_grad():
        for step, data in enumerate(testloader):
            seq, mask = data[0].to(device), data[1].to(device)
            wt, wt_mask = data[2].to(device), data[3].to(device)
            pos = [p.to(device) for p in data[4]]
            golden_score = data[5].to(device)
            pid = data[6].to(device)
            if istest:
                if accelerator is not None:
                    pid = accelerator.gather(pid)
                for s in pid:
                    seq_list.append(s.cpu())

            score, logits = compute_score(model, seq, mask, wt, pos, tokenizer)

            if accelerator is not None:
                score = accelerator.gather(score)
                golden_score = accelerator.gather(golden_score)
            score = np.asarray(score.cpu())
            golden_score = np.asarray(golden_score.cpu())
            score_list.extend(score)
            gscore_list.extend(golden_score)
    score_list = np.asarray(score_list)
    gscore_list = np.asarray(gscore_list)
    sr = spearman(score_list, gscore_list)

    if istest:
        seq_list = np.asarray(seq_list)

        return sr, score_list, seq_list
    else:
        return sr

In [20]:
# training size
shot = 100
seed = 0
sample_data(dataset_name, seed, shot)
split_train(dataset_name)

In [21]:
train_data = pd.read_csv(f'data/{dataset_name}/train.csv', index_col=0)
val_data = pd.read_csv(f'data/{dataset_name}/train_1.csv', index_col=0)
test_data = pd.read_csv(f'data/{dataset_name}/test.csv', index_col=0)

In [22]:
train_dataset = Mutation_Set(train_data, dataset_name, tokenizer, sep_len=seq_len)
val_dataset = Mutation_Set(val_data, dataset_name, tokenizer, sep_len=seq_len)
test_dataset = Mutation_Set(test_data, dataset_name, tokenizer, sep_len=seq_len)

In [28]:
batch_size = 2
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          collate_fn=train_dataset.collate_fn)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False,
                        collate_fn=val_dataset.collate_fn)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False,
                         collate_fn=test_dataset.collate_fn)

In [29]:
batch = next(iter(train_loader))

In [73]:
len(batch), batch[0].shape, batch[1].shape

(7, torch.Size([8, 1024]), torch.Size([8, 1024]))

In [31]:
base_model = model

lora_config = LoraConfig(
    task_type=None,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["query", "key", "value"])

model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()

trainable params: 2,027,520 || all params: 654,384,054 || trainable%: 0.3098364007506821


In [None]:
# model_name = "facebook/esm2_t48_15B_UR50D"
# base_model = EsmForMaskedLM.from_pretrained(model_name, 
#                                             torch_dtype=torch.float16,
#                                             # low_cpy_mem_usage=True,
#                                             use_auth_token=True)

# lora_config = LoraConfig(
#     task_type="CAUSAL_LM",
#     r=8,
#     lora_alpha=16,
#     lora_dropout=0.1,
#     target_modules=["query", "value"])

# model = get_peft_model(base_model, lora_config)
# model.print_trainable_parameters()

Loading checkpoint shards: 100%|██████████| 7/7 [01:42<00:00, 14.65s/it]


trainable params: 7,864,320 || all params: 15,142,217,634 || trainable%: 0.0519363820418327


In [32]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
accelerator = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
lambda_reg = 0.1
device = "cuda"
model.to(device)

PeftModel(
  (base_model): LoraModel(
    (model): EsmForMaskedLM(
      (esm): EsmModel(
        (embeddings): EsmEmbeddings(
          (word_embeddings): Embedding(33, 1280, padding_idx=1)
          (dropout): Dropout(p=0.0, inplace=False)
          (position_embeddings): Embedding(1026, 1280, padding_idx=1)
        )
        (encoder): EsmEncoder(
          (layer): ModuleList(
            (0-32): 33 x EsmLayer(
              (attention): EsmAttention(
                (self): EsmSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=1280, out_features=1280, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=1280, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_f

In [33]:
epochs = 5
for epoch in range(epochs):
    model.train()
    total_loss = 0
    
    for batch in train_loader:
        seq, mask, wt, wt_mask, pos, golden_score, pid = batch
        seq, mask, wt, wt_mask, golden_score, pid = seq.to(device), mask.to(device), wt.to(device), wt_mask.to(device), golden_score.to(device), pid.to(device)
        pos = [p.to(device) for p in pos]
        score, logits = compute_score(model, seq, mask, wt, pos, tokenizer)
        l_bt = BT_loss(score, golden_score)
        l_reg = KLloss(logits, logits, seq, mask)
        loss = l_bt + lambda_reg * l_reg
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch {epoch + 1}, Training Loss: {total_loss / len(train_loader)}')
    
    val_sr = evaluate(model, val_loader, tokenizer, None)
    print(f'Epoch {epoch + 1}, Validation Spearman: {val_sr}')



Epoch 1, Training Loss: 2.368814775943756
Epoch 1, Validation Spearman: 0.675187969924812
Epoch 2, Training Loss: 2.4839784908294678
Epoch 2, Validation Spearman: 0.7473684210526315
Epoch 3, Training Loss: 2.216451325416565
Epoch 3, Validation Spearman: 0.7684210526315788
Epoch 4, Training Loss: 1.897665684223175
Epoch 4, Validation Spearman: 0.7428571428571428
Epoch 5, Training Loss: 2.2363506412506102
Epoch 5, Validation Spearman: 0.7639097744360903


In [None]:
# import torch, gc
# torch.cuda.empty_cache()
# gc.collect()

# print(torch.cuda.memory_summary())

## Integrate with SPURS

In [4]:
from Bio import SeqIO
from pathlib import Path

In [4]:
base_dir = Path("/work/yunan/PsiFit/data/proteingym")
datasets = [d.name for d in base_dir.iterdir() if d.is_dir()]

data = []

for dataset in datasets:
    fasta_path = base_dir / dataset / "wildtype.fasta"
    if fasta_path.exists():
        try:
            record = next(SeqIO.parse(fasta_path, "fasta"))
            seq = str(record.seq)
            length = len(seq)
            data.append({'dms_id': dataset, 'seq_length': length})
        except Exception as e:
            print(f"Error reading {fasta_path}: {e}")
            
df = pd.DataFrame(data)

df.sort_values(by='seq_length', ascending=True, inplace=True)

In [5]:
df

Unnamed: 0,dms_id,seq_length
50,VG08_BPP22_Tsuboyama_2023_2GP8,40
65,SQSTM_MOUSE_Tsuboyama_2023_2RRU,40
139,OTU7A_HUMAN_Tsuboyama_2023_2L2D,42
25,HCP_LAMBD_Tsuboyama_2023_2L6Q,55
6,DN7A_SACS2_Tsuboyama_2023_1JIC,55
...,...,...
39,NPC1_HUMAN_Erwood_2022_RPE1,1278
118,NPC1_HUMAN_Erwood_2022_HEK293T,1278
84,BRCA1_HUMAN_Findlay_2018,1863
117,SCN5A_HUMAN_Glazer_2019,2016


In [6]:
df[~df['dms_id'].str.contains('Tsuboyama')]

Unnamed: 0,dms_id,seq_length
110,ENVZ_ECOLI_Ghose_2023,60
91,IF1_ECOLI_Kelsic_2016,72
93,TAT_HV1BR_Fernandes_2016,86
47,A0A247D711_LISMN_Stadelmann_2021,87
32,CCDB_ECOLI_Tripathi_2016,101
...,...,...
39,NPC1_HUMAN_Erwood_2022_RPE1,1278
118,NPC1_HUMAN_Erwood_2022_HEK293T,1278
84,BRCA1_HUMAN_Findlay_2018,1863
117,SCN5A_HUMAN_Glazer_2019,2016


In [7]:
dms_id = "ENVZ_ECOLI_Ghose_2023"
# dms_id = "IF1_ECOLI_Kelsic_2016"
spurs_path = base_dir / dms_id / "spurs_prediction.tsv"
spurs_df = pd.read_csv(spurs_path, sep='\t', index_col=0)
spurs_ddg = torch.tensor(spurs_df.values, dtype=torch.float32)

In [8]:
# spurs_df.head()
spurs_df.columns

Index(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q',
       'R', 'S', 'T', 'V', 'W', 'Y'],
      dtype='object')

In [9]:
wt = next(SeqIO.parse(base_dir / dms_id / "wildtype.fasta", "fasta"))
wt_seq = str(wt.seq)

In [10]:
wt_seq

'LADDRTLLMAGVSHDLRTPLTRIRLATEMMSEQDGYLAESINKDIEECNAIIEQFIDYLR'

#### PsiFit Class

In [11]:
from transformers import EsmForMaskedLM, EsmTokenizer

In [12]:
model_name = "facebook/esm2_t33_650M_UR50D"
esm_model = EsmForMaskedLM.from_pretrained(model_name)
esm_tokenizer = EsmTokenizer.from_pretrained(model_name)



In [13]:
aa_list = list("ACDEFGHIKLMNPQRSTVWY")

In [14]:
spurs_df.columns.tolist() == aa_list

True

In [18]:
# esm natural aa order
vocab = esm_tokenizer.get_vocab()
aa_tokens_sorted = sorted((v, k) for k, v in vocab.items() if len(k) == 1 and k.isupper())
print("ESM natural amino acid order:", [t[1] for t in aa_tokens_sorted])

ESM natural amino acid order: ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O']


In [19]:
aa_list == [t[1] for t in aa_tokens_sorted]

False

In [16]:
positions = spurs_df.index.tolist()
esm_probs_df =  pd.DataFrame(index=positions, columns=aa_list)

In [15]:
aa_token_ids = [esm_tokenizer.convert_tokens_to_ids(aa) for aa in aa_list]

In [23]:
mask_token = esm_tokenizer.mask_token

In [21]:
from tqdm import tqdm

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
esm_model.to(device)

EsmForMaskedLM(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 1280, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 1280, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-32): 33 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=1280, out_features=1280, bias=True)
              (key): Linear(in_features=1280, out_features=1280, bias=True)
              (value): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((1280,), eps=1e-05, 

In [26]:
import torch
import torch.nn as nn

In [27]:
class PsiFit(nn.Module):
    def __init__(self, esm_model, spurs_ddg, aa_token_ids):
        super().__init__()
        self.esm_model = esm_model
        self.ddg = spurs_ddg.to(device)
        self.aa_token_ids = aa_token_ids
        
    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        logits = outputs.logits
        
        seq_len = input_ids.shape[1] - 2  
        aligned_logits = logits[:, 1:seq_len+1, self.aa_token_ids] 
        aligned_logits += self.ddg.unsqueeze(0).to(aligned_logits.device)
        
        if logits is not None:
            full_logits = logits.clone()
            full_logits[:, 1:seq_len+1, self.aa_token_ids] = aligned_logits
            outputs.logits = full_logits
            
        return outputs

In [28]:
psifit = PsiFit(esm_model, spurs_ddg, aa_token_ids)
psifit.to(device)

PsiFit(
  (esm_model): EsmForMaskedLM(
    (esm): EsmModel(
      (embeddings): EsmEmbeddings(
        (word_embeddings): Embedding(33, 1280, padding_idx=1)
        (dropout): Dropout(p=0.0, inplace=False)
        (position_embeddings): Embedding(1026, 1280, padding_idx=1)
      )
      (encoder): EsmEncoder(
        (layer): ModuleList(
          (0-32): 33 x EsmLayer(
            (attention): EsmAttention(
              (self): EsmSelfAttention(
                (query): Linear(in_features=1280, out_features=1280, bias=True)
                (key): Linear(in_features=1280, out_features=1280, bias=True)
                (value): Linear(in_features=1280, out_features=1280, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
                (rotary_embeddings): RotaryEmbedding()
              )
              (output): EsmSelfOutput(
                (dense): Linear(in_features=1280, out_features=1280, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
      

In [29]:
psifit.train()

PsiFit(
  (esm_model): EsmForMaskedLM(
    (esm): EsmModel(
      (embeddings): EsmEmbeddings(
        (word_embeddings): Embedding(33, 1280, padding_idx=1)
        (dropout): Dropout(p=0.0, inplace=False)
        (position_embeddings): Embedding(1026, 1280, padding_idx=1)
      )
      (encoder): EsmEncoder(
        (layer): ModuleList(
          (0-32): 33 x EsmLayer(
            (attention): EsmAttention(
              (self): EsmSelfAttention(
                (query): Linear(in_features=1280, out_features=1280, bias=True)
                (key): Linear(in_features=1280, out_features=1280, bias=True)
                (value): Linear(in_features=1280, out_features=1280, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
                (rotary_embeddings): RotaryEmbedding()
              )
              (output): EsmSelfOutput(
                (dense): Linear(in_features=1280, out_features=1280, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
      

In [30]:
optimizer = torch.optim.Adam(psifit.parameters(), lr=1e-5)
num_epochs = 5

In [32]:
for epoch in range(num_epochs):
    total_loss = 0
    
    for idx, pos in enumerate(positions):
        seq_list = list(wt_seq)
        seq_list[pos - 1] = mask_token 
        masked_seq = ''.join(seq_list)
        
        inputs = esm_tokenizer(masked_seq, return_tensors='pt').to(device)
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        
        labels = input_ids.clone()
        labels[:, :] = -100  
        orig_token_id = esm_tokenizer.convert_tokens_to_ids(wt_seq[pos - 1])
        labels[:, pos] = orig_token_id  
        
        outputs = psifit(input_ids, attention_mask, labels)
        loss = outputs.loss  
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}: Avg Loss {total_loss / len(positions):.4f}")

Epoch 1: Avg Loss 2.0279
Epoch 2: Avg Loss 0.5987
Epoch 3: Avg Loss 0.1215
Epoch 4: Avg Loss 0.0298
Epoch 5: Avg Loss 0.0086
