In [1]:
import ankh
import h5py
import numpy as np
import pandas as pd
import torch
import yaml
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    matthews_corrcoef,
    precision_score,
    recall_score,
    roc_auc_score,
)
from sklearn.model_selection import GroupKFold
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import BatchSampler, DataLoader, Dataset, SequentialSampler

In [2]:
import clearml
from clearml import Logger, Task

In [3]:
clearml.browser_login()
task = Task.init(
    project_name="DBPs_search",
    task_name="Finetune Ankh v3",
    output_uri=True,
)
logger = Logger.current_logger()

ClearML Task: created new task id=9b912b2281484adca0cba97e26b1a38d
2024-06-20 14:05:45,131 - clearml.Task - INFO - Storing jupyter notebook directly as code
ClearML results page: https://app.clear.ml/projects/45ba7ff7a93646a8a76d1950065cf1d5/experiments/9b912b2281484adca0cba97e26b1a38d/output/log
ClearML Monitor: Could not detect iteration reporting, falling back to iterations as seconds-from-start


In [3]:
with open("config.yml", "r") as f:
    config = yaml.safe_load(f)

In [5]:
task.connect_configuration(config)

{'model_config': {'input_dim': 1536,
  'nhead': 4,
  'hidden_dim': 1536,
  'num_hidden_layers': 1,
  'num_layers': 1,
  'kernel_size': 7,
  'dropout': 0.2,
  'pooling': 'max'},
 'training_config': {'epochs': 10,
  'lr': '2e-4',
  'seed': 42,
  'factor': 0.5,
  'patience': 2,
  'min_lr': '1e-6',
  'batch_size': 64,
  'num_workers': 4,
  'optimizer': 'adamw'}}

In [5]:
input_dim = config["model_config"]["input_dim"]
nhead = config["model_config"]["nhead"]
hidden_dim = config["model_config"]["hidden_dim"]
num_hidden_layers = config["model_config"]["num_hidden_layers"]
num_layers = config["model_config"]["num_layers"]
kernel_size = config["model_config"]["kernel_size"]
dropout = config["model_config"]["dropout"]
pooling = config["model_config"]["pooling"]


epochs = config["training_config"]["epochs"]
lr = config["training_config"]["lr"]
factor = config["training_config"]["factor"]
patience = config["training_config"]["patience"]
min_lr = config["training_config"]["min_lr"]
batch_size = config["training_config"]["batch_size"]
seed = config["training_config"]["seed"]
num_workers = config["training_config"]["num_workers"]

In [6]:
DEVICE = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [7]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)


set_seed(seed)

In [10]:
def load_dict_from_hdf5(filename):
    """
    Load a dictionary with string keys and NumPy array values from an HDF5 file.

    Parameters:
    filename (str): Name of the HDF5 file to load the data from.

    Returns:
    dict: Dictionary with string keys and NumPy array values.
    """
    loaded_dict = {}
    with h5py.File(filename, "r") as f:
        for key in f.keys():
            loaded_dict[key] = f[key][:]
    return loaded_dict

In [11]:
def calculate_metrics(
    all_labels: list, all_preds: list, logits: list
) -> dict[str, float]:
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    logits = np.array(logits)

    auc = roc_auc_score(all_labels, logits)
    accuracy = accuracy_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    specificity = recall_score(all_labels, all_preds, pos_label=0)
    f1 = f1_score(all_labels, all_preds)

    metrics = {
        "Accuracy": accuracy,
        "Sensitivity": recall,
        "Specificity": specificity,
        "Precision": precision,
        "AUC": auc,
        "F1": f1,
        "MCC": mcc,
    }

    return metrics

In [12]:
embeddings = load_dict_from_hdf5(
    "../../../../ssd2/dbp_finder/ankh_embeddings/train_p2_2d.h5"
)

In [12]:
for key in embeddings:
    embeddings[key] = np.squeeze(embeddings[key])

In [13]:
embed_df = pd.DataFrame(list(embeddings.items()), columns=["identifier", "embedding"])

In [14]:
train_df = pd.read_csv("../data/ready_data/train_pdb2272.csv")

In [15]:
train_df = train_df.merge(embed_df, on="identifier")

In [16]:
train_df.label.value_counts()

1    31803
0    31803
Name: label, dtype: int64

In [17]:
len(train_df.embedding.iloc[1])

368

In [18]:
X = train_df["sequence"].tolist()
y = train_df["label"].tolist()
groups = train_df["cluster"].tolist()

In [19]:
gkf = GroupKFold(n_splits=5)

In [20]:
train_folds = []
valid_folds = []

