In [None]:
%pip install pytorch-lightning
%pip install git+https://github.com/facebookresearch/esm.git

In [None]:
# settings
EMBEDDING_SIZE = 768
BATCH_SIZE = 8
MAX_LENGTH = 510 # for fixed max length batching
MAX_TOKENS_PER_BATCH = 4096 # for dynamic batching

# **Imports and Model**

In [None]:
from esm import Alphabet, FastaBatchedDataset, pretrained
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
import pickle

def read_fasta(fastafile):
    """Parse a file with sequences in FASTA format and store in a dict"""
    with open(fastafile, 'r') as f:
        content = [l.strip() for l in f.readlines()]

    res = {}
    seq, seq_id = '', None
    for line in content:
        if line.startswith('>'):
            
            if len(seq) > 0:
                res[seq_id] = seq
            
            seq_id = line.replace('>', '')
            seq = ''
        else:
            seq += line
    res[seq_id] = seq
    return res

In [None]:
# title finetuning model mean

class ESMFinetune(pl.LightningModule):
    def __init__(self):
        super().__init__()
        model, alphabet = pretrained.load_model_and_alphabet("esm1_t12_85M_UR50S")
        self.model = model
        self.clf_head = nn.Linear(768, 1)

        # The ESM 12 model does not have a layer norm before MLM. Therefore the 768 feature output has spikes.
        # We found no difference in performance by adding this. 
        with open("ESM12_Layer12_Norm.pkl", "rb") as f:
            final_scaling = pickle.load(f)
        self.scaling_mean = torch.tensor(final_scaling["mean"], device="cuda", requires_grad=False)
        self.scaling_std = torch.tensor(final_scaling["std"], device="cuda", requires_grad=False)
        self.final_ln = nn.LayerNorm(768)
        self.lr = 2e-5
    def forward(self, toks, lens, non_mask):
        # in lightning, forward defines the prediction/inference actions
        x = self.model(toks, repr_layers=[12])
        x = x["representations"][12]
        x = (x- self.scaling_mean) / self.scaling_std
        x = self.final_ln(x)
        x_mean = (x * non_mask[:,:,None]).sum(1) / lens[:,None]
        x = self.clf_head(x_mean)
        return x.squeeze() 

    def configure_optimizers(self):
        grouped_parameters = [
            {"params": [p for n, p in self.model.named_parameters()], 'lr': 3e-6},
            {"params": [p for n, p in self.clf_head.named_parameters()] + [p for n, p in self.final_ln.named_parameters()], 'lr': 2e-5},
        ]
        optimizer = torch.optim.AdamW(grouped_parameters, lr=self.lr)
        return optimizer

    def training_step(self, batch, batch_idx):
        #self.unfreeze()
        x, l, n, y, _ = batch
        y_pred =  self.forward(x, l, n)
        loss = F.binary_cross_entropy_with_logits(y_pred, y)
        self.log('train_loss_batch', loss)
        return {'loss': loss}

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        self.log('train_loss', avg_loss, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        #self.freeze()
        x, l, n, y, _ = batch
        y_pred =  self.forward(x, l, n)
        correct = ((y_pred>0) == y).sum()
        count = y.size(0)
        loss = F.binary_cross_entropy_with_logits(y_pred, y)
        self.log('val_loss_batch', loss)
        return {'loss': loss, 'correct':correct, "count":count}
  
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        self.log('val_loss', avg_loss, prog_bar=True)
        avg_acc = torch.tensor([x['correct'] for x in outputs]).sum() / torch.tensor([x['count'] for x in outputs]).sum()
        self.log('val_acc', avg_acc, prog_bar=True)

# **TRAINING**

Pick one for training: Dynamic batching or fixed max length.

In [None]:
# @title Dynamic batching
import re
import torch
import random

class FastaBatchedDataset(torch.utils.data.Dataset):
    def __init__(self, data_df):
        self.data_df = data_df

    def __len__(self):
        return len(self.data_df)
    
    def shuffle(self):
        self.data_df = self.data_df.sample(frac=1).reset_index(drop=True)

    def __getitem__(self, idx):
        return self.data_df["fasta"][idx], self.data_df["solubility"][idx], self.data_df["sid"][idx]

    def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0):
        sizes = [(len(s), i) for i, s in enumerate(self.data_df["fasta"])]
        sizes.sort()
        batches = []
        buf = []
        max_len = 0

        def _flush_current_buf():
            nonlocal max_len, buf
            if len(buf) == 0:
                return
            batches.append(buf)
            buf = []
            max_len = 0
        start = 0
        #start = random.randint(0, len(sizes))
        for j in range(len(sizes)):
            i = (start + j) % len(sizes)
            sz = sizes[i][0]
            idx = sizes[i][1]    
            sz += extra_toks_per_seq
            if (max(sz, max_len) * (len(buf) + 1) > toks_per_batch):
                _flush_current_buf()
            max_len = max(max_len, sz)
            buf.append(idx)

        _flush_current_buf()
        return batches

