In [17]:
MODELS = [
    "esm3_sm_open_v1",
    "esm3-large-2024-03",
    "facebook/esm2_t6_8M_UR50D",
    "facebook/esm2_t12_35M_UR50D",
    "facebook/esm2_t30_150M_UR50D",
    "facebook/esm2_t33_650M_UR50D",
    "facebook/esm2_t36_3B_UR50D",
    "facebook/esm2_t48_15B_UR50D",
]

In [18]:
import lmdb
import json
import pandas as pd
import torch
import numpy as np
import dgeb
from pathlib import Path


dfs_lmdb = {}
for data_type in ['train', 'valid', 'test']:
    env = lmdb.open(f'HumanPPI/normal/{data_type}/')
    res = []

    with env.begin() as txn:
        for key, value in txn.cursor():
            if key == b'info' or key == b'length':
                continue
            res.append(json.loads(value))
    
    env.close()

    dfs_lmdb[data_type] = pd.DataFrame(res)

df_lmdb = pd.concat(dfs_lmdb.values())

In [19]:
df_lmdb

Unnamed: 0,name_1,name_2,seq_1,seq_2,label
0,Q01780,Q9Y333,MAPPSTREPRVLSATSATKSDGEMVLPGFPDADSFVKFALGSVVAV...,MLFYSFFKSLVGKDVVVELKNDLSICGTLHSVDQYLNIKLTDISVT...,1
1,Q9P104,P06213,MASNFNDIVKQGYVRIRSRRLGIYQRCWLVFKKASSKGPKRLEKFS...,MATGGRRGAAAAPLLVAVAALLLGAAGHLYPGEVCPGMDIRNNLTR...,1
2,O00300,P04004,MNNLLCCALVFLDISIKWTTQETFPPKYLHYDEETSHQLLCDKCPP...,MAPLRPLLILALLAWVALADQESCKGRCTEGFNVDKKCQCDELCSY...,1
3,Q9UNY4,P22626,MEEVRCPEHGTFCFLKTGVRDGPNKGKSFYVCRADTCSFVRATDIP...,MEKTLETVPLERKKREKEQFRKLFIGGLSFETTEESLRNYYEQWGK...,1
4,Q15139,Q02156,MSAPPVLRPPSPLLPVAAAAAAAAAALVPGSGPGPAPFLAPVAAPV...,MVVFNGLLKIKICEAVSLKPTAWSLRHAVGPRPQTFLLDPYIALNV...,1
...,...,...,...,...,...
175,P09429,Q92552,MGKGDPKKPRGKMSSYAFFVQTCREEHKKKHPDASVNFSEFSKKCS...,MAASIVRRGMLLARQVVLPQLSPAGKRYLLSSAYVDSHKWEAREKE...,0
176,O94763,Q96M27,MEAPTVETPPDPSPPSAPAPALVPLRAPDVARLREEQEKVVTNCQE...,MMEESGIETTPPGTPPPNPAGLAATAMSSTPVPLAATSSFSSPNVS...,0
177,Q86TM3,Q96M27,MSHWAPEWKRAEANPRDLGASWDVRGSRGSGWSGPFGHQGPRAAGS...,MMEESGIETTPPGTPPPNPAGLAATAMSSTPVPLAATSSFSSPNVS...,0
178,O15151,Q9BY32,MTSFSTSAQCSTSDSACRISPGQINQVRPKLPLLKILHAAGAQGEM...,MAASLVGKKIVFVTGNAKKLEEVVQILGDKFPCTLVAQKIDLPEYQ...,0


In [20]:
def get_embeddings_df(sequences, model_name):
    model_name_for_file = model_name.replace('/', '-').replace(' ', '-').replace('_', '-')
    f = Path(f'embeddings_{model_name_for_file}.parquet')
    if f.exists():
        return pd.read_parquet(f)

    model = dgeb.get_model(model_name, layers="last", batch_size=1, max_seq_length=2048)

    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        embeddings = model.encode(sequences)

    df = pd.DataFrame({'sequence': sequences, 'embedding': list(embeddings.squeeze())})
    df.to_parquet(f)
    return df

