In [1]:
!pip install peptdeep

Collecting peptdeep
  Downloading peptdeep-1.3.1-py3-none-any.whl.metadata (51 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/51.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.2/51.2 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
Collecting biopython (from peptdeep)
  Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting streamlit-aggrid (from peptdeep)
  Downloading streamlit_aggrid-1.1.1-py3-none-any.whl.metadata (8.6 kB)
Collecting pyteomics (from peptdeep)
  Downloading pyteomics-4.7.5-py3-none-any.whl.metadata (6.5 kB)
Collecting streamlit>=1.23.0 (from peptdeep)
  Downloading streamlit-1.43.2-py2.py3-none-any.whl.metadata (8.9 kB)
Collecting alphabase>=1.1.0 (from peptdeep)
  Downloading alphabase-1.6.0-py3-none-any.whl.metadata (27 kB)
Collecting alpharaw>=0.2.0 (from peptdeep)
  Downloading alpharaw-0.4.6-py3-none-any.whl.metadata (2

# Library import

In [None]:
from peptdeep.pretrained_models import ModelManager
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, average_precision_score, precision_score, recall_score, f1_score
from torch.nn.utils.rnn import pad_sequence

import sys

# from google.colab import drive # delete this line if not used on colab
# drive.mount("/content/drive", force_remount = True) # delete this line if not used on colab

# sys.path.append("C:/Users/Walraff/OneDrive - Universite de Liege/Documents/Ulg/Master2/TFE/")
# sys.path.append("/content/drive/MyDrive/TFE/")
import utils

import wandb

2025-03-15 11:24:43> Downloading https://github.com/MannLabs/alphapeptdeep/releases/download/pre-trained-models/pretrained_models.zip ...
2025-03-15 11:24:43> The pretrained models had been downloaded in /root/peptdeep/pretrained_models/pretrained_models.zip
Mounted at /content/drive


In [4]:
# data_path = "C:/Users/Walraff/OneDrive - Universite de Liege/Documents/Ulg/Master2/TFE/data"
data_path = "/content/drive/MyDrive/TFE/"
original_df = pd.read_csv(f'{data_path}/final_status_SPARE.csv')
original_df

Unnamed: 0,ProteinName_SPARE,Peptide_SPARE,Status_SPARE
0,sp|P02751|FINC_HUMAN,VDVIPVNLPGEHGQR,bon
1,sp|P02751|FINC_HUMAN,STTPDITGYR,bon
2,sp|P02751|FINC_HUMAN,SYTITGLQPGTDYK,bon
3,sp|P02751|FINC_HUMAN,IYLYTLNDNAR,bon
4,sp|P04114|APOB_HUMAN,TGISPLALIK,bon
...,...,...,...
150,sp|P02743|SAMP_HUMAN,VGEYSLYIGR,bon
151,sp|P04004|VTNC_HUMAN,GQYCYELDEK,mauvais
152,sp|P04004|VTNC_HUMAN,FEDGVLDPDYPR,bon
153,sp|P04004|VTNC_HUMAN,DWHGVPGQVDAAMAGR,bon


# Dataset, model and loading data

In [None]:
# Creating a dataframe with the sequences and labels
df = pd.DataFrame()
df["sequence"] = original_df["Peptide_SPARE"]
# column required for RT model
df["mods"] = ''
df["mod_sites"] = ''
df["nAA"] = df["sequence"].str.len()
df["quantotypic"] = original_df.apply(lambda row: 0 if row['Status_SPARE'] == 'bon' else 1, axis=1)

class_counts = df['quantotypic'].value_counts()
num_pos = class_counts.get(1, 0)  # Nombre d'échantillons de la classe 1
num_neg = class_counts.get(0, 0)  # Nombre d'échantillons de la classe 0

pos_weight = torch.tensor(num_neg / num_pos, dtype=torch.float64)

# Affichage des résultats
print("Class counts:\n", class_counts)
print("Pos weight (for class 1):", pos_weight)

Class counts:
 quantotypic
0    117
1     38
Name: count, dtype: int64
Pos weight (for class 1): tensor(3.0789, dtype=torch.float64)


In [None]:
class RTPeptideDataset(Dataset):
    """
    Custom Dataset for RT models using amino acid indices and modification features.

    Args:
        aa_indices (torch.Tensor or list): Encoded amino acid sequences (e.g., integer indices).
        mod_x (torch.Tensor or list): Additional input features (e.g., modifications or physicochemical properties).
        labels (torch.Tensor or list): Corresponding labels for each sample.
    """
    def __init__(self, aa_indices, mod_x, labels):
        self.aa_indices = aa_indices
        self.mod_x = mod_x
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.aa_indices[idx], self.mod_x[idx], self.labels[idx]


In [None]:
class RTPeptideModel(nn.Module):
    """
    Model using RT encoder and MLP classifier.

    Args:
        encoder (nn.Module): Encoder module that takes (aa_indices, mod_x) as input and returns a feature vector.
        hidden_dim (int): Dimension of the first hidden layer in the classifier.
        dropout_prob (float): Dropout probability used after each hidden layer.
        num_layers (int): Number of hidden layers before the final output layer.
    """
    def __init__(self, encoder, hidden_dim, dropout_prob, num_layers):
        super(RTPeptideModel, self).__init__()
        self.encoder = encoder  # Encoder module used to process inputs

        layers = []
        input_dim = 256  # output dimension of encoder
        for _ in range(num_layers):
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_prob))
            input_dim = hidden_dim
            hidden_dim = max(hidden_dim // 2, 1)  # Gradually reduce dimensionality

        layers.append(nn.Linear(input_dim, 1))  # Final output layer for binary classification
        self.classifier = nn.Sequential(*layers)

    def forward(self, aa_indices, mod_x):
        X = self.encoder(aa_indices, mod_x)
        X = self.classifier(X)
        return X


In [8]:
model_mgr = ModelManager()
model_mgr.load_installed_models()
pretrained_model = model_mgr.rt_model

# Functions

In [None]:
def precompute_features(df, pretrained_model, model_name='CCS'):
    """
    Precomputes peptide features (`aa_indices`, `mod_x`, and `labels`) to accelerate training.

    Args:
        df (pd.DataFrame): DataFrame containing peptide sequences and labels.
        pretrained_model (nn.Module): Pretrained model with methods to extract amino acid and modification features.
        model_name (str, optional): Identifier for model-specific preprocessing logic (e.g., 'CCS', 'RT').

    Returns:
        tuple: Three lists:
            - aa_indices_list (list of torch.Tensor): Encoded amino acid index features.
            - mod_x_list (list of torch.Tensor): Additional modification features.
            - labels_list (list of torch.Tensor): Corresponding labels as float tensors.
    """
    aa_indices_list = []
    mod_x_list = []
    labels_list = []

    for _, row in df.iterrows():
        sequence = row["sequence"]
        label = row["quantotypic"]

        # Build temporary DataFrame for encoder compatibility
        row_df = pd.DataFrame({"sequence": [sequence]})
        row_df["nAA"] = row_df["sequence"].str.len()
        row_df["mod_sites"] = ""
        row_df["mods"] = ""

        if model_name == 'CCS':
            row_df["charge"] = 0  # Required for CCS input

        # Extract features using the pretrained model
        aa_indices = pretrained_model._get_26aa_indice_features(row_df).squeeze(0)
        mod_x = pretrained_model._get_mod_features(row_df).squeeze(0)

        # Store computed features
        aa_indices_list.append(aa_indices)
        mod_x_list.append(mod_x)
        labels_list.append(torch.tensor(label, dtype=torch.float32))

    return aa_indices_list, mod_x_list, labels_list

In [None]:
def collate_batch(batch):
    """
    Custom collate function to dynamically pad amino acid and modification feature sequences.
    
    Args:
        batch (list of tuples): Each element is a tuple (aa_indices, mod_x, y).

    Returns:
        tuple:
            - aa_indices_padded (torch.Tensor): Padded tensor of amino acid index sequences (batch_size, max_len).
            - mod_x_padded (torch.Tensor): Padded tensor of modification features (batch_size, max_len, feature_dim).
            - y_tensor (torch.Tensor): Tensor of labels (batch_size,).
    """
    aa_indices_list, mod_x_list, y_list = zip(*batch)

    # Pad variable-length amino acid index sequences
    aa_indices_padded = pad_sequence(aa_indices_list, batch_first=True, padding_value=0)

    # Pad variable-length modification feature vectors
    mod_x_padded = pad_sequence(mod_x_list, batch_first=True, padding_value=0)

    # Convert labels to tensor
    y_tensor = torch.tensor(y_list, dtype=torch.float32)

    return aa_indices_padded, mod_x_padded, y_tensor


In [None]:
# Initialize the model with the pretrained encoder
model = RTPeptideModel(pretrained_model.model.rt_encoder, hidden_dim=128, dropout_prob=0.5, num_layers=2)

# Select the device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model and encoder to the selected device
model.to(device)
model.encoder.to(device)

# Precompute features for all peptides in the dataset
aa_indices_list, mod_x_list, labels_list = precompute_features(df, pretrained_model, model_name='RT')

# Create the dataset using precomputed features
dataset = RTPeptideDataset(aa_indices_list, mod_x_list, labels_list)

# Create the DataLoader with custom collate function for padding
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_batch)