for train_idx, valid_idx in gkf.split(X, y, groups=groups):
    train_idx = train_idx.tolist()
    valid_idx = valid_idx.tolist()

    train = train_df.iloc[train_idx]
    valid = train_df.iloc[valid_idx]

    train_folds.append(train)
    valid_folds.append(valid)

In [21]:
train_folds[0]

Unnamed: 0,identifier,sequence,label,cluster,embedding
0,A0A096MJY4,MGRKKIQITRIMDERNRQVTFTKRKFGLMKKAYELSVLCDCEIALI...,1,1984,"[[0.011778823, 0.004910938, 0.007578383, 0.011..."
1,A0A0D2UG83,MAKSKKIVAATSGSRSRSSRAGLAFPVGRVHRLLRKGHFADRIGSG...,1,23507,"[[0.029099701, -0.014861202, 0.019306818, 0.02..."
2,A0A0G2JTZ2,MSSKQATSPFACTVDGEETMTQDLTSREKEEGSDQHPASHLPLHPI...,1,11680,"[[0.018502971, 0.0045635537, -0.006065971, 0.0..."
4,A0A0G2L7I0,MMEDEDFLLALRLQEQFDQETPAAGWPDEDCPSSKRRRVDPSGGLD...,1,14394,"[[0.027536467, -0.016558107, -0.0020665708, 0...."
5,A0A0G2Q9D6,MGKNEARRSALAPDHGTVVCDPLRRLNRMHATPEESIRIVAAQKKK...,1,23923,"[[0.016459255, -0.0053423513, -0.0066505796, 0..."
...,...,...,...,...,...
63601,Q8WYP3,MTAWTMGARGLDKRGSFFKLIDTIASEIGELKQEMVRTDVNLENGL...,0,575,"[[0.0080460785, -0.017565874, -0.015078276, 0...."
63602,O35550,MAQPGPAPQPDVSLQQRVAELEKINAEFLRAQQQLEQEFNQKRAKF...,0,27336,"[[0.016553048, -0.023833824, -0.0103855645, 0...."
63603,Q9Z0W5,MSGPYDEASEEITDSFWEVGNYKRTVKRIDDGHRLCNDLMSCVQER...,0,3049,"[[0.01263333, 0.0013439438, 0.012199068, 0.010..."
63604,Q7TQ32,MGQSPSPRSPHGSPPTLSTLTLLLLLCGQAHSQCKILRCNAEYVSS...,0,18137,"[[0.0138585605, -0.016503282, -0.004226524, 0...."


In [22]:
class SequenceDataset(Dataset):
    def __init__(self, df):
        self.embeds = df.embedding.tolist()
        self.labels = df.label.tolist()
        self.lengths = [len(embed) for embed in self.embeds]

    def __len__(self):
        return len(self.embeds)

    def __getitem__(self, index):
        x = self.embeds[index]
        y = self.labels[index]

        x = torch.tensor(x, dtype=torch.float)
        y = torch.tensor(y, dtype=torch.float)
        return x, y

In [23]:
def custom_collate_fn(batch):
    # Extract the embeddings from the batch
    embeddings = [item[0] for item in batch]
    labels = torch.tensor([item[1] for item in batch], dtype=torch.float)

    # Pad the embeddings
    padded_embeddings = pad_sequence(embeddings, batch_first=True)
    return padded_embeddings, labels

In [24]:
class CustomBatchSampler(BatchSampler):
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        self.sampler = SequentialSampler(dataset)

    def __iter__(self):
        indices = list(self.sampler)
        indices.sort(
            key=lambda i: self.dataset.lengths[i], reverse=True
        )  # Sort indices by sequence length
        batches = [
            indices[i : i + self.batch_size]
            for i in range(0, len(indices), self.batch_size)
        ]
        for batch in batches:
            yield batch

    def __len__(self):
        return len(self.dataset) // self.batch_size

train - pd dataframe

In [25]:
# train_dataset = SequenceDataset(train)
# train_sampler = CustomBatchSampler(train_dataset, batch_size)
# train_dataloader = DataLoader(
#     train_dataset,
#     num_workers=num_workers,
#     batch_sampler=train_sampler,
#     collate_fn=custom_collate_fn,
# )

In [26]:
# valid_dataset = SequenceDataset(valid)
# valid_sampler = CustomBatchSampler(valid_dataset, batch_size)
# valid_dataloader = DataLoader(
#     valid_dataset,
#     num_workers=num_workers,
#     batch_sampler=valid_sampler,
#     collate_fn=custom_collate_fn,
# )

In [9]:
binary_classification_model = ankh.ConvBertForBinaryClassification(
    input_dim=input_dim,
    nhead=nhead,
    hidden_dim=hidden_dim,
    num_hidden_layers=num_hidden_layers,
    num_layers=num_layers,
    kernel_size=kernel_size,
    dropout=dropout,
    pooling=pooling,
)