sequences = list(set(df_lmdb.seq_1) | set(df_lmdb.seq_2))

embedding_by_sequence = {}
for model_name in MODELS:
    embeddings_df = get_embeddings_df(sequences, model_name)
    embedding_by_sequence[model_name] = {}
    for elem in embeddings_df.itertuples():
        embedding_by_sequence[model_name][elem.sequence] = elem.embedding


datasets = {}
for model_name in MODELS:
    print(f'{model_name=}')

    datasets[model_name] = {}

    for data_type in ['train', 'valid', 'test']:
        embeddings = []
        for elem in dfs_lmdb[data_type].itertuples():
            emb1 = embedding_by_sequence[model_name][elem.seq_1]
            emb2 = embedding_by_sequence[model_name][elem.seq_2]
            emb = torch.cat([torch.from_numpy(emb1), torch.from_numpy(emb2)])
            embeddings.append(emb)

        datasets[model_name][data_type] = TensorDataset(
            torch.stack(embeddings).float(),
            torch.tensor(dfs_lmdb[data_type].label)
        )

model_name='esm3_sm_open_v1'
model_name='esm3-large-2024-03'
model_name='facebook/esm2_t6_8M_UR50D'
model_name='facebook/esm2_t12_35M_UR50D'
model_name='facebook/esm2_t30_150M_UR50D'
model_name='facebook/esm2_t33_650M_UR50D'
model_name='facebook/esm2_t36_3B_UR50D'
model_name='facebook/esm2_t48_15B_UR50D'


In [12]:
import wandb

from uuid import uuid4
from torch import nn
from tqdm.auto import tqdm
from torch.utils.data import TensorDataset

from sklearn.metrics import accuracy_score


torch.manual_seed(1488)


wandb.login(key="03c57b40de2b2f02c0d3c1357868f3ef656696ce")


class SimpleMLP(nn.Module):
    def __init__(self, input_size, output_size, dropout):
        super().__init__()

        self.project = nn.Sequential(
            nn.Linear(input_size, input_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(input_size, output_size),
        )

    def forward(self, x):
        return self.project(x)


def train(dataloader, model, loss_fn, optimizer, device):
    model.train()
    loss_accum = 0
    for step, (embs, target) in enumerate(dataloader):
        embs = embs.to(device)
        target = target.to(device)
        pred = model(embs)
        loss = loss_fn(pred.squeeze(), target.float())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        loss_accum += loss.detach().cpu().item()
    train_loss = loss_accum / (step + 1)
    scheduler.step(train_loss)
    return train_loss


def get_accuracy(dataloader, model, device):
    model.eval()
    preds = []
    targets = []
    with torch.no_grad():
        for step, eval_batch in enumerate(dataloader):
            embs, target = eval_batch
            embs = embs.to(device)
            target = target.to(device)
            pred = model(embs).squeeze(-1)
            pred = torch.sigmoid(pred)

            preds.append(pred.detach().cpu().numpy())
            targets.append(target.cpu().numpy())
    preds = np.concatenate(preds)
    targets = np.concatenate(targets)

    threshold = 0.5
    binary_predictions = (preds >= threshold).astype(int)
    accuracy = accuracy_score(targets, binary_predictions)

    return accuracy


# model_name = MODELS[0]

best_models_info_file = Path('best_models.csv')
best_models = []

for model_name in MODELS:
    print(f'{model_name=}')
    model_name_for_file = model_name.replace('/', '-').replace(' ', '-').replace('_', '-')


    batch_size = 1000
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_epochs = 100
    input_size = datasets['train'][0][0].shape[-1]
    output_size = 1
    lr = 1e-4
    dropout = 0.2


    train_loader = torch.utils.data.DataLoader(
        datasets[model_name]['train'], batch_size=batch_size, shuffle=True
    )
    valid_loader = torch.utils.data.DataLoader(
        datasets[model_name]['valid'], batch_size=batch_size, shuffle=False
    )


    model = SimpleMLP(input_size, output_size, dropout).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=10
    )
    loss_fn = nn.BCEWithLogitsLoss()

    best_valid_accuracy = 0

    model_uuid = str(uuid4())


    run = wandb.init(
        project="mint",
        config={
            "model_name": model_name,
            "model_uuid": model_uuid,
            "dropout": dropout,
            "batch_size": batch_size,
            "num_epochs": num_epochs,
            "lr": lr,
        },
        settings=wandb.Settings(silent=True)
    )
    run.watch(model)

    for epoch in tqdm(range(num_epochs)):
        train_loss = train(train_loader, model, loss_fn, optimizer, device)
        valid_accuracy = get_accuracy(valid_loader, model, device)
        if valid_accuracy > best_valid_accuracy:
            best_valid_accuracy = valid_accuracy
            model_filename = f"best-model_{model_name_for_file}_{model_uuid}.pth"
            torch.save(model.state_dict(), model_filename)

            model_info = {
                "model_name": model_name,
                "model_uuid": model_uuid,
                "dropout": dropout,
                "batch_size": batch_size,
                "num_epochs": num_epochs,
                "lr": lr,
                "model_filename": model_filename,
                "train_loss": train_loss,
                "valid_accuracy": valid_accuracy,
                "epoch": epoch,
                "wandb_url": run.url,
                "input_size": input_size,
                "output_size": output_size,
            }

        run.log({
            "train_loss": train_loss,
            "valid_accuracy": valid_accuracy,
        })

    run.finish()

    best_models.append(model_info)