# Example training loop iteration
for batch in train_loader:
    aa_idx, mod_x, y = batch

    # Move batch data to the selected device
    aa_idx, mod_x, y = aa_idx.to(device), mod_x.to(device), y.to(device)

    print(aa_idx)
    print(mod_x)
    print(y)
    print(model(aa_idx, mod_x))
    break


tensor([[ 0, 22, 22, 12,  8, 16, 14, 25,  8, 17, 22,  4,  9,  7, 12,  9, 11,  0,
          0,  0,  0,  0,  0],
        [ 0, 25,  5,  9, 20, 20,  9,  8, 14, 12,  6, 18,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0],
        [ 0, 17,  6, 16,  9, 12, 12,  4,  6, 11,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0],
        [ 0,  4, 23,  8,  7, 22, 16,  7, 17, 22,  4,  1,  1, 13,  1,  7, 18,  0,
          0,  0,  0,  0,  0],
        [ 0,  4, 14,  5, 12, 12, 22, 25, 11,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0],
        [ 0, 12, 19, 16,  9, 25, 14, 12, 22, 16, 22, 11,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0],
        [ 0, 20, 12, 12, 22,  5,  1,  5,  7,  9,  5, 17,  5, 11,  0,  0,  0,  0,
          0,  0,  0,  0,  0],
        [ 0,  7, 25, 19,  9,  6, 19, 25,  1, 20, 11,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0],
        [ 0, 17, 25, 25,  5,  7, 19,  5,  9, 22, 22,  1,  7, 18,  0,  0,  0,  0,
          0,  0,  0,  0,  0],
 

In [None]:
def early_stopping(val_losses, patience):
    """
    Checks if the last 'patience' validation losses have increased or remained the same.

    Args:
        val_losses (list): List of validation losses.
        patience (int): Number of consecutive epochs to check for lack of improvement.

    Returns:
        bool: True if the last 'patience' losses show no improvement, otherwise False.
    """
    # Not enough values to apply the early stopping criterion
    if len(val_losses) < patience + 1:
        return False

    # Check if each of the last 'patience' losses is greater than or equal to the previous one
    for i in range(-patience, -1):
        if val_losses[i] < val_losses[i + 1]:
            continue
        else:
            return False

    return True

In [None]:
def reordering(data, labels, batch_converter, device):
    """
    Prepares and converts a batch of sequence data and labels for ESM-2 model input.

    Args:
        data (tuple): Tuple containing two lists (e.g., sequence names and sequences).
        labels (torch.Tensor): Tensor of labels associated with each sequence.
        batch_converter (callable): Function to convert (name, sequence) pairs into tokenized format.
        device (torch.device): Device to move the tensors to (e.g., 'cuda' or 'cpu').

    Returns:
        tuple:
            - batch_tokens (torch.Tensor): Tensor of tokenized sequences.
            - batch_labels (torch.Tensor): Tensor of corresponding labels (shape: [batch_size]).
    """
    # Reformat data into a list of (name, sequence) tuples
    ReArrangeData = []
    for i in range(len(data[0])):
        ReArrangeData.append((data[0][i], data[1][i]))

    # Move labels to device and ensure float format
    labels = labels.float()
    batch_labels = labels.to(device).unsqueeze(1)

    # Convert sequences to tokens using the batch converter
    batch_tokens = batch_converter(ReArrangeData)[2].to(device)

    return batch_tokens, batch_labels.squeeze(1)

def train_epoch(model, dataloader, optimizer, criterion, device, model_name='Scratch', batch_converter=None):
    """
    Trains the model for a single epoch.

    Args:
        model (nn.Module): The model to train.
        dataloader (DataLoader): DataLoader providing training data batches.
        optimizer (torch.optim.Optimizer): Optimizer used for updating model weights.
        criterion (nn.Module): Loss function to compute the training loss.
        device (str or torch.device): Device on which computations are performed ('cuda' or 'cpu').
        model_name (str): Specifies the model type (e.g., 'Scratch', 'ESM', 'RT', 'CCS').
        batch_converter (callable, optional): Function used to convert ESM data batches.

    Returns:
        float: Mean training loss over the entire epoch.
    """
    model.train()
    train_losses = []

    for batch in dataloader:
        # Prepare inputs and targets based on the model type
        if model_name == 'ESM':
            X, y = batch
            X, y = reordering(X, y, batch_converter, device)
        elif model_name in ['RT', 'CCS']:
            aa_idx, mod_x, y = batch
            aa_idx, mod_x, y = aa_idx.to(device), mod_x.to(device), y.to(device)
        else:
            X, y = batch
            X, y = X.to(device), y.to(device)

        # Reset gradients
        optimizer.zero_grad()

        # Forward pass
        if model_name in ['RT', 'CCS']:
            output = model(aa_idx, mod_x).squeeze(1)
        else:
            output = model(X).squeeze(1)

        # Compute loss
        loss = criterion(output, y)

        # Backward pass and optimizer step
        loss.backward()
        optimizer.step()

        # Record the loss
        train_losses.append(loss.item())
    
    # wandb.log({"training_loss": np.array(train_losses).mean()})

    return np.mean(train_losses)

def val_epoch(model, dataloader, criterion, device, threshold=0.5, model_name='Scratch', batch_converter=None):
    """
    Validates the model for one epoch and computes evaluation metrics.

    Args:
        model (nn.Module): The model to evaluate.
        dataloader (DataLoader): DataLoader providing validation data.
        criterion (nn.Module): Loss function used for evaluation.
        device (str or torch.device): Device on which computations are performed ('cuda' or 'cpu').
        threshold (float, optional): Threshold for binary classification (applied to predicted probabilities). Defaults to 0.5.
        model_name (str): Model type identifier (e.g., 'Scratch', 'ESM', 'RT', 'CCS').
        batch_converter (callable, optional): Function for converting ESM batches into tokenized format.

    Returns:
        dict: Dictionary containing the following metrics:
              - "loss": Mean validation loss.
              - "accuracy": Classification accuracy.
              - "roc_auc": Area under the ROC curve.
              - "pr_auc": Area under the precision-recall curve.
    """
    model.eval()
    val_losses = []
    all_targets = []
    all_predictions = []
    all_probabilities = []

    with torch.no_grad():
        for batch in dataloader:
            # Prepare inputs and targets based on model type
            if model_name == 'ESM':
                X, y = batch
                X, y = reordering(X, y, batch_converter, device)
            elif model_name in ['RT', 'CCS']:
                aa_idx, mod_x, y = batch
                aa_idx, mod_x, y = aa_idx.to(device), mod_x.to(device), y.to(device)
            else:
                X, y = batch
                X, y = X.to(device), y.to(device)

            # Forward pass
            if model_name in ['RT', 'CCS']:
                output = model(aa_idx, mod_x).squeeze(1)
            else:
                output = model(X).squeeze(1)

            probabilities = torch.sigmoid(output)  # Convert logits to probabilities
            loss = criterion(output, y)
            val_losses.append(loss.item())

            # Threshold probabilities to get binary predictions
            binary_output = (probabilities >= threshold).float()

            # Store results
            all_targets.extend(y.cpu().numpy())
            all_predictions.extend(binary_output.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())

    # Compute average loss and metrics
    mean_loss = np.mean(val_losses)
    accuracy = (np.array(all_predictions) == np.array(all_targets)).mean()
    roc_auc = roc_auc_score(all_targets, all_probabilities) if len(np.unique(all_targets)) > 1 else 0.0
    pr_auc = average_precision_score(all_targets, all_probabilities) if len(np.unique(all_targets)) > 1 else 0.0

    # Log metrics with wandb
    # wandb.log({
    #     "validation_loss": mean_loss,
    #     "validation_accuracy": accuracy,
    #     "validation_roc_auc": roc_auc,
    #     "validation_pr_auc": pr_auc,
    # })

    return {
        "loss": mean_loss,
        "accuracy": accuracy,
        "roc_auc": roc_auc,
        "pr_auc": pr_auc
    }

In [None]:
def train(model, train_loader, val_loader, n_epochs, lr, filename, threshold=0.5,
          do_early_stopping=True, model_name='Scratch', batch_converter=None, weight=None):
    """
    Trains the model over multiple epochs with validation after each epoch.

    Args:
        model (nn.Module): The model instance to train.
        train_loader (DataLoader): DataLoader providing training data.
        val_loader (DataLoader): DataLoader providing validation data.
        n_epochs (int): Number of training epochs.
        lr (float): Learning rate for the optimizer.
        filename (str): File name to save the final model.
        threshold (float, optional): Threshold for binary classification. Defaults to 0.5.
        do_early_stopping (bool, optional): Whether to apply early stopping. Defaults to True.
        model_name (str, optional): Identifier for model type (e.g., 'Scratch', 'ESM', 'RT', 'CCS').
        batch_converter (callable, optional): Function to convert ESM batches into token format.
        weight (torch.Tensor, optional): Weight for positive class in BCEWithLogitsLoss.

    Returns:
        dict: Training history containing:
            - "train_loss": List of training losses per epoch.
            - "val_loss": List of validation losses per epoch.
            - "val_accuracy": List of validation accuracies.
            - "val_roc_auc": List of ROC-AUC scores.
            - "val_pr_auc": List of PR-AUC scores.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.cuda.empty_cache()
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    if weight is not None:
        criterion = nn.BCEWithLogitsLoss(pos_weight=weight)
    else:
        criterion = nn.BCEWithLogitsLoss()

    history = {
        "train_loss": [],
        "val_loss": [],
        "val_accuracy": [],
        "val_roc_auc": [],
        "val_pr_auc": [],
    }

    for epoch in range(n_epochs):
        # Train for one epoch
        train_loss = train_epoch(
            model, train_loader, optimizer, criterion, device,
            model_name=model_name, batch_converter=batch_converter
        )

        # Validate after each epoch
        val_metrics = val_epoch(
            model, val_loader, criterion, device, threshold,
            model_name=model_name, batch_converter=batch_converter
        )

        # Record metrics
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_metrics["loss"])
        history["val_accuracy"].append(val_metrics["accuracy"])
        history["val_roc_auc"].append(val_metrics["roc_auc"])
        history["val_pr_auc"].append(val_metrics["pr_auc"])

        # Early stopping
        if early_stopping(history["val_loss"], patience=5) and do_early_stopping:
            print(f"Early stopping triggered at epoch {epoch}.")
            break

    # wandb.log({"Early_stopping": epoch})

    return history

In [None]:
def compute_pos_weight(df, target_column):
    """
    Computes the pos_weight value for BCEWithLogitsLoss based on class imbalance.

    Args:
        df (pd.DataFrame): DataFrame containing the target labels.
        target_column (str): Name of the column with binary class labels.

    Returns:
        torch.Tensor: A tensor representing the positive class weight for use in BCEWithLogitsLoss.
                      If no positive samples are found, defaults to a weight of 1.0.
    """
    class_counts = df[target_column].value_counts()

    num_pos = class_counts.get(1, 0)  # Number of positive samples
    num_neg = class_counts.get(0, 0)  # Number of negative samples

    if num_pos == 0:
        pos_weight = torch.tensor(1.0, dtype=torch.float64)
    else:
        pos_weight = torch.tensor(num_neg / num_pos, dtype=torch.float64)

    return pos_weight

compute_pos_weight(df, 'quantotypic')

tensor(3.0789, dtype=torch.float64)

In [None]:
def compute_best(history, best_history, model, best_model):
    """
    Compares the current model to the best model so far based on validation PR AUC,
    and updates the best model and history if the current model performs better.

    Args:
        history (dict): Dictionary containing the current training and validation metrics.
        best_history (dict): Dictionary containing the best validation metrics so far.
        model (nn.Module): Current model instance.
        best_model (nn.Module): Best model instance so far.

    Returns:
        tuple: Updated best model and its corresponding history.
    """
    if best_history is None or (history["val_pr_auc"][-1] > best_history["val_pr_auc"][-1]):
        best_model = model
        best_history = history

    return best_model, best_history

def evaluate_rt_model(model, test_df):
    """
    Evaluates a trained RT model on the test dataset and returns performance metrics.

    Args:
        model (nn.Module): The trained RT model to evaluate.
        test_df (pd.DataFrame): Test data containing peptide sequences and labels.

    Returns:
        dict: Dictionary with evaluation metrics:
              - 'accuracy': Classification accuracy.
              - 'precision': Precision score.
              - 'recall': Recall score.
              - 'f1': F1 score.
              - 'roc_auc': Area under the ROC curve.
              - 'pr_auc': Area under the Precision-Recall curve.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.eval()
    model.to(device)

    # Load the pretrained feature extractor
    model_mgr = ModelManager()
    model_mgr.load_installed_models()
    pretrained_model = model_mgr.rt_model

    # Precompute input features
    test_aa_indices, test_mod_x, test_labels = precompute_features(test_df, pretrained_model, model_name='RT')
    test_dataset = RTPeptideDataset(test_aa_indices, test_mod_x, test_labels)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_batch)

    all_preds = []
    all_targets = []

    with torch.no_grad():
        for aa_indices, mod_x, labels in test_loader:
            aa_indices, mod_x, labels = aa_indices.to(device), mod_x.to(device), labels.to(device)
            outputs = model(aa_indices, mod_x)
            preds = torch.sigmoid(outputs).cpu().numpy()
            all_preds.extend(preds)
            all_targets.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds).flatten()
    all_targets = np.array(all_targets)

    predicted_classes = (all_preds >= 0.5).astype(int)

    accuracy = accuracy_score(all_targets, predicted_classes)
    precision = precision_score(all_targets, predicted_classes)
    recall = recall_score(all_targets, predicted_classes)
    f1 = f1_score(all_targets, predicted_classes)
    roc_auc = roc_auc_score(all_targets, all_preds)
    pr_auc = average_precision_score(all_targets, all_preds)

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'roc_auc': roc_auc,
        'pr_auc': pr_auc
    }