class BatchConverter(object):
    """Callable to convert an unprocessed (labels + strings) batch to a
    processed (labels + tensor) batch.
    """

    def __init__(self, alphabet):
        self.alphabet = alphabet

    def __call__(self, raw_batch):
        # RoBERTa uses an eos token, while ESM-1 does not.
        batch_size = len(raw_batch)
        #print(len(raw_batch[0]), raw_batch[1], raw_batch[2])
        max_len = max(len(seq_str) for seq_str, _, _ in raw_batch)
        tokens = torch.empty((batch_size, max_len + int(self.alphabet.prepend_bos) + \
            int(self.alphabet.append_eos)), dtype=torch.int64)
        tokens.fill_(self.alphabet.padding_idx)
        labels = []
        lengths = []
        strs = []
        targets = torch.zeros((batch_size,), dtype=torch.float32)
        for i, (seq_str, target, label) in enumerate(raw_batch):
            #seq_str = seq_str[1:]
            labels.append(label)
            lengths.append(len(seq_str))
            strs.append(seq_str)
            targets[i] = target
            if self.alphabet.prepend_bos:
                tokens[i, 0] = self.alphabet.cls_idx
            seq = torch.tensor([self.alphabet.get_idx(s) for s in seq_str], dtype=torch.int64)
            tokens[i, int(self.alphabet.prepend_bos) : len(seq_str) + int(self.alphabet.prepend_bos)] = seq
            if self.alphabet.append_eos:
                tokens[i, len(seq_str) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
        
        non_pad_mask = ~tokens.eq(self.alphabet.padding_idx) &\
         ~tokens.eq(self.alphabet.cls_idx) &\
         ~tokens.eq(self.alphabet.eos_idx)# B, T

        return tokens, torch.tensor(lengths), non_pad_mask, targets, labels

class Alphabet(object):
    prepend_toks = ("<null_0>", "<pad>", "<eos>", "<unk>")
    append_toks = ("<cls>", "<mask>", "<sep>")
    prepend_bos = True
    append_eos = False

    def __init__(self, standard_toks):
        self.standard_toks = list(standard_toks)

        self.all_toks = list(self.prepend_toks)
        self.all_toks.extend(self.standard_toks)
        for i in range((8 - (len(self.all_toks) % 8)) % 8):
            self.all_toks.append(f"<null_{i  + 1}>")
        self.all_toks.extend(self.append_toks)

        self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}

        self.unk_idx = self.tok_to_idx["<unk>"]
        self.padding_idx = self.get_idx("<pad>")
        self.cls_idx = self.get_idx("<cls>")
        self.mask_idx = self.get_idx("<mask>")
        self.eos_idx = self.get_idx("<eos>")

    def __len__(self):
        return len(self.all_toks)

    def get_idx(self, tok):
        return self.tok_to_idx.get(tok, self.unk_idx)

    def get_tok(self, ind):
        return self.all_toks[ind]

    def to_dict(self):
        return {"toks": self.toks}

    def get_batch_converter(self):
        return BatchConverter(self)

    @classmethod
    def from_dict(cls, d):
        return cls(standard_toks=d["toks"])


class NewAlphabet(Alphabet):
    def __init__(self, alphabet):
        self.alphabet = alphabet
    def get_batch_converter(self):
        return BatchConverter(self.alphabet)

In [None]:
# @title Fixed Max Length
import re
import torch
import random

class FastaDataset(object):
    def __init__(self, data_df):
        self.data_df = data_df

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx):
        return self.data_df["fasta"][idx],self.data_df["solubility"][idx],self.data_df["sid"][idx]