pd.DataFrame(best_models).to_csv(
    best_models_info_file,
    index=False,
    mode='a',
    header=not best_models_info_file.exists()
)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ubuntu/.netrc


model_name='esm3_sm_open_v1'


  0%|          | 0/100 [00:00<?, ?it/s]

model_name='esm3-large-2024-03'


  0%|          | 0/100 [00:00<?, ?it/s]

model_name='facebook/esm2_t6_8M_UR50D'


  0%|          | 0/100 [00:00<?, ?it/s]

model_name='facebook/esm2_t12_35M_UR50D'


  0%|          | 0/100 [00:00<?, ?it/s]

model_name='facebook/esm2_t30_150M_UR50D'


  0%|          | 0/100 [00:00<?, ?it/s]

model_name='facebook/esm2_t33_650M_UR50D'


  0%|          | 0/100 [00:00<?, ?it/s]

model_name='facebook/esm2_t36_3B_UR50D'


  0%|          | 0/100 [00:00<?, ?it/s]

model_name='facebook/esm2_t48_15B_UR50D'


  0%|          | 0/100 [00:00<?, ?it/s]

In [22]:
df = pd.read_csv(best_models_info_file)
df

Unnamed: 0,model_name,model_uuid,dropout,batch_size,num_epochs,lr,model_filename,train_loss,valid_accuracy,epoch,wandb_url,input_size,output_size
0,esm3_sm_open_v1,cf052e28-82d5-42a1-8123-be62ad4054dc,0.2,1000,100,0.0001,best-model_esm3-sm-open-v1_cf052e28-82d5-42a1-...,0.369579,0.854701,17,https://wandb.ai/tehadawest-no/mint/runs/4mex867u,3072,1
1,esm3-large-2024-03,b92c64dd-5a22-462c-8a9c-aa578a196a8b,0.2,1000,100,0.0001,best-model_esm3-large-2024-03_b92c64dd-5a22-46...,0.137056,0.884615,8,https://wandb.ai/tehadawest-no/mint/runs/2rasvum2,12288,1
2,facebook/esm2_t6_8M_UR50D,78a93caf-3d60-49b2-af25-19dd4150a6d1,0.2,1000,100,0.0001,best-model_facebook-esm2-t6-8M-UR50D_78a93caf-...,0.222356,0.858974,33,https://wandb.ai/tehadawest-no/mint/runs/dutx7pnw,640,1
3,facebook/esm2_t12_35M_UR50D,1a5c176b-cb71-4542-817d-6e42d4f59ea0,0.2,1000,100,0.0001,best-model_facebook-esm2-t12-35M-UR50D_1a5c176...,0.154534,0.871795,21,https://wandb.ai/tehadawest-no/mint/runs/qxyft3fy,960,1
4,facebook/esm2_t30_150M_UR50D,16eab1a4-98ea-4748-9838-de30db793603,0.2,1000,100,0.0001,best-model_facebook-esm2-t30-150M-UR50D_16eab1...,0.202061,0.888889,7,https://wandb.ai/tehadawest-no/mint/runs/tamt86iu,1280,1
5,facebook/esm2_t33_650M_UR50D,dd3dbacf-5a27-4429-a454-a63064d75719,0.2,1000,100,0.0001,best-model_facebook-esm2-t33-650M-UR50D_dd3dba...,0.218106,0.876068,3,https://wandb.ai/tehadawest-no/mint/runs/rylvtagc,2560,1
6,facebook/esm2_t36_3B_UR50D,54b9075d-8e09-44c2-900f-7161f2fc66a5,0.2,1000,100,0.0001,best-model_facebook-esm2-t36-3B-UR50D_54b9075d...,0.281358,0.91453,2,https://wandb.ai/tehadawest-no/mint/runs/3wglz8in,5120,1
7,facebook/esm2_t48_15B_UR50D,c7cd1cb6-0f0f-4e02-bf1e-233348ef347e,0.2,1000,100,0.0001,best-model_facebook-esm2-t48-15B-UR50D_c7cd1cb...,0.190633,0.863248,1,https://wandb.ai/tehadawest-no/mint/runs/3ewccthb,10240,1