def write_into_json(dict, filename):
    import json
    with open(filename, 'w') as f:
        json.dump(dict, f)

import warnings
warnings.simplefilter("ignore", category=FutureWarning)

# Training

In [None]:
LEARNING_RATE = 3e-4
EMBED_DIM = 128
HIDDEN_DIM = 128
OUTPUT_DIM = 1
N_EPOCHS = 100
N_SPLITS = 5
DROPOUT = 0.3
BATCH_SIZE = 32

# Define paths and useful variables
count = 1
do_early_stopping = False
models_path = "/content/drive/MyDrive/TFE/Models"
results_path = "/content/drive/MyDrive/TFE/Results"
images_path = "/content/drive/MyDrive/TFE/Images"

imbalanced_name_dict = {
    "Imbalance": {
        (False, False): "N",
        (True, False): "W",
        (False, True): "O"
    },
    "Early Stopping": {
        True: "ES",
        False: "NES"
    },
    "Transfer Learning": {
        True: "FT",
        False: "FE",
        None: "N"
    }
}

kf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=42)

# Loop over all configurations of variant (here only different class imbalance technique) in order to compute all in once
for FT in [False, True]:
    for weighting, oversampling in [(False, False), (True, False), (False, True)]:
        temp_results_list = []
        # Loop over the folds
        for fold, (train_idx, test_idx) in enumerate(kf.split(df, df['quantotypic'])):
            # Split the data and creating dataloaders
            print(f"Fold {fold+1} on 5")
            tmp_train_df = df.iloc[train_idx]
            test_df = df.iloc[test_idx]

            train_df, val_df = train_test_split(tmp_train_df, test_size=0.2, stratify=tmp_train_df['quantotypic'], random_state=42)

            if oversampling:
                train_df = utils.balance_classes_with_oversampling(train_df)

            pos_weight = compute_pos_weight(train_df, 'quantotypic') if weighting else None

            model_mgr = ModelManager()
            pretrained_model = model_mgr.rt_model

            train_aa_indices_list, train_mod_x_list, train_labels_list = precompute_features(train_df, pretrained_model, model_name = 'RT')
            validation_aa_indices_list, validation_mod_x_list, validation_labels_list = precompute_features(val_df, pretrained_model, model_name = 'RT')

            train_dataset = RTPeptideDataset(train_aa_indices_list, train_mod_x_list, train_labels_list)
            validation_dataset = RTPeptideDataset(validation_aa_indices_list, validation_mod_x_list, validation_labels_list)

            train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
            validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

            best_model = None
            best_history = None

            # GridSearch for hyperparameters
            config_count = 1
            for LR in [0.0001, 0.001, 0.01, 0.1]:
                for n_layers in [1, 2, 3]:
                    for dropout in [0.2, 0.3]:
                        for hidden_size, embed_size in [(64, 64), (128, 128)]:
                            for n_epoch in [25, 50, 100]:
                                print(f"Testing config {config_count} on {4*3*2*2*3}")
                                config_count += 1

                                # Loading the model
                                model_mgr = ModelManager()
                                pretrained_model = model_mgr.rt_model

                                # Freeze the encoder if not fine-tuning
                                for param in pretrained_model.model.rt_encoder.parameters():
                                    param.requires_grad = FT

                                # Initialize the model with the pretrained encoder
                                model = RTPeptideModel(pretrained_model.model.rt_encoder, hidden_dim=hidden_size, dropout_prob=dropout, num_layers=n_layers)

                                # wandb.init(
                                #     project="MLP",
                                #     name=f"MLP_Scratch_config_hyperparameters",
                                #     config={
                                #         "learning_rate": LR,
                                #         "architecture": "MLP",
                                #         "dataset": "SPARE",
                                #         "epochs": N_EPOCHS,
                                #         "Batch_size": batch_size,
                                #         "Dropout": dropout,
                                #         "Hidden_size": hidden_size,
                                #         "Embedding_size": embed_size,
                                #         "Number of layers": n_layers,
                                #         "Oversampling": oversampling,
                                #         "Weighting": weighting,
                                #         "Early stopping": do_early_stopping
                                #     }
                                # )


                                # Train the model
                                history = train(model, train_loader, validation_loader, n_epoch, LR, "test", do_early_stopping=do_early_stopping,
                                                model_name='RT', weight=pos_weight)

                                # Evaluate the model on the validation set
                                best_model, best_history = compute_best(history, best_history, model, best_model)

                                # wandb.finish()

            print("\n--------------------------------------Config Done--------------------------------------------\n")
            # Evaluate the best model on the test set
            results = evaluate_rt_model(best_model, test_df)

            temp_results_list.append({
                'accuracy': results['accuracy'],
                'precision': results['precision'],
                'recall': results['recall'],
                'f1': results['f1'],
                'roc_auc': results['roc_auc'],
                'pr_auc': results['pr_auc']
            })

            wandb.init(project="RT", name=f"ValidProtocol_RT_Scratch_fold_{fold+1}",
                    config={
                            "weighting": weighting,
                            "oversampling": oversampling,
                            "early_stopping": do_early_stopping,
                            "Fine tuning": FT
                    })

            wandb.log(results)
            wandb.finish()

            # Save the best model
            text = f"{imbalanced_name_dict['Early Stopping'][do_early_stopping]}_{imbalanced_name_dict['Transfer Learning'][FT]}_{imbalanced_name_dict['Imbalance'][(weighting, oversampling)]}"
            filename = f"Best_RT_{text}_cvfold{fold+1}"
            torch.save(best_model.state_dict(), f"{models_path}/{filename}.pth")

        accuracy_list = [res['accuracy'] for res in temp_results_list]
        precision_list = [res['precision'] for res in temp_results_list]
        recall_list = [res['recall'] for res in temp_results_list]
        f1_list = [res['f1'] for res in temp_results_list]
        roc_auc_list = [res['roc_auc'] for res in temp_results_list]
        pr_auc_list = [res['pr_auc'] for res in temp_results_list]

        metrics_summary = {
            'accuracy_mean': np.mean(accuracy_list),
            'accuracy_std': np.std(accuracy_list),
            'precision_mean': np.mean(precision_list),
            'precision_std': np.std(precision_list),
            'recall_mean': np.mean(recall_list),
            'recall_std': np.std(recall_list),
            'f1_mean': np.mean(f1_list),
            'f1_std': np.std(f1_list),
            'roc_auc_mean': np.mean(roc_auc_list),
            'roc_auc_std': np.std(roc_auc_list),
            'pr_auc_mean': np.mean(pr_auc_list),
            'pr_auc_std': np.std(pr_auc_list),
        }

        wandb.init(project="RT", name=f"ValidProtocol_RT_Scratch_CrossValidation_Summary",
                config={
                        "weighting": weighting,
                        "oversampling": oversampling,
                        "early_stopping": do_early_stopping,
                        "Fine tuning": FT
                })
        wandb.log(metrics_summary)
        wandb.finish()

        dict_save = {
            'weighting': weighting,
            'oversampling': oversampling,
            'early_stopping': do_early_stopping,
            "Fine tuning": FT,
            'metrics': metrics_summary
        }

        # Save the results into a JSON file
        write_into_json(dict_save, f"{results_path}/ValidProtocol_RT_Experiment_{count}.json")
        count += 1

