In [4]:
import pandas as pd
from tqdm import tqdm

In [10]:
def extract_esm_feature(dataset):
    from transformers import AutoTokenizer, EsmModel
    import torch

    device = "cuda"

    tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D" )
    model = EsmModel.from_pretrained("facebook/esm2_t12_35M_UR50D")
    model.to(device)

    train_path = f"./data/{dataset}/raw/data_train.csv"
    test_path = f"./data/{dataset}/raw/data_test.csv"

    train_df = pd.read_csv(train_path, usecols=['target_sequence'])
    test_df = pd.read_csv(test_path, usecols=['target_sequence'])
    df = pd.concat([train_df, test_df], axis=0)
    df.drop_duplicates(inplace=True)

    seqs = df['target_sequence'].tolist()

    esm2_feature_dict = {}

    for sequence in tqdm(seqs):
        inputs = tokenizer(
            sequence,
            add_special_tokens = True,
            max_length = 1200, 
            padding = 'max_length',
            truncation = True,
            return_tensors = 'pt'
        )
        inputs.to(device)
        outputs = model(**inputs)
        last_hidden_states = outputs.last_hidden_state
        embedding = last_hidden_states.cpu().detach().numpy()
        esm2_feature_dict.update({sequence : embedding})

    torch.save(esm2_feature_dict, f'./data/{dataset}/raw/esm2.pth')

In [11]:
extract_esm_feature("davis")

Some weights of the model checkpoint at facebook/esm2_t12_35M_UR50D were not used when initializing EsmModel: ['lm_head.decoder.weight', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight']
- This IS expected if you are initializing EsmModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


379


In [12]:
extract_esm_feature("kiba")

Some weights of the model checkpoint at facebook/esm2_t12_35M_UR50D were not used when initializing EsmModel: ['lm_head.decoder.weight', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight']
- This IS expected if you are initializing EsmModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


229