class BatchConverter(object):
    """Callable to convert an unprocessed (labels + strings) batch to a
    processed (labels + tensor) batch.
    """

    def __init__(self, alphabet, crop):
        self.alphabet = alphabet
        self.crop = crop

    def __call__(self, raw_batch):
        # RoBERTa uses an eos token, while ESM-1 does not.
        batch_size = len(raw_batch)
        tokens = torch.empty((batch_size, MAX_LENGTH + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos)), dtype=torch.int64)
        tokens.fill_(self.alphabet.padding_idx)
        labels = []
        lengths = []
        strs = []
        targets = torch.zeros((batch_size,), dtype=torch.float32)
        for i, (seq_str, target, label) in enumerate(raw_batch):
            labels.append(label)
            lengths.append(len(seq_str))
            strs.append(seq_str)
            targets[i] = target
            if self.alphabet.prepend_bos:
                tokens[i, 0] = self.alphabet.cls_idx
            seq = torch.tensor([self.alphabet.get_idx(s) for s in seq_str], dtype=torch.int64)
            tokens[i, int(self.alphabet.prepend_bos) : len(seq_str) + int(self.alphabet.prepend_bos)] = seq
            if self.alphabet.append_eos:
                tokens[i, len(seq_str) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx

        non_pad_mask = ~tokens.eq(self.alphabet.padding_idx) &\
         ~tokens.eq(self.alphabet.cls_idx) &\
         ~tokens.eq(self.alphabet.eos_idx)# B, T
            
        return tokens, torch.tensor(lengths), non_pad_mask, targets, labels # dct_mat, idct_mat, 

class Alphabet(object):
    prepend_toks = ("<null_0>", "<pad>", "<eos>", "<unk>")
    append_toks = ("<cls>", "<mask>", "<sep>")
    prepend_bos = True
    append_eos = False

    def __init__(self, standard_toks):
        self.standard_toks = list(standard_toks)

        self.all_toks = list(self.prepend_toks)
        self.all_toks.extend(self.standard_toks)
        for i in range((8 - (len(self.all_toks) % 8)) % 8):
            self.all_toks.append(f"<null_{i  + 1}>")
        self.all_toks.extend(self.append_toks)

        self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}

        self.unk_idx = self.tok_to_idx["<unk>"]
        self.padding_idx = self.get_idx("<pad>")
        self.cls_idx = self.get_idx("<cls>")
        self.mask_idx = self.get_idx("<mask>")
        self.eos_idx = self.get_idx("<eos>")

    def __len__(self):
        return len(self.all_toks)

    def get_idx(self, tok):
        return self.tok_to_idx.get(tok, self.unk_idx)

    def get_tok(self, ind):
        return self.all_toks[ind]

    def to_dict(self):
        return {"toks": self.toks}

    def get_batch_converter(self):
        return BatchConverter(self)

    @classmethod
    def from_dict(cls, d):
        return cls(standard_toks=d["toks"])


class NewAlphabet(Alphabet):
    def __init__(self, alphabet):
        self.alphabet = alphabet
    def get_batch_converter(self):
        return BatchConverter(self.alphabet)

Input should be a dataframe with these 3 columns: sid, solubility, fasta
Based on the choice between dynamic and fixed max length, comment out the filtering based on sequence length.

In [None]:
# @title Dataset
import pandas as pd 
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np

FASTA_FILE = "../Datasets/PSI_Biology/pET_full_without_his_tag.fa"
LABELS_FILE = "../Datasets/PSI_Biology/class.txt"
CLUSTERS_FILE = "../Datasets/PSI_Biology/psi_biology_nesg_partitioning_wl_th025_amT.csv"

labels_df = pd.read_csv(LABELS_FILE, delimiter="\t")
labels_df.columns = ["sid", "solubility"]
labels_df.solubility = labels_df.solubility -1

fasta_dict = read_fasta(FASTA_FILE)
fasta_df = pd.DataFrame(fasta_dict.items(), columns=['Accession', 'fasta'])
fasta_df["sid"] = fasta_df.Accession.apply(lambda x: x.split("_")[0])
print(len(fasta_df))

data_df = labels_df.merge(fasta_df)

clusters_df = pd.read_csv(CLUSTERS_FILE)
clusters_df.columns = ["sid","priority","label-val","between_connectivity","cluster"]