Fold 1 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjimmy-walraff02[0m ([33mTFE-proteomics[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.64516
f1,0.15385
pr_auc,0.36549
precision,0.16667
recall,0.14286
roc_auc,0.5


Fold 2 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.77419
f1,0.0
pr_auc,0.29761
precision,0.0
recall,0.0
roc_auc,0.60714


Fold 3 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.64516
f1,0.0
pr_auc,0.28677
precision,0.0
recall,0.0
roc_auc,0.47826


Fold 4 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.74194
f1,0.0
pr_auc,0.35361
precision,0.0
recall,0.0
roc_auc,0.50543


Fold 5 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.74194
f1,0.0
pr_auc,0.25432
precision,0.0
recall,0.0
roc_auc,0.40761


0,1
accuracy_mean,▁
accuracy_std,▁
f1_mean,▁
f1_std,▁
pr_auc_mean,▁
pr_auc_std,▁
precision_mean,▁
precision_std,▁
recall_mean,▁
recall_std,▁

0,1
accuracy_mean,0.70968
accuracy_std,0.05398
f1_mean,0.03077
f1_std,0.06154
pr_auc_mean,0.31156
pr_auc_std,0.04186
precision_mean,0.03333
precision_std,0.06667
recall_mean,0.02857
recall_std,0.05714


Fold 1 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.74194
f1,0.0
pr_auc,0.25256
precision,0.0
recall,0.0
roc_auc,0.50595


Fold 2 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.22581
f1,0.36842
pr_auc,0.2776
precision,0.22581
recall,1.0
roc_auc,0.57738


Fold 3 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.58065
f1,0.13333
pr_auc,0.28536
precision,0.14286
recall,0.125
roc_auc,0.44022


Fold 4 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.45161
f1,0.41379
pr_auc,0.44577
precision,0.28571
recall,0.75
roc_auc,0.60326


Fold 5 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.51613
f1,0.11765
pr_auc,0.25622
precision,0.11111
recall,0.125
roc_auc,0.3587


0,1
accuracy_mean,▁
accuracy_std,▁
f1_mean,▁
f1_std,▁
pr_auc_mean,▁
pr_auc_std,▁
precision_mean,▁
precision_std,▁
recall_mean,▁
recall_std,▁

0,1
accuracy_mean,0.50323
accuracy_std,0.16898
f1_mean,0.20664
f1_std,0.15816
pr_auc_mean,0.3035
pr_auc_std,0.07221
precision_mean,0.1531
precision_std,0.09816
recall_mean,0.4
recall_std,0.39843


Fold 1 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.67742
f1,0.375
pr_auc,0.42176
precision,0.33333
recall,0.42857
roc_auc,0.59524


Fold 2 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.70968
f1,0.30769
pr_auc,0.33505
precision,0.33333
recall,0.28571
roc_auc,0.66667


Fold 3 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.54839
f1,0.22222
pr_auc,0.34131
precision,0.2
recall,0.25
roc_auc,0.42391


Fold 4 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.6129
f1,0.33333
pr_auc,0.40538
precision,0.3
recall,0.375
roc_auc,0.6087


Fold 5 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.70968
f1,0.18182
pr_auc,0.2524
precision,0.33333
recall,0.125
roc_auc,0.34239


0,1
accuracy_mean,▁
accuracy_std,▁
f1_mean,▁
f1_std,▁
pr_auc_mean,▁
pr_auc_std,▁
precision_mean,▁
precision_std,▁
recall_mean,▁
recall_std,▁

0,1
accuracy_mean,0.65161
accuracy_std,0.06255
f1_mean,0.28401
f1_std,0.07146
pr_auc_mean,0.35118
pr_auc_std,0.06006
precision_mean,0.3
precision_std,0.05164
recall_mean,0.29286
recall_std,0.1051


Fold 1 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.67742
f1,0.16667
pr_auc,0.28191
precision,0.2
recall,0.14286
roc_auc,0.59524


Fold 2 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.54839
f1,0.22222
pr_auc,0.21674
precision,0.18182
recall,0.28571
roc_auc,0.43452


Fold 3 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.67742
f1,0.28571
pr_auc,0.34273
precision,0.33333
recall,0.25
roc_auc,0.49457


Fold 4 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.6129
f1,0.25
pr_auc,0.2961
precision,0.25
recall,0.25
roc_auc,0.44565


Fold 5 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.58065
f1,0.13333
pr_auc,0.34212
precision,0.14286
recall,0.125
roc_auc,0.42391


0,1
accuracy_mean,▁
accuracy_std,▁
f1_mean,▁
f1_std,▁
pr_auc_mean,▁
pr_auc_std,▁
precision_mean,▁
precision_std,▁
recall_mean,▁
recall_std,▁

0,1
accuracy_mean,0.61935
accuracy_std,0.05161
f1_mean,0.21159
f1_std,0.05518
pr_auc_mean,0.29592
pr_auc_std,0.04646
precision_mean,0.2216
precision_std,0.06564
recall_mean,0.21071
recall_std,0.06429


Fold 1 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.70968
f1,0.0
pr_auc,0.26482
precision,0.0
recall,0.0
roc_auc,0.53571


Fold 2 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.77419
f1,0.22222
pr_auc,0.46334
precision,0.5
recall,0.14286
roc_auc,0.69643


Fold 3 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.58065
f1,0.13333
pr_auc,0.25394
precision,0.14286
recall,0.125
roc_auc,0.45652


Fold 4 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.70968
f1,0.18182
pr_auc,0.56195
precision,0.33333
recall,0.125
roc_auc,0.76087


Fold 5 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.74194
f1,0.0
pr_auc,0.3114
precision,0.0
recall,0.0
roc_auc,0.33152


0,1
accuracy_mean,▁
accuracy_std,▁
f1_mean,▁
f1_std,▁
pr_auc_mean,▁
pr_auc_std,▁
precision_mean,▁
precision_std,▁
recall_mean,▁
recall_std,▁

0,1
accuracy_mean,0.70323
accuracy_std,0.06579
f1_mean,0.10747
f1_std,0.09216
pr_auc_mean,0.37109
pr_auc_std,0.12126
precision_mean,0.19524
precision_std,0.19541
recall_mean,0.07857
recall_std,0.06448


Fold 1 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.64516
f1,0.26667
pr_auc,0.34572
precision,0.25
recall,0.28571
roc_auc,0.60119


Fold 2 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.74194
f1,0.5
pr_auc,0.63368
precision,0.44444
recall,0.57143
roc_auc,0.80952


Fold 3 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.67742
f1,0.28571
pr_auc,0.27724
precision,0.33333
recall,0.25
roc_auc,0.41304


Fold 4 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.70968
f1,0.0
pr_auc,0.19537
precision,0.0
recall,0.0
roc_auc,0.26087


Fold 5 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.6129
f1,0.14286
pr_auc,0.23376
precision,0.16667
recall,0.125
roc_auc,0.42935


0,1
accuracy_mean,▁
accuracy_std,▁
f1_mean,▁
f1_std,▁
pr_auc_mean,▁
pr_auc_std,▁
precision_mean,▁
precision_std,▁
recall_mean,▁
recall_std,▁

0,1
accuracy_mean,0.67742
accuracy_std,0.04562
f1_mean,0.23905
f1_std,0.16579
pr_auc_mean,0.33715
pr_auc_std,0.15645
precision_mean,0.23889
precision_std,0.15072
recall_mean,0.24643
recall_std,0.19113


In [None]:
model_mgr = ModelManager()
model_mgr.load_installed_models()
pretrained_model = model_mgr.rt_model

dir(pretrained_model.model.rt_encoder)

  torch.load(stream, map_location=self.device), strict=False


['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_backward_pre_hooks',
 '_buffers',
 '_call_impl',
 '_compiled_call_impl',
 '_forward_hooks',
 '_forward_hooks_always_called',
 '_forward_hooks_with_kwargs',
 '_forward_pre_hooks',
 '_forward_pre_hooks_with_kwargs',
 '_get_backward_hooks',
 '_get_backward_pre_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_post_hooks',
 '_load_state_dict_pre_hooks',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_named_members',
 '_non_persistent_buffers_se

In [None]:
print(pretrained_model.model.rt_encoder)

Encoder_26AA_Mod_CNN_LSTM_AttnSum(
  (mod_nn): Mod_Embedding_FixFirstK(
    (nn): Linear(in_features=103, out_features=2, bias=False)
  )
  (input_cnn): SeqCNN(
    (cnn_short): Conv1d(35, 35, kernel_size=(3,), stride=(1,), padding=(1,))
    (cnn_medium): Conv1d(35, 35, kernel_size=(5,), stride=(1,), padding=(2,))
    (cnn_long): Conv1d(35, 35, kernel_size=(7,), stride=(1,), padding=(3,))
  )
  (hidden_nn): SeqLSTM(
    (rnn): LSTM(140, 128, num_layers=2, batch_first=True, bidirectional=True)
  )
  (attn_sum): SeqAttentionSum(
    (attn): Sequential(
      (0): Linear(in_features=256, out_features=1, bias=False)
      (1): Softmax(dim=1)
    )
  )
)


In [None]:
print(pretrained_model.model)

Model_RT_LSTM_CNN(
  (dropout): Dropout(p=0.1, inplace=False)
  (rt_encoder): Encoder_26AA_Mod_CNN_LSTM_AttnSum(
    (mod_nn): Mod_Embedding_FixFirstK(
      (nn): Linear(in_features=103, out_features=2, bias=False)
    )
    (input_cnn): SeqCNN(
      (cnn_short): Conv1d(35, 35, kernel_size=(3,), stride=(1,), padding=(1,))
      (cnn_medium): Conv1d(35, 35, kernel_size=(5,), stride=(1,), padding=(2,))
      (cnn_long): Conv1d(35, 35, kernel_size=(7,), stride=(1,), padding=(3,))
    )
    (hidden_nn): SeqLSTM(
      (rnn): LSTM(140, 128, num_layers=2, batch_first=True, bidirectional=True)
    )
    (attn_sum): SeqAttentionSum(
      (attn): Sequential(
        (0): Linear(in_features=256, out_features=1, bias=False)
        (1): Softmax(dim=1)
      )
    )
  )
  (rt_decoder): Decoder_Linear(
    (nn): Sequential(
      (0): Linear(in_features=256, out_features=64, bias=True)
      (1): PReLU(num_parameters=1)
      (2): Linear(in_features=64, out_features=1, bias=True)
    )
  )
)


In [None]:
dir(pretrained_model)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_as_tensor',
 '_check_predict_in_order',
 '_device',
 '_device_ids',
 '_device_type',
 '_fixed_sequence_len',
 '_get_26aa_indice_features',
 '_get_aa_features',
 '_get_aa_indice_features',
 '_get_aa_indice_features_padding_zeros',
 '_get_aa_mod_features',
 '_get_features_from_batch_df',
 '_get_lr_schedule_with_warmup',
 '_get_mod_features',
 '_get_targets_from_batch_df',
 '_init_for_training',
 '_load_model_from_pytorchfile',
 '_load_model_from_stream',
 '_load_model_from_zipfile',
 '_min_pred_value',
 '_model_to_device',
 '_pad_zeros_if_fixed_len',
 '_predict_in_order',
 '_predict_one_batch',
 '_prepare_predict_data_df'

In [None]:
print(pretrained_model.model_params)

{'fixed_sequence_len': 0, 'min_pred_value': 0.0, 'dropout': 0.1}
