In [1]:
import os
import torch
import numpy as np
import pandas as pd

In [2]:
from Bio import SeqIO

In [3]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
model = AutoModelForMaskedLM.from_pretrained("AI4Protein/ProSST-2048", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("AI4Protein/ProSST-2048", trust_remote_code=True)

In [6]:
tokenizer.add_tokens

<bound method SpecialTokensMixin.add_tokens of EsmTokenizer(name_or_path='AI4Protein/ProSST-2048', vocab_size=25, model_max_length=2048, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<eos>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'cls_token': '<cls>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True, added_tokens_decoder={
	0: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<cls>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("<eos>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	23: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	24: AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)>

In [7]:
def read_seq(fasta):
    for record in SeqIO.parse(fasta, "fasta"):
        return str(record.seq)


def tokenize_structure_sequence(structure_sequence):
    shift_structure_sequence = [i + 3 for i in structure_sequence]
    shift_structure_sequence = [1, *shift_structure_sequence, 2]
    return torch.tensor(
        [
            shift_structure_sequence,
        ],
        dtype=torch.long,
    )


In [8]:
from pathlib import Path

In [9]:
residue_sequence_dir = "/home/yining_yang/Documents/lm/SSRAP/VenusREM/data/proteingym_v1/aa_seq"
structure_sequence_dir = "/home/yining_yang/Documents/lm/SSRAP/VenusREM/data/proteingym_v1/struc_seq/2048"
name = "A0A2Z5U3Z0_9INFA_Wu_2014"

residue_fasta = Path(residue_sequence_dir) / f"{name}.fasta"
structure_fasta = Path(structure_sequence_dir) / f"{name}.fasta"



In [10]:
sequence = read_seq(residue_fasta)
structure_sequence = read_seq(structure_fasta)

structure_sequence = [int(i) for i in structure_sequence.split(",")]
ss_input_ids = tokenize_structure_sequence(structure_sequence).to(device)
tokenized_results = tokenizer([sequence], return_tensors="pt")
input_ids = tokenized_results["input_ids"].to(device)
attention_mask = tokenized_results["attention_mask"].to(device)

In [11]:
input_ids.shape

torch.Size([1, 567])

In [12]:
len("MKAKLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHSVNLLEDSHNGKLCKLKGIAPLQLGKCNITGWLLGNPECDSLLPARSWSYIVETPNSENGACYPGDLIDYEELREQLSSVSSLERFEIFPKESSWPNHTFNGVTVSCSHRGKSSFYRNLLWLTKKGDSYPKLTNSYVNNKGKEVLVLWGVHHPSSSDEQQSLYSNGNAYVSVASSNYNRRFTPEIAARPKVRDQHGRMNYYWTLLEPGDTIIFEATGNLIAPWYAFALSRGFESGIITSNASMHECNTKCQTPQGAINSNLPFQNIHPVTIGECPKYVRSTKLRMVTGLRNIPSIQYRGLFGAIAGFIEGGWTGMIDGWYGYHHQNEQGSGYAADQKSTQNAINGITNKVNSVIEKMNTQFTAVGKEFNNLEKRMENLNKKVDDGFLDIWTYNAELLVLLENERTLDFHDLNVKNLYEKVKSQLKNNAKEIGNGCFEFYHKCDNECMESVRNGTYDYPKYSEESKLNREKIDGVKLESMGVYQILAIYSTVASSLVLLVSLGAISFWMCSNGSLQCRICI")

565

In [12]:
model = model.to(device)
outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        ss_input_ids=ss_input_ids,
        labels=input_ids,
        output_hidden_states=True,
        # output_attentions=True,  # Make sure to get attention outputs
        return_dict=True
    )


In [12]:
len(outputs.hidden_states)

13

In [15]:
last_embedding = outputs.hidden_states[-1]

In [16]:
last_embedding.size()

torch.Size([1, 567, 768])

In [17]:
# read all csv files in /home/yining_yang/Documents/lm/SSRAP/VenusREM/result/proteingym_v1_original/scores in a for loop
# Define the directory path
directory_path = "/home/yining_yang/Documents/lm/SSRAP/VenusREM/result/proteingym_v1_original/scores"

# Create an empty list to store DataFrames
dataframes = []

# Loop through each file in the directory
for filename in os.listdir(directory_path):
    # Check if the file ends with .csv
    if filename.endswith(".csv"):
        file_path = os.path.join(directory_path, filename)
        df = pd.read_csv(file_path)  # Read the CSV file
        print(df)
        break

# Optionally combine them into one DataFrame
# combined_df = pd.concat(dataframes, ignore_index=True)


     mutant  DMS_score  DMS_score_bin  VenusREM
0      A46C   0.503519              1 -1.188972
1      A46D   0.162813              1 -0.495420
2      A46E   0.374461              1 -0.268819
3      A46F  -0.051768              1 -0.995978
4      A46G  -0.405567              0 -0.798013
...     ...        ...            ...       ...
1327   W38R   0.409986              1 -0.006912
1328   W38S  -0.026247              1 -0.241827
1329   W38T  -0.047858              1 -0.338707
1330   W38V   0.207099              1 -0.326499
1331   W38Y   0.085070              1 -0.460771

[1332 rows x 4 columns]


## Few shot learning

In [12]:
from datasets import load_dataset


fewshot_DMS_csv = '/home/yining_yang/Documents/lm/SSRAP/data/fewshot/A0A2Z5U3Z0_9INFA_Wu_2014.csv'
# dataset = load_dataset('csv', data_files=fewshot_DMS_csv)
# dataset = dataset['train'].train_test_split(test_size=0.1)

In [13]:
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import torch

class DMSDataset(Dataset):
    def __init__(self, csv_file, sequence, structure_sequence, tokenizer):
        self.data = pd.read_csv(csv_file)
        self.sequence = sequence
        self.structure_sequence = structure_sequence
        self.tokenizer = tokenizer
        self.vocab = tokenizer.get_vocab()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        mutant = self.data.iloc[idx]['mutant']
        dms_score = self.data.iloc[idx]['DMS_score']

        # Tokenize sequence
        # tokenized = self.tokenizer(self.sequence, return_tensors="pt")
        # input_ids = tokenized["input_ids"].squeeze(0)
        # attention_mask = tokenized["attention_mask"].squeeze(0)

        # Process structure sequence
        # ss_input_ids = torch.tensor(self.structure_sequence, dtype=torch.long)

        return {
            # 'input_ids': input_ids,
            # 'attention_mask': attention_mask,
            # 'ss_input_ids': ss_input_ids,
            'mutant': mutant,
            'dms_score': torch.tensor(dms_score, dtype=torch.float)
        }


In [22]:
# from transformers import AutoTokenizer

# # tokenizer = AutoTokenizer.from_pretrained("AI4Protein/ProSST-2048", trust_remote_code=True)

# def tokenize_function(examples):
#     return tokenizer(examples["mutated_sequence"], padding="max_length", truncation=True)

# tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [23]:
# import torch.nn as nn

# class SpearmanLoss(nn.Module):
#     def __init__(self, regularization_strength=1.0):
#         super(SpearmanLoss, self).__init__()
#         self.regularization_strength = regularization_strength

#     def forward(self, preds, target):
#         preds = torchsort.soft_rank(preds, regularization_strength=self.regularization_strength)
#         target = torchsort.soft_rank(target, regularization_strength=self.regularization_strength)
#         preds = preds - preds.mean()
#         preds = preds / preds.norm()
#         target = target - target.mean()
#         target = target / target.norm()
#         return 1 - (preds * target).sum()

In [None]:
# # model = AutoModelForMaskedLM.from_pretrained("AI4Protein/ProSST-2048", trust_remote_code=True)
# num_epochs =5
# model.train()

# optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
# loss_fn = SpearmanLoss()

# for epoch in range(num_epochs):
#     for batch in train_dataloader:
#         optimizer.zero_grad()
#         outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
#         logits = outputs.logits
#         # Compute your predictions and targets here
#         loss = loss_fn(predictions, targets)
#         loss.backward()
#         optimizer.step()

In [14]:
dataset = DMSDataset(fewshot_DMS_csv, sequence, structure_sequence, tokenizer)
# dataloader = DataLoader(dataset, batch_size=20, shuffle=True)

In [15]:
# Define the split sizes
train_size = int(0.9 * len(dataset))  # 90% for training
val_size = len(dataset) - train_size  # 10% for validation

# Set a seed for reproducibility
torch.manual_seed(42)

# Split the dataset
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=20, shuffle=False)

In [18]:
import torch.nn as nn
import torch.optim as optim
import itertools

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


num_epochs = 5
# loss_fn = nn.MarginRankingLoss(margin=0.0)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

ss_input_ids = tokenize_structure_sequence(structure_sequence).to(device)
tokenized_results = tokenizer([sequence], return_tensors="pt")
input_ids = tokenized_results["input_ids"].to(device)
attention_mask = tokenized_results["attention_mask"].to(device)

for epoch in range(num_epochs):
    model.train()
    for batch in train_loader:
        # input_ids = batch['input_ids'].to(device)
        # attention_mask = batch['attention_mask'].to(device)
        # ss_input_ids = batch['ss_input_ids'].to(device)
        # print("the "+str(epoch)+"th epoch")
        dms_scores = batch['dms_score'].to(device)
        mutants = batch['mutant']
        # print(mutants)
        # break

        batch_size = len(mutants)
        # print(input_ids.size())
        # print(batch_size)
        # pred_scores = torch.tensor([0]*batch_size).to(device)
        pred_scores=[]
        # print(pred_scores.size())
        # break
        # print(ss_input_ids)

        outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    ss_input_ids=ss_input_ids,
                    labels=input_ids,
                    # output_hidden_states=True,
                    # # output_attentions=True,  # Make sure to get attention outputs
                    # return_dict=True
                )

        # logits = outputs.logits
        logits = torch.log_softmax(outputs.logits[:, 1:-1, :], dim=-1)
        # print(logits)
        # break

        for i in range(batch_size):

            pred_score = 0
            for sub_mutant in mutants[i].split(":"):
                wt, idx, mt = sub_mutant[0], int(sub_mutant[1:-1]) - 1, sub_mutant[-1]
                score = logits[0, idx, tokenizer.convert_tokens_to_ids(mt)] - logits[0, idx, tokenizer.convert_tokens_to_ids(wt)]
                pred_score += score
                # print(score)

                # print(pred_score)
            # pred_scores[i] = pred_score
            pred_scores.append(pred_score)

        # pred_scores = pred_scores.clone().detach().requires_grad_(True)
        # print(pred_scores)
        pred_scores = torch.stack(pred_scores)
        # print(pred_scores)
        # break
        # # Prepare pairs for MarginRankingLoss
        # pairs = list(itertools.combinations(range(batch_size), 2))
        # if not pairs:
        #     continue  # Skip if less than 2 samples in batch

        # pred1 = torch.stack([pred_scores[i] for i, j in pairs])
        # pred2 = torch.stack([pred_scores[j] for i, j in pairs])
        # dms1 = torch.stack([dms_scores[i] for i, j in pairs])
        # dms2 = torch.stack([dms_scores[j] for i, j in pairs])

        # # Determine target: 1 if dms1 > dms2, -1 if dms1 < dms2
        # target = torch.sign(dms1 - dms2)
        # non_zero_indices = target != 0
        # if non_zero_indices.sum() == 0:
        #     continue  # Skip if all targets are zero

        # pred1 = pred1[non_zero_indices]
        # pred2 = pred2[non_zero_indices]
        # target = target[non_zero_indices]

                # Compute predicted scores for each sample in the batch
        # This assumes you have a way to map logits to predicted scores
        # For example, summing log-probabilities of the correct tokens

        # batch_size = pred_scores.size(0)

        if batch_size > 1:
            pred_mean = pred_scores.mean()
            pred_std = pred_scores.std(unbiased=False)
            pred_scores_std = (pred_scores - pred_mean) / (pred_std + 1e-8)

            dms_mean = dms_scores.mean()
            dms_std = dms_scores.std(unbiased=False)
            dms_scores_std = (dms_scores - dms_mean) / (dms_std + 1e-8)

            loss = loss_fn(pred_scores_std.view(-1), dms_scores_std.view(-1))
        else:
            # Handle batch size of 1 appropriately
            # For example, you might skip the update or accumulate gradients over multiple batches
            continue


        # loss = loss_fn(pred1, pred2, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        # Validation phase
    model.eval()
    val_losses = []

    # break
    with torch.no_grad():
        for batch in val_loader:
            dms_scores = batch['dms_score'].to(device)
            mutants = batch['mutant']

            batch_size = len(mutants)
            pred_scores = []

            for i in range(batch_size):
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    ss_input_ids=ss_input_ids,
                    labels=input_ids
                )

                logits = outputs.logits
                logits = torch.log_softmax(logits[:, 1:-1, :], dim=-1)

                pred_score = 0
                for sub_mutant in mutants[i].split(":"):
                    wt, idx, mt = sub_mutant[0], int(sub_mutant[1:-1]) - 1, sub_mutant[-1]
                    score = logits[0, idx, tokenizer.convert_tokens_to_ids(mt)] - logits[0, idx, tokenizer.convert_tokens_to_ids(wt)]
                    pred_score += score.item()
                pred_scores.append(pred_score)

            pred_scores = torch.tensor(pred_scores, device=device)

            # Ensure batch size is greater than 1 to compute standard deviation
            if pred_scores.size(0) > 1:
                pred_mean = pred_scores.mean()
                pred_std = pred_scores.std(unbiased=False)
                pred_scores_std = (pred_scores - pred_mean) / (pred_std + 1e-8)

                dms_mean = dms_scores.mean()
                dms_std = dms_scores.std(unbiased=False)
                dms_scores_std = (dms_scores - dms_mean) / (dms_std + 1e-8)

                # Reshape tensors to ensure matching dimensions
                pred_scores_std = pred_scores_std.view(-1)
                dms_scores_std = dms_scores_std.view(-1)



                loss = loss_fn(pred_scores_std, dms_scores_std)
                
                # print(loss.item())
                val_losses.append(loss.item())
            else:
                # Skip this batch if batch size is 1
                continue

    avg_val_loss = sum(val_losses) / len(val_losses) if val_losses else float('nan')
    print(f"Validation Loss: {avg_val_loss:.4f}")