In [28]:
# binary_classification_model = binary_classification_model.to(DEVICE)

In [29]:
# a, b = next(iter(train_dataloader))

In [30]:
# a.shape

In [31]:
# output = binary_classification_model(a.to(DEVICE), b.to(DEVICE).unsqueeze(1))

In [32]:
# output = binary_classification_model(a.to(DEVICE), b.to(DEVICE).unsqueeze(1))

In [33]:
# optimizer = AdamW(binary_classification_model.parameters(), lr=float(lr))
# scheduler = ReduceLROnPlateau(
#     optimizer, mode="min", factor=factor, patience=patience, min_lr=float(min_lr)
# )

In [34]:
def train_fn(binary_classification_model, train_dataloader, optimizer):
    binary_classification_model.train()
    loss = 0.0
    for x, y in train_dataloader:
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        y = y.unsqueeze(1)

        optimizer.zero_grad()
        output = binary_classification_model(x, y)
        output.loss.backward()
        optimizer.step()

        loss += output.loss.item()

    epoch_loss = loss / len(train_dataloader)
    return epoch_loss

In [35]:
def validate_fn(binary_classification_model, test_dataloader, scheduler):
    binary_classification_model.eval()
    loss = 0.0
    all_preds = []
    all_labels = []
    logits = []

    with torch.no_grad():
        for x, y in test_dataloader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)

            y = y.unsqueeze(1)

            output = binary_classification_model(x, y)
            loss += output.loss.item()

            preds = (output.logits > 0.5).float()

            logits.extend(output.logits.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    epoch_loss = loss / len(test_dataloader)
    scheduler.step(epoch_loss)
    metrics = calculate_metrics(all_labels, all_preds, logits)
    return epoch_loss, metrics

In [36]:
import logging

for i in range(len(train_folds)):
    train_dataset = SequenceDataset(train_folds[i])
    train_sampler = CustomBatchSampler(train_dataset, batch_size)
    train_dataloader = DataLoader(
        train_dataset,
        num_workers=num_workers,
        batch_sampler=train_sampler,
        collate_fn=custom_collate_fn,
    )

    valid_dataset = SequenceDataset(valid_folds[i])
    valid_sampler = CustomBatchSampler(valid_dataset, batch_size)
    valid_dataloader = DataLoader(
        valid_dataset,
        num_workers=num_workers,
        batch_sampler=valid_sampler,
        collate_fn=custom_collate_fn,
    )

    binary_classification_model = ankh.ConvBertForBinaryClassification(
        input_dim=input_dim,
        nhead=nhead,
        hidden_dim=hidden_dim,
        num_hidden_layers=num_hidden_layers,
        num_layers=num_layers,
        kernel_size=kernel_size,
        dropout=dropout,
        pooling=pooling,
    )

    binary_classification_model = binary_classification_model.to(DEVICE)
    optimizer = AdamW(binary_classification_model.parameters(), lr=float(lr))
    scheduler = ReduceLROnPlateau(
        optimizer, mode="min", factor=factor, patience=patience, min_lr=float(min_lr)
    )

    best_val_loss = float("inf")
    best_model_path = f"checkpoints/best_model_{i}.pth"

    for epoch in range(epochs):
        train_loss = train_fn(binary_classification_model, train_dataloader, optimizer)
        valid_loss, metrics = validate_fn(
            binary_classification_model, valid_dataloader, scheduler
        )

        logger.report_scalar(
            title=f"Loss model {i}", series="train loss", value=train_loss, iteration=epoch
        )
        logger.report_scalar(
            title=f"Loss model {i}", series="valid loss", value=valid_loss, iteration=epoch
        )

        for metric_name, metric_value in metrics.items():
            logger.report_scalar(
                title=f"Metrics model {i}", series=metric_name, value=metric_value, iteration=epoch
            )

        if valid_loss < best_val_loss:
            best_val_loss = valid_loss
            torch.save(binary_classification_model.state_dict(), best_model_path)
            message = f'Saved Best Model on epoch {epoch} with Validation Loss: {best_val_loss}'
            logger.report_text(message, level=logging.DEBUG, print_console=False)

2024-06-20 14:15:22,329 - clearml.storage - ERROR - Exception encountered while uploading Failed uploading object /DBPs_search/Finetune Ankh v3.9b912b2281484adca0cba97e26b1a38d/models/best_model_0.pth (413): <!doctype html>
<html lang=en>
<title>413 Request Entity Too Large</title>
<h1>Request Entity Too Large</h1>
<p>The data value transmitted exceeds the capacity limit.</p>

2024-06-20 14:15:22,336 - clearml.Task - INFO - Failed model upload
2024-06-20 14:19:39,146 - clearml.storage - ERROR - Exception encountered while uploading Failed uploading object /DBPs_search/Finetune Ankh v3.9b912b2281484adca0cba97e26b1a38d/models/best_model_0.pth (413): <!doctype html>
<html lang=en>
<title>413 Request Entity Too Large</title>
<h1>Request Entity Too Large</h1>
<p>The data value transmitted exceeds the capacity limit.</p>

2024-06-20 14:19:39,154 - clearml.Task - INFO - Failed model upload
2024-06-20 14:27:31,679 - clearml.storage - ERROR - Exception encountered while uploading Failed uploadi

Inference average best models

In [52]:
model_0 = ankh.ConvBertForBinaryClassification(
    input_dim=input_dim,
    nhead=nhead,
    hidden_dim=hidden_dim,
    num_hidden_layers=num_hidden_layers,
    num_layers=num_layers,
    kernel_size=kernel_size,
    dropout=dropout,
    pooling=pooling,
)

model_1 = ankh.ConvBertForBinaryClassification(
    input_dim=input_dim,
    nhead=nhead,
    hidden_dim=hidden_dim,
    num_hidden_layers=num_hidden_layers,
    num_layers=num_layers,
    kernel_size=kernel_size,
    dropout=dropout,
    pooling=pooling,
)

model_2 = ankh.ConvBertForBinaryClassification(
    input_dim=input_dim,
    nhead=nhead,
    hidden_dim=hidden_dim,
    num_hidden_layers=num_hidden_layers,
    num_layers=num_layers,
    kernel_size=kernel_size,
    dropout=dropout,
    pooling=pooling,
)

model_3 = ankh.ConvBertForBinaryClassification(
    input_dim=input_dim,
    nhead=nhead,
    hidden_dim=hidden_dim,
    num_hidden_layers=num_hidden_layers,
    num_layers=num_layers,
    kernel_size=kernel_size,
    dropout=dropout,
    pooling=pooling,
)

model_4 = ankh.ConvBertForBinaryClassification(
    input_dim=input_dim,
    nhead=nhead,
    hidden_dim=hidden_dim,
    num_hidden_layers=num_hidden_layers,
    num_layers=num_layers,
    kernel_size=kernel_size,
    dropout=dropout,
    pooling=pooling,
)

In [53]:
model_0.load_state_dict(torch.load("best_model_0.pth"))
model_1.load_state_dict(torch.load("best_model_1.pth"))
model_2.load_state_dict(torch.load("best_model_2.pth"))
model_3.load_state_dict(torch.load("best_model_3.pth"))
model_4.load_state_dict(torch.load("best_model_4.pth"))



<All keys matched successfully>

In [54]:
models = [model_0, model_1, model_2, model_3, model_4]

In [37]:
# task.close()

Testing on benchmark pdb2272

In [55]:
test_embed = load_dict_from_hdf5(
    "../../../../ssd2/dbp_finder/ankh_embeddings/pdb2272_2d.h5"
)

In [56]:
for key in test_embed:
    test_embed[key] = np.squeeze(test_embed[key])

In [57]:
test_embed = pd.DataFrame(list(test_embed.items()), columns=["identifier", "embedding"])
test_df = pd.read_csv("../data/embeddings/input_csv/pdb2272.csv")
test_df = test_df.merge(test_embed, on="identifier")

In [58]:
test_df.label.value_counts()

1    1153
0    1119
Name: label, dtype: int64

In [59]:
testing_set = SequenceDataset(test_df)
testing_dataloader = DataLoader(
    testing_set,
    num_workers=num_workers,
    shuffle=False,
    batch_size=1,
)

In [None]:
x, y = next(iter(testing_dataloader))

In [66]:
def evaluate_fn(models, testing_dataloader):
    all_labels = []
    all_logits = []

    with torch.no_grad():
        for x, y in testing_dataloader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)

            y = y.unsqueeze(1)
            ens_logits = []

            for model in models:
                model.eval()
                model = model.to(DEVICE)
                output = model(x, y)

                logits = output.logits
                ens_logits.append(logits)

            ens_logits = torch.stack(ens_logits, dim=0)
            ens_logits = torch.mean(ens_logits, dim=0)

            all_logits.extend(ens_logits.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    all_preds = [1 if logit > 0.5 else 0 for logit in all_logits]
    metrics = calculate_metrics(all_labels, all_preds, all_logits)
    return metrics

In [67]:
metrics = evaluate_fn(models, testing_dataloader)

In [70]:
metrics_df = pd.DataFrame(metrics, index=["pdb2272"])

In [71]:
logger.report_table(title="pdb2272", series="Metrics", table_plot=metrics_df)

In [72]:
task.close()