data_df = data_df.merge(clusters_df)

print(len(data_df))
newalphabet_train = NewAlphabet(alphabet)
newalphabet_val = NewAlphabet(alphabet)

def get_split(i):
    train_df = data_df[data_df.cluster != i]

    # Ignore if dynamic batching
    train_df["lengths"] = train_df["fasta"].apply(lambda x: len(x))
    train_df = train_df[train_df["lengths"] <= MAX_LENGTH].reset_index(drop=True)
    
    X, y = np.stack(train_df["sid"].to_numpy()), np.stack(train_df["solubility"].to_numpy())
    sss_tt = StratifiedShuffleSplit(n_splits=1, test_size=512, random_state=0)
    
    (split_train_idx, split_val_idx) = next(sss_tt.split(X, y))
    split_train_df =  train_df.iloc[split_train_idx].reset_index(drop=True)
    split_val_df = train_df.iloc[split_val_idx].reset_index(drop=True)
    print(len(split_train_df))

    train_dataset = FastaDataset(split_train_df)
    train_dataloader = torch.utils.data.DataLoader(
      train_dataset,
      collate_fn=newalphabet_train.get_batch_converter(),
      shuffle=True,
      batch_size=BATCH_SIZE,
      num_workers=4,
      #pin_memory=True,
      drop_last=True)

    val_dataset = FastaDataset(split_val_df)
    val_dataloader = torch.utils.data.DataLoader(
      val_dataset,
      collate_fn=newalphabet_val.get_batch_converter(),
      #num_workers=4,
      shuffle=False,
        batch_size=BATCH_SIZE)
    
    return train_dataloader, val_dataloader

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

path = "./models/"

def train_model(idx):

    train_dataloader, val_dataloader = get_next_split(idx)

    early_stop_callback = EarlyStopping(
        monitor='val_loss',
        min_delta=0.00,
        patience=3,
        verbose=False,
        mode='min'
    )

    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath=path,
        filename= f"{idx}" + 'PSISplit-{epoch:02d}-{val_loss:.2f}',
        period=1,
        save_top_k=1,
        save_last=False
    )


    # Initialize trainer
    trainer = pl.Trainer(max_epochs=3, 
                        check_val_every_n_epoch=1, 
                        default_root_dir=path + f"{idx}",
                        callbacks=[early_stop_callback, checkpoint_callback],
                        precision=16,
                        progress_bar_refresh_rate=1000,
                        accumulate_grad_batches=4,
                        gpus=1)
    clf = ESMFinetune()
    print(f"Training clf {idx}")
    trainer.fit(clf, train_dataloader, val_dataloader)
    return trainer

In [None]:
for i in range(5):
  train_model(i)

# **TESTING**

In [None]:
# @title Dynamic batching
import re
import torch
import random

class FastaBatchedDataset(torch.utils.data.Dataset):
    def __init__(self, data_df):
        self.data_df = data_df

    def __len__(self):
        return len(self.data_df)
    
    def shuffle(self):
        self.data_df = self.data_df.sample(frac=1).reset_index(drop=True)

    def __getitem__(self, idx):
        return self.data_df["fasta"][idx], self.data_df["solubility"][idx], self.data_df["sid"][idx]

    def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0):
        sizes = [(len(s), i) for i, s in enumerate(self.data_df["fasta"])]
        sizes.sort()
        batches = []
        buf = []
        max_len = 0

        def _flush_current_buf():
            nonlocal max_len, buf
            if len(buf) == 0:
                return
            batches.append(buf)
            buf = []
            max_len = 0
        start = 0
        #start = random.randint(0, len(sizes))
        for j in range(len(sizes)):
            i = (start + j) % len(sizes)
            sz = sizes[i][0]
            idx = sizes[i][1]    
            sz += extra_toks_per_seq
            if (max(sz, max_len) * (len(buf) + 1) > toks_per_batch):
                _flush_current_buf()
            max_len = max(max_len, sz)
            buf.append(idx)

        _flush_current_buf()
        return batches