Validation Loss: 1.1986
Validation Loss: 1.1765
Validation Loss: 1.1556
Validation Loss: 1.1267
Validation Loss: 1.1298


In [24]:
model.save_pretrained("../model/prosst_finetuned_"+name+"_epoch"+str(num_epochs)+"/model")
tokenizer.save_pretrained("../model/prosst_finetuned_"+name+"_epoch"+str(num_epochs)+"/model")

('../model/prosst_finetuned_A0A2Z5U3Z0_9INFA_Wu_2014_epoch5/model/tokenizer_config.json',
 '../model/prosst_finetuned_A0A2Z5U3Z0_9INFA_Wu_2014_epoch5/model/special_tokens_map.json',
 '../model/prosst_finetuned_A0A2Z5U3Z0_9INFA_Wu_2014_epoch5/model/vocab.txt',
 '../model/prosst_finetuned_A0A2Z5U3Z0_9INFA_Wu_2014_epoch5/model/added_tokens.json')

In [42]:
val_losses

[]

In [23]:
model_less  = AutoModelForMaskedLM.from_pretrained("/home/yining_yang/Documents/lm/SSRAP/model/prosst_finetuned_A0A2Z5U3Z0_9INFA_Wu_2014_epoch5/model", trust_remote_code=True)


In [22]:
model_less

