In [1]:
from datasets import load_dataset

In [4]:
data = load_dataset('GleghornLab/SS8')
data

DatasetDict({
    train: Dataset({
        features: ['seqs', 'labels'],
        num_rows: 10792
    })
    valid: Dataset({
        features: ['seqs', 'labels'],
        num_rows: 626
    })
    test: Dataset({
        features: ['seqs', 'labels'],
        num_rows: 50
    })
})

In [6]:
vocab = set()
for label in data['train']['labels']:
    vocab.update(label)
vocab

{'B', 'C', 'D', 'E', 'G', 'H', 'I', 'S', 'T'}

In [5]:
data['train'].to_pandas().to_csv('ss8_train.csv', index=False)

In [3]:
import torch
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from models.modeling_esm_diff import ESM_Diff_Binders, ESMDiffConfig
from models.utils import wrap_lora

MODEL_PATH = 'lhallee/esm_diff_bind_150'
base_path = 'GleghornLab/esm_diff_150'

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

local_weight_file = hf_hub_download(
    repo_id=MODEL_PATH,
    filename='model.safetensors',
    repo_type='model',
)

config = ESMDiffConfig.from_pretrained(MODEL_PATH)
model = ESM_Diff_Binders(config=config)
model = wrap_lora(model, r=config.lora_r, lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout)
state_dict = load_file(local_weight_file)

# Track which parameters were loaded
loaded_params = set()
missing_params = set()

for name, param in model.named_parameters():
    found = False
    for key in state_dict.keys():
        if key in name:
            param.data = state_dict[key]
            loaded_params.add(name)
            found = True
            break
    if not found:
        missing_params.add(name)

# Verify all weights were loaded correctly
print(f"Loaded {len(loaded_params)} parameters")
print(f"Missing {len(missing_params)} parameters")
if missing_params:
    print("Missing parameters:")
    for param in sorted(missing_params):
        print(f"  - {param}")

# Move model to device
model = model.to(device)
model

Loaded 856 parameters
Missing 0 parameters


LoraModel(
  (model): ESM_Diff_Binders(
    (esm): FAST_ESM_ENCODER(
      (embeddings): EsmEmbeddings(
        (word_embeddings): Embedding(33, 640, padding_idx=1)
      )
      (encoder): EsmEncoder(
        (layer): ModuleList(
          (0-29): 30 x EsmLayer(
            (attention): EsmAttention(
              (self): EsmSelfAttention(
                (query): lora.Linear(
                  (base_layer): Linear(in_features=640, out_features=640, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.01, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=640, out_features=8, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=8, out_features=640, bias=False)
                  )
                  (lora_embedding_A): ParameterDict()
                  (lora_embedding_B): ParameterDict()
            

In [12]:
for key in model.state_dict().keys():
    print(key)

model.esm.embeddings.word_embeddings.weight
model.esm.encoder.layer.0.attention.self.query.base_layer.weight
model.esm.encoder.layer.0.attention.self.query.base_layer.bias
model.esm.encoder.layer.0.attention.self.query.lora_A.default.weight
model.esm.encoder.layer.0.attention.self.query.lora_B.default.weight
model.esm.encoder.layer.0.attention.self.key.base_layer.weight
model.esm.encoder.layer.0.attention.self.key.base_layer.bias
model.esm.encoder.layer.0.attention.self.key.lora_A.default.weight
model.esm.encoder.layer.0.attention.self.key.lora_B.default.weight
model.esm.encoder.layer.0.attention.self.value.base_layer.weight
model.esm.encoder.layer.0.attention.self.value.base_layer.bias
model.esm.encoder.layer.0.attention.self.value.lora_A.default.weight
model.esm.encoder.layer.0.attention.self.value.lora_B.default.weight
model.esm.encoder.layer.0.attention.self.rotary_embeddings.inv_freq
model.esm.encoder.layer.0.attention.output.dense.base_layer.weight
model.esm.encoder.layer.0.atten

In [11]:
for key in state_dict.keys():
    print(key)


esm.contact_head.regression.bias
esm.contact_head.regression.weight
esm.embeddings.word_embeddings.weight
esm.encoder.emb_layer_norm_after.bias
esm.encoder.emb_layer_norm_after.weight
esm.encoder.layer.0.LayerNorm.bias
esm.encoder.layer.0.LayerNorm.weight
esm.encoder.layer.0.attention.LayerNorm.bias
esm.encoder.layer.0.attention.LayerNorm.weight
esm.encoder.layer.0.attention.output.dense.base_layer.bias
esm.encoder.layer.0.attention.output.dense.base_layer.weight
esm.encoder.layer.0.attention.output.dense.lora_A.default.weight
esm.encoder.layer.0.attention.output.dense.lora_B.default.weight
esm.encoder.layer.0.attention.self.key.base_layer.bias
esm.encoder.layer.0.attention.self.key.base_layer.weight
esm.encoder.layer.0.attention.self.key.lora_A.default.weight
esm.encoder.layer.0.attention.self.key.lora_B.default.weight
esm.encoder.layer.0.attention.self.query.base_layer.bias
esm.encoder.layer.0.attention.self.query.base_layer.weight
esm.encoder.layer.0.attention.self.query.lora_A.defa

In [1]:
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True)
tokenizer = model.tokenizer

sequences = ['MPRTEIN', 'MSEQWENCE']
tokenized = tokenizer(sequences, padding=True, return_tensors='pt')

In [4]:
tokenizer.all_special_ids

[2, 3, 1, 0, 32, 31]

In [2]:
subtypes = ['affinity', 'binding']

f'{[subtype for subtype in subtypes]}'

['affinity', 'binding']