class BatchConverter(object):
    """Callable to convert an unprocessed (labels + strings) batch to a
    processed (labels + tensor) batch.
    """

    def __init__(self, alphabet):
        self.alphabet = alphabet

    def __call__(self, raw_batch):
        # RoBERTa uses an eos token, while ESM-1 does not.
        batch_size = len(raw_batch)
        #print(len(raw_batch[0]), raw_batch[1], raw_batch[2])
        max_len = max(len(seq_str) for seq_str, _, _ in raw_batch)
        tokens = torch.empty((batch_size, max_len + int(self.alphabet.prepend_bos) + \
            int(self.alphabet.append_eos)), dtype=torch.int64)
        tokens.fill_(self.alphabet.padding_idx)
        labels = []
        lengths = []
        strs = []
        targets = torch.zeros((batch_size,), dtype=torch.float32)
        for i, (seq_str, target, label) in enumerate(raw_batch):
            #seq_str = seq_str[1:]
            labels.append(label)
            lengths.append(len(seq_str))
            strs.append(seq_str)
            targets[i] = target
            if self.alphabet.prepend_bos:
                tokens[i, 0] = self.alphabet.cls_idx
            seq = torch.tensor([self.alphabet.get_idx(s) for s in seq_str], dtype=torch.int64)
            tokens[i, int(self.alphabet.prepend_bos) : len(seq_str) + int(self.alphabet.prepend_bos)] = seq
            if self.alphabet.append_eos:
                tokens[i, len(seq_str) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
        
        non_pad_mask = ~tokens.eq(self.alphabet.padding_idx) &\
         ~tokens.eq(self.alphabet.cls_idx) &\
         ~tokens.eq(self.alphabet.eos_idx)# B, T

        return tokens, torch.tensor(lengths), non_pad_mask, targets, labels

class Alphabet(object):
    prepend_toks = ("<null_0>", "<pad>", "<eos>", "<unk>")
    append_toks = ("<cls>", "<mask>", "<sep>")
    prepend_bos = True
    append_eos = False

    def __init__(self, standard_toks):
        self.standard_toks = list(standard_toks)

        self.all_toks = list(self.prepend_toks)
        self.all_toks.extend(self.standard_toks)
        for i in range((8 - (len(self.all_toks) % 8)) % 8):
            self.all_toks.append(f"<null_{i  + 1}>")
        self.all_toks.extend(self.append_toks)

        self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}

        self.unk_idx = self.tok_to_idx["<unk>"]
        self.padding_idx = self.get_idx("<pad>")
        self.cls_idx = self.get_idx("<cls>")
        self.mask_idx = self.get_idx("<mask>")
        self.eos_idx = self.get_idx("<eos>")

    def __len__(self):
        return len(self.all_toks)

    def get_idx(self, tok):
        return self.tok_to_idx.get(tok, self.unk_idx)

    def get_tok(self, ind):
        return self.all_toks[ind]

    def to_dict(self):
        return {"toks": self.toks}

    def get_batch_converter(self):
        return BatchConverter(self)

    @classmethod
    def from_dict(cls, d):
        return cls(standard_toks=d["toks"])


class NewAlphabet(Alphabet):
    def __init__(self, alphabet):
        self.alphabet = alphabet
    def get_batch_converter(self):
        return BatchConverter(self.alphabet)

In [None]:
from sklearn.metrics import accuracy_score, precision_score, roc_auc_score, matthews_corrcoef, roc_curve
import numpy as np
optimal_thresholds = []

def evaluate_split(split_i, test_df):
    probs = np.stack(test_df["preds"].to_numpy())
    y_test = np.stack(test_df["solubility"].to_numpy())

    preds = probs>0.5
    
    # youden index
    fpr, tpr, thresholds = metrics.roc_curve(y_test,net1_probs_test)
    optimal_idx = np.argmax(tpr - fpr)
    optimal_threshold = thresholds[optimal_idx]
    optimal_thresholds.append(optimal_threshold)

    acc = accuracy_score(y_test, preds)
    pre = precision_score(y_test, preds)
    mcc = matthews_corrcoef(y_test, preds)
    auc = roc_auc_score(y_test, probs)
    print(f"Fold{split_i}- Acc: {acc:.3f}, Pre: {pre:.3f}, MCC: {mcc:.3f}, AUC: {auc:.3f}\n")
    return acc, pre, mcc, auc

def predict(df, clf):
    test_df = df
    print(len(test_df))
    newalphabet = NewAlphabet(alphabet)
    embed_dataset = FastaBatchedDataset(test_df)
    embed_batches = embed_dataset.get_batch_indices(MAX_TOKENS_PER_BATCH, extra_toks_per_seq=1)
    embed_dataloader = torch.utils.data.DataLoader(embed_dataset, collate_fn=newalphabet.get_batch_converter(), batch_sampler=embed_batches)

    embed_dict = {}
    with torch.no_grad():
      for i, (toks, lengths, np_mask, targets, labels) in enumerate(embed_dataloader):
          x = torch.sigmoid(clf(toks.to("cuda"), lengths.to("cuda"), np_mask.to("cuda"))).cpu().numpy()
          for j in range(len(labels)):
              if len(labels) == 1:
                embed_dict[labels[j]] = x
              else:
                embed_dict[labels[j]] = x[j]

    pred_df = pd.DataFrame(embed_dict.items(), columns=['sid', 'preds'])
    test_df = test_df.merge(pred_df)
    
    return test_df

def get_dataset_split(split_i):
    FASTA_FILE = "../Datasets/PSI_Biology/pET_full_without_his_tag.fa"
    LABELS_FILE = "../Datasets/PSI_Biology/class.txt"
    CLUSTERS_FILE = "../Datasets/PSI_Biology/psi_biology_nesg_partitioning_wl_th025_amT.csv"

    labels_df = pd.read_csv(LABELS_FILE, delimiter="\t")
    labels_df.columns = ["sid", "solubility"]
    labels_df.solubility = labels_df.solubility -1

    fasta_dict = read_fasta(FASTA_FILE)
    fasta_df = pd.DataFrame(fasta_dict.items(), columns=['Accession', 'fasta'])
    fasta_df["sid"] = fasta_df.Accession.apply(lambda x: x.split("_")[0])
    print(len(fasta_df))

    data_df = labels_df.merge(fasta_df)
    clusters_df = pd.read_csv(CLUSTERS_FILE)
    clusters_df.columns = ["sid","priority","label-val","between_connectivity","cluster"]
    data_df = data_df.merge(clusters_df)
    data_df = data_df[data_df.cluster == i].reset_index(drop=True)
    return data_df

def test_psi_split(split_i, clf):
    # Input should be a dataframe with these 3 columns: sid, solubility, fasta
    data_df = get_dataset_split(split_i)
    pred_df = predict(data_df, clf)
    evaluate_split(split_i, pred_df)
    return pred_df

def test_independent(split_i, clf):
    data_df = pd.read_csv("../Datasets/NESG/NESG_testset.csv")
    pred_df = predict(data_df, clf)
    return pred_df

In [None]:
# cross validaiton
accs = []
pres = []
mccs = []
aucs = []

for i in range(5):
    path = f"models/{i}PSISplit.ckpt"
    clf = ESMFinetune.load_from_checkpoint(path)
    clf.eval().cuda()
    acc, pre, mcc, auc = test_psi_split(i, clf)
    accs.append(acc)
    pres.append(pre)
    mccs.append(mcc)
    aucs.append(auc)

print(f"{round(np.array(accs).mean(), 2)} + {round(np.array(accs).std(), 2)}" + " & "
      f"{round(np.array(pres).mean(), 2)} + {round(np.array(pres).std(), 2)}" + " & "
      f"{round(np.array(mccs).mean(), 2)} + {round(np.array(mccs).std(), 2)}" + " & "
      f"{round(np.array(aucs).mean(), 2)} + {round(np.array(aucs).std(), 2)}" + " \\\\ ")

In [None]:
# independent validation
preds_list = []
for i in range(5):  
  path = f"models/{i}PSISplit.ckpt"
  clf = ESMFinetune.load_from_checkpoint(path)
  clf.eval().cuda()
  test_df = test_psi_split(i, clf)
  preds_list.append(np.stack(test_df.preds.to_numpy()))


probs = sum(preds_list) / 5
y_test = np.stack(test_df["solubility"].to_numpy())
preds = probs > sum(optimal_thresholds)/5

acc = accuracy_score(y_test, preds)
pre = precision_score(y_test, preds)
mcc = matthews_corrcoef(y_test, preds)
auc = roc_auc_score(y_test, probs)
print(f"Acc: {acc:.3f}, Pre: {pre:.3f}, MCC: {mcc:.3f}, AUC: {auc:.3f}\n")