ProSSTForMaskedLM(
  (prosst): ProSSTModel(
    (embeddings): ProSSTEmbeddings(
      (word_embeddings): Embedding(25, 768, padding_idx=0)
      (ss_embeddings): Embedding(2051, 768)
      (ss_layer_norm): ProSSTLayerNorm()
      (LayerNorm): ProSSTLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ProSSTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ProSSTLayer(
          (attention): ProSSTAttention(
            (self): DisentangledSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (pos_dropout): Dropout(p=0.1, inplace=False)
              (pos_proj): Linear(in_features=768, out_features=768, bias=False)
              (pos_q_proj): Linear(in_features=768, out_features=768, bias=True)
              (ss_proj): Linear(in_features=768, out_features=

In [31]:
print(loss.grad_fn)

None


In [33]:
model.enable_adapters

<bound method PeftAdapterMixin.enable_adapters of ProSSTForMaskedLM(
  (prosst): ProSSTModel(
    (embeddings): ProSSTEmbeddings(
      (word_embeddings): Embedding(25, 768, padding_idx=0)
      (ss_embeddings): Embedding(2051, 768)
      (ss_layer_norm): ProSSTLayerNorm()
      (LayerNorm): ProSSTLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ProSSTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ProSSTLayer(
          (attention): ProSSTAttention(
            (self): DisentangledSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (pos_dropout): Dropout(p=0.1, inplace=False)
              (pos_proj): Linear(in_features=768, out_features=768, bias=False)
              (pos_q_proj): Linear(in_features=768, out_features=768, bias=True)
            