In [21]:
for elem in df.itertuples():
    test_loader = torch.utils.data.DataLoader(
        datasets[elem.model_name]['test'], batch_size=elem.batch_size, shuffle=False
    )

    model = SimpleMLP(elem.input_size, elem.output_size, elem.dropout)
    model.load_state_dict(torch.load(elem.model_filename))
    model.to(device)
    test_accuracy = get_accuracy(test_loader, model, device)
    print(f'{elem.model_name}: {test_accuracy=}')

esm3_sm_open_v1: test_accuracy=0.8444444444444444
esm3-large-2024-03: test_accuracy=0.8888888888888888
facebook/esm2_t6_8M_UR50D: test_accuracy=0.8055555555555556
facebook/esm2_t12_35M_UR50D: test_accuracy=0.8166666666666667
facebook/esm2_t30_150M_UR50D: test_accuracy=0.8722222222222222
facebook/esm2_t33_650M_UR50D: test_accuracy=0.8611111111111112
facebook/esm2_t36_3B_UR50D: test_accuracy=0.8611111111111112
facebook/esm2_t48_15B_UR50D: test_accuracy=0.8611111111111112


In [16]:
elem

Pandas(Index=0, model_name='esm3_sm_open_v1', model_uuid='cf052e28-82d5-42a1-8123-be62ad4054dc', dropout=0.2, batch_size=1000, num_epochs=100, lr=0.0001, model_filename='best-model_esm3-sm-open-v1_cf052e28-82d5-42a1-8123-be62ad4054dc.pth', train_loss=0.3695792693782735, valid_accuracy=0.8547008547008547, epoch=17, wandb_url='https://wandb.ai/tehadawest-no/mint/runs/4mex867u', input_size=3072, output_size=1)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output


class LivePlot:
    def __init__(self, num_epochs):
        self.x_data = []
        self.y_data = []
        self.num_epochs = num_epochs

    def plot_data(self, x, x_label, y, y_label, model_name):
        self.x_data.append(x)
        self.y_data.append(y)

        clear_output(wait=True)
        plt.figure(figsize=(10, 4))
        plt.plot(self.x_data, self.y_data, marker='o')
        plt.title(model_name)
        plt.xlabel(x_label)
        plt.ylabel(y_label)
        plt.grid(True)
        plt.xlim(0, self.num_epochs)
        plt.tight_layout()
        plt.show()


my_live_plot = LivePlot(num_epochs)
plt.ion()

my_live_plot.plot_data(epoch, "epoch", train_loss, "train_loss", model_name)