In [73]:
import lmdb
import json
import pandas as pd
import torch
import numpy as np
import dgeb
from pathlib import Path


dfs_lmdb = {}
for data_type in ['train', 'valid', 'test']:
    env = lmdb.open(f'HumanPPI/normal/{data_type}/')
    res = []

    with env.begin() as txn:
        for key, value in txn.cursor():
            if key == b'info' or key == b'length':
                continue
            res.append(json.loads(value))
    
    env.close()

    dfs_lmdb[data_type] = pd.DataFrame(res)

df_lmdb = pd.concat(dfs_lmdb.values())

In [74]:
df_lmdb

Unnamed: 0,name_1,name_2,seq_1,seq_2,label
0,Q01780,Q9Y333,MAPPSTREPRVLSATSATKSDGEMVLPGFPDADSFVKFALGSVVAV...,MLFYSFFKSLVGKDVVVELKNDLSICGTLHSVDQYLNIKLTDISVT...,1
1,Q9P104,P06213,MASNFNDIVKQGYVRIRSRRLGIYQRCWLVFKKASSKGPKRLEKFS...,MATGGRRGAAAAPLLVAVAALLLGAAGHLYPGEVCPGMDIRNNLTR...,1
2,O00300,P04004,MNNLLCCALVFLDISIKWTTQETFPPKYLHYDEETSHQLLCDKCPP...,MAPLRPLLILALLAWVALADQESCKGRCTEGFNVDKKCQCDELCSY...,1
3,Q9UNY4,P22626,MEEVRCPEHGTFCFLKTGVRDGPNKGKSFYVCRADTCSFVRATDIP...,MEKTLETVPLERKKREKEQFRKLFIGGLSFETTEESLRNYYEQWGK...,1
4,Q15139,Q02156,MSAPPVLRPPSPLLPVAAAAAAAAAALVPGSGPGPAPFLAPVAAPV...,MVVFNGLLKIKICEAVSLKPTAWSLRHAVGPRPQTFLLDPYIALNV...,1
...,...,...,...,...,...
175,P09429,Q92552,MGKGDPKKPRGKMSSYAFFVQTCREEHKKKHPDASVNFSEFSKKCS...,MAASIVRRGMLLARQVVLPQLSPAGKRYLLSSAYVDSHKWEAREKE...,0
176,O94763,Q96M27,MEAPTVETPPDPSPPSAPAPALVPLRAPDVARLREEQEKVVTNCQE...,MMEESGIETTPPGTPPPNPAGLAATAMSSTPVPLAATSSFSSPNVS...,0
177,Q86TM3,Q96M27,MSHWAPEWKRAEANPRDLGASWDVRGSRGSGWSGPFGHQGPRAAGS...,MMEESGIETTPPGTPPPNPAGLAATAMSSTPVPLAATSSFSSPNVS...,0
178,O15151,Q9BY32,MTSFSTSAQCSTSDSACRISPGQINQVRPKLPLLKILHAAGAQGEM...,MAASLVGKKIVFVTGNAKKLEEVVQILGDKFPCTLVAQKIDLPEYQ...,0


In [4]:
def get_embeddings_df(sequences, model_name):
    model_name_for_file = model_name.replace('/', '-').replace(' ', '-').replace('_', '-')
    f = Path(f'embeddings_{model_name_for_file}.parquet')
    if f.exists():
        return pd.read_parquet(f)

    model = dgeb.get_model(model_name, layers="last", batch_size=1, max_seq_length=2048)

    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        embeddings = model.encode(sequences)

    df = pd.DataFrame({'sequence': sequences, 'embedding': list(embeddings.squeeze())})
    df.to_parquet(f)
    return df

sequences = list(set(df_lmdb.seq_1) | set(df_lmdb.seq_2))

embedding_by_sequence = {}
for model_name in ["esm3_sm_open_v1", "facebook/esm2_t33_650M_UR50D"]:
    embeddings_df = get_embeddings_df(sequences, model_name)
    embedding_by_sequence[model_name] = {}
    for elem in embeddings_df.itertuples():
        embedding_by_sequence[model_name][elem.sequence] = elem.embedding

In [26]:
model_name = "facebook/esm2_t33_650M_UR50D"

embeddings = []
for elem in df_lmdb.itertuples():
    emb1 = embedding_by_sequence[model_name][elem.seq_1]
    emb2 = embedding_by_sequence[model_name][elem.seq_2]
    emb = torch.cat([torch.from_numpy(emb1), torch.from_numpy(emb2)])
    embeddings.append(emb)

embeddings = torch.stack(embeddings)

In [27]:
embeddings

tensor([[ 1.2412, -3.2344,  1.7555,  ..., -5.0823, -2.1633, -7.6051],
        [ 0.3723,  1.4925, -3.8769,  ..., -1.4513,  3.1598,  0.7873],
        [ 0.5665,  0.7051, -1.1807,  ..., -1.1339, -0.8480, -2.6322],
        ...,
        [ 0.5372, -5.0282, -0.3123,  ...,  1.4255,  0.5592, -0.4750],
        [ 3.7329, -2.3684,  0.7646,  ..., -2.6803, -4.0307, -2.7534],
        [ 1.2119, -4.5896, -1.4423,  ..., -0.2077,  0.6540,  1.7106]])

In [76]:
from torch import nn
from tqdm.auto import tqdm
from torch.utils.data import TensorDataset

from sklearn.metrics import (
    accuracy_score,
    auc,
    f1_score,
    precision_recall_curve,
    roc_auc_score,
)


class SimpleMLP(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()

        self.project = nn.Sequential(
            nn.Linear(input_size, input_size),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(input_size, output_size),
        )

    def forward(self, x):
        return self.project(x)


def classification_metrics(targets, predictions, threshold=0.5):
    binary_predictions = (predictions >= threshold).astype(int)
    accuracy = accuracy_score(targets, binary_predictions)
    f1 = f1_score(targets, binary_predictions)
    auc_score = roc_auc_score(targets, predictions)
    precision_vals, recall_vals, _ = precision_recall_curve(targets, predictions)
    auprc = auc(recall_vals, precision_vals)
    return {
        "Accuracy": accuracy,
        "AUPRC": auprc,
        "F1 Score": f1,
        "AUROC": auc_score,
    }


@torch.no_grad()
def evaluate(model, loader, name, device):
    model.eval()
    preds = []
    targets = []
    for step, eval_batch in enumerate(loader):
        embs, target = eval_batch
        embs = embs.to(device)
        target = target.to(device)
        pred = model(embs).squeeze(-1)
        pred = torch.sigmoid(pred)

        preds.append(pred.detach().cpu().numpy())
        targets.append(target.cpu().numpy())
    preds = np.concatenate(preds)
    targets = np.concatenate(targets)

    metrics = classification_metrics(targets, preds)
    return {f"{name}_{k}": i for k, i in metrics.items()}


def calculate_mean_std(metrics):
    aggregated_metrics = {}
    for metric_dict in metrics:
        for key, value in metric_dict.items():
            if key not in aggregated_metrics:
                aggregated_metrics[key] = []
            aggregated_metrics[key].append(value)
    mean_std_metrics = {}
    for key, values in aggregated_metrics.items():
        mean_std_metrics[key + "_mean"] = np.mean(values)
        mean_std_metrics[key + "_std"] = np.std(values)
    return mean_std_metrics


model_name = ["esm3_sm_open_v1", "facebook/esm2_t33_650M_UR50D"][0]
print(f'{model_name=}')

datasets = {}
for data_type in ['train', 'valid', 'test']:
    embeddings = []
    for elem in dfs_lmdb[data_type].itertuples():
        emb1 = embedding_by_sequence[model_name][elem.seq_1]
        emb2 = embedding_by_sequence[model_name][elem.seq_2]
        emb = torch.cat([torch.from_numpy(emb1), torch.from_numpy(emb2)])
        embeddings.append(emb)

    datasets[data_type] = TensorDataset(
        torch.stack(embeddings),
        torch.tensor(dfs_lmdb[data_type].label)
    )

all_metrics = []
batch_size = 1000
reps = 1

monitor_metric="Accuracy"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs=100
input_size = datasets['train'][0][0].shape[-1]
output_size = 1
lr = 1e-4


for rep in range(reps):
    train_loader = torch.utils.data.DataLoader(
        datasets['train'], batch_size=batch_size, shuffle=True
    )
    valid_loader = torch.utils.data.DataLoader(
        datasets['valid'], batch_size=batch_size, shuffle=False
    )
    test_loader = torch.utils.data.DataLoader(
        datasets['test'], batch_size=batch_size, shuffle=False
    )

    torch.manual_seed(rep)

    model = SimpleMLP(input_size, output_size).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=10
    )

    loss_fn = nn.BCEWithLogitsLoss()

    best_valid_metric = 0

    for epoch in tqdm(range(num_epochs)):
        loss_accum = 0
        model.train()
        for step, train_batch in enumerate(train_loader):
            optimizer.zero_grad()

            embs, target = train_batch
            embs = embs.to(device)
            target = target.to(device)

            pred = model(embs)
            loss = loss_fn(pred.squeeze(), target.float())

            loss.backward()
            optimizer.step()

            loss_accum += loss.detach().cpu().item()

        test_metrics_dict = evaluate(model, test_loader, "test", device)
        valid_metrics_dict = evaluate(model, valid_loader, "valid", device)

        if valid_metrics_dict[f"valid_{monitor_metric}"] >= best_valid_metric:
            best_test_metrics = test_metrics_dict
            best_valid_metric = valid_metrics_dict[f"valid_{monitor_metric}"]

        train_loss = loss_accum / (step + 1)
        scheduler.step(train_loss)

    best_test_metrics = {f"best_{k}": i for k, i in best_test_metrics.items()}

    all_metrics.append(best_test_metrics)

print(calculate_mean_std(all_metrics))

model_name='esm3_sm_open_v1'


  0%|          | 0/100 [00:00<?, ?it/s]

{'best_test_Accuracy_mean': np.float64(0.7944444444444444), 'best_test_Accuracy_std': np.float64(0.0), 'best_test_AUPRC_mean': np.float64(0.8571600308454443), 'best_test_AUPRC_std': np.float64(0.0), 'best_test_F1 Score_mean': np.float64(0.7784431137724551), 'best_test_F1 Score_std': np.float64(0.0), 'best_test_AUROC_mean': np.float64(0.8927625772285968), 'best_test_AUROC_std': np.float64(0.0)}
