# Library import

In [37]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold, KFold, train_test_split
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, average_precision_score, accuracy_score

import sys
import os

# from google.colab import drive # delete this line if not used on colab
# drive.mount("/content/drive", force_remount = True) # delete this line if not used on colab

# sys.path.append("C:/Users/Walraff/OneDrive - Universite de Liege/Documents/Ulg/Master2/TFE/")
# sys.path.append("/content/drive/MyDrive/TFE/")
import utils

import wandb

In [39]:
# data_path = "C:/Users/Walraff/OneDrive - Universite de Liege/Documents/Ulg/Master2/TFE/data"
# data_path = "/content/drive/MyDrive/TFE"
data_path = "/home/jwalraff/TFE/data"
original_df = pd.read_csv(f'{data_path}/final_status_SPARE.csv')
original_df

Unnamed: 0,ProteinName_SPARE,Peptide_SPARE,Status_SPARE
0,sp|P02751|FINC_HUMAN,VDVIPVNLPGEHGQR,bon
1,sp|P02751|FINC_HUMAN,STTPDITGYR,bon
2,sp|P02751|FINC_HUMAN,SYTITGLQPGTDYK,bon
3,sp|P02751|FINC_HUMAN,IYLYTLNDNAR,bon
4,sp|P04114|APOB_HUMAN,TGISPLALIK,bon
...,...,...,...
150,sp|P02743|SAMP_HUMAN,VGEYSLYIGR,bon
151,sp|P04004|VTNC_HUMAN,GQYCYELDEK,mauvais
152,sp|P04004|VTNC_HUMAN,FEDGVLDPDYPR,bon
153,sp|P04004|VTNC_HUMAN,DWHGVPGQVDAAMAGR,bon


# Dataset, Model and loading data

In [None]:
# Define the vocabulary
amino_acid_vocab = {aa: idx+1 for idx, aa in enumerate("ACDEFGHIKLMNPQRSTVWY")}

class PeptideDataset(Dataset):
    def __init__(self, sequences, labels, vocab, max_len):
        """
        Dataset for peptide sequences.

        Args:
            sequences (list of str): List of amino acid sequences.
            labels (list of int): List of labels associated with the sequences.
            vocab (dict): Mapping dictionary {amino acid: index}.
            max_len (int): Maximum sequence length (applies padding).
        """
        self.sequences = sequences
        self.labels = labels
        self.vocab = vocab
        self.max_len = max_len

    def encode_sequence(self, sequence):
        """Encodes a sequence into integer indices with padding."""
        encoded = [self.vocab.get(aa, 0) for aa in sequence]  # Encoding
        encoded += [0] * (self.max_len - len(encoded))  # Padding
        return torch.tensor(encoded, dtype=torch.long)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        """Returns an encoded sequence and its label."""
        sequence = self.sequences[idx]
        label = self.labels[idx]
        encoded_sequence = self.encode_sequence(sequence)
        return encoded_sequence, torch.tensor(label, dtype=torch.float32)


In [41]:
LEARNING_RATE = 0.001
VOCAB_SIZE = len(amino_acid_vocab) + 1  # +1 for padding
EMBED_DIM = 64
HIDDEN_DIM = 16
OUTPUT_DIM = 1
N_EPOCHS = 100
BATCH_SIZE = 32
LSTM_LAYERS = 2
LSTM_HIDDEN_DIM = 32
N_SPLITS = 5
DROPOUT = 0.3

In [None]:
# Creating a dataframe with the sequences and labels
df = pd.DataFrame()
df["sequence"] = original_df["Peptide_SPARE"]
df["quantotypic"] = original_df.apply(lambda row: 0 if row['Status_SPARE'] == 'bon' else 1, axis=1)

positive_df = df[df['quantotypic'] == 0]
negative_df = df[df['quantotypic'] == 1]

class_counts = df['quantotypic'].value_counts()
max_len = df['sequence'].str.len().max()
print(class_counts)

quantotypic
0    117
1     38
Name: count, dtype: int64


In [None]:
class PeptideBiLSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, lstm_hidden_dim, output_dim, dropout_prob, num_layers):
        """
        Bidirectional LSTM-based classifier for peptide sequences.

        Args:
            vocab_size (int): Size of the vocabulary (number of amino acid + padding).
            embed_dim (int): Size of the embedding vectors.
            hidden_dim (int): Number of units in the fully connected hidden layer.
            lstm_hidden_dim (int): Number of hidden units in the BiLSTM.
            output_dim (int): Number of output classes (1 here).
            dropout_prob (float): Dropout rate applied after the first fully connected layer.
            num_layers (int): Number of LSTM layers.
        """
        super(PeptideBiLSTMClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.bilstm = nn.LSTM(embed_dim, lstm_hidden_dim, num_layers=num_layers,
                              bidirectional=True, batch_first=True, dropout=dropout_prob)
        self.fc1 = nn.Linear(2 * lstm_hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x):
        """
        Forward pass of the model.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, max_seq_len),
                              containing integer-encoded peptide sequences.

        Returns:
            torch.Tensor: Output logits of shape (batch_size, output_dim).
        """
        # Embed input sequences
        embedded = self.embedding(x)  # (batch_size, max_seq_len, embed_dim)

        # BiLSTM encoding
        lstm_out, _ = self.bilstm(embedded)  # (batch_size, max_seq_len, 2 * lstm_hidden_dim)

        # Use the output of the last time step
        lstm_out_last = lstm_out[:, -1, :]  # (batch_size, 2 * lstm_hidden_dim)

        # Fully connected layers with dropout
        hidden = F.relu(self.fc1(lstm_out_last))
        hidden = self.dropout(hidden)
        output = self.fc2(hidden)  # (batch_size, output_dim)

        return output


# Functions

In [None]:
def early_stopping(val_losses, patience):
    """
    Checks if the last 'patience' validation losses have increased or remained the same.

    Args:
        val_losses (list): List of validation losses.
        patience (int): Number of consecutive epochs to check for lack of improvement.

    Returns:
        bool: True if the last 'patience' losses show no improvement, otherwise False.
    """
    # Not enough values to apply the early stopping criterion
    if len(val_losses) < patience + 1:
        return False

    # Check if each of the last 'patience' losses is greater than or equal to the previous one
    for i in range(-patience, -1):
        if val_losses[i] < val_losses[i + 1]:
            continue
        else:
            return False

    return True

In [None]:
def reordering(data, labels, batch_converter, device):
    """
    Prepares and converts a batch of sequence data and labels for ESM-2 model input.

    Args:
        data (tuple): Tuple containing two lists (e.g., sequence names and sequences).
        labels (torch.Tensor): Tensor of labels associated with each sequence.
        batch_converter (callable): Function to convert (name, sequence) pairs into tokenized format.
        device (torch.device): Device to move the tensors to (e.g., 'cuda' or 'cpu').

    Returns:
        tuple:
            - batch_tokens (torch.Tensor): Tensor of tokenized sequences.
            - batch_labels (torch.Tensor): Tensor of corresponding labels (shape: [batch_size]).
    """
    # Reformat data into a list of (name, sequence) tuples
    ReArrangeData = []
    for i in range(len(data[0])):
        ReArrangeData.append((data[0][i], data[1][i]))

    # Move labels to device and ensure float format
    labels = labels.float()
    batch_labels = labels.to(device).unsqueeze(1)

    # Convert sequences to tokens using the batch converter
    batch_tokens = batch_converter(ReArrangeData)[2].to(device)

    return batch_tokens, batch_labels.squeeze(1)

def train_epoch(model, dataloader, optimizer, criterion, device, model_name='Scratch', batch_converter=None):
    """
    Trains the model for a single epoch.

    Args:
        model (nn.Module): The model to train.
        dataloader (DataLoader): DataLoader providing training data batches.
        optimizer (torch.optim.Optimizer): Optimizer used for updating model weights.
        criterion (nn.Module): Loss function to compute the training loss.
        device (str or torch.device): Device on which computations are performed ('cuda' or 'cpu').
        model_name (str): Specifies the model type (e.g., 'Scratch', 'ESM', 'RT', 'CCS').
        batch_converter (callable, optional): Function used to convert ESM data batches.

    Returns:
        float: Mean training loss over the entire epoch.
    """
    model.train()
    train_losses = []

    for batch in dataloader:
        # Prepare inputs and targets based on the model type
        if model_name == 'ESM':
            X, y = batch
            X, y = reordering(X, y, batch_converter, device)
        elif model_name in ['RT', 'CCS']:
            aa_idx, mod_x, y = batch
            aa_idx, mod_x, y = aa_idx.to(device), mod_x.to(device), y.to(device)
        else:
            X, y = batch
            X, y = X.to(device), y.to(device)

        # Reset gradients
        optimizer.zero_grad()

        # Forward pass
        if model_name in ['RT', 'CCS']:
            output = model(aa_idx, mod_x).squeeze(1)
        else:
            output = model(X).squeeze(1)

        # Compute loss
        loss = criterion(output, y)

        # Backward pass and optimizer step
        loss.backward()
        optimizer.step()

        # Record the loss
        train_losses.append(loss.item())
    
    # wandb.log({"training_loss": np.array(train_losses).mean()})

    return np.mean(train_losses)


def val_epoch(model, dataloader, criterion, device, threshold=0.5, model_name='Scratch', batch_converter=None):
    """
    Validate the model for a single epoch and compute additional metrics.

    Args:
        model (nn.Module): The model to validate.
        dataloader (Dataloader): Dataloader providing validation data.
        criterion (nn.Module): Loss function.
        device (str): Device to use ('cuda' or 'cpu').
        threshold (float, optional): Threshold for binary classification. Defaults to 0.5.

    Returns:
        dict: Dictionary with mean loss, accuracy, precision, recall, F1, and AUC-ROC.
    """
    model.eval()  # Set the model to evaluation mode
    val_losses = []
    all_targets = []
    all_predictions = []
    all_probabilities = []

    with torch.no_grad():  # Disable gradient computation
        for batch in dataloader:
            if model_name == 'ESM':
                X, y = batch
                X, y = reordering(X, y, batch_converter, device)
            elif model_name == 'RT' or model_name == 'CCS':
                aa_idx, mod_x, y = batch
                aa_idx, mod_x, y = aa_idx.to(device), mod_x.to(device), y.to(device)
            else:
                X, y = batch
                X, y = X.to(device), y.to(device)

            # Forward pass
            if model_name == 'RT' or model_name == 'CCS':
                output = model(aa_idx, mod_x).squeeze(1)
            else:
                output = model(X).squeeze(1)

            probabilities = torch.sigmoid(output)  # Convert logits to probabilities

            # Compute loss
            loss = criterion(output, y)
            val_losses.append(loss.item())

            # Convert probabilities to binary predictions
            binary_output = (probabilities >= threshold).float()

            # Store all targets, predictions, and probabilities
            all_targets.extend(y.cpu().numpy())
            all_predictions.extend(binary_output.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())

    # Compute mean loss
    mean_loss = np.array(val_losses).mean()

    # Compute metrics using sklearn
    accuracy = (np.array(all_predictions) == np.array(all_targets)).mean()
    roc_auc = roc_auc_score(all_targets, all_probabilities) if len(np.unique(all_targets)) > 1 else 0.0
    pr_auc = average_precision_score(all_targets, all_probabilities) if len(np.unique(all_targets)) > 1 else 0.0

    # Log metrics with wandb
    # wandb.log({
    #     "validation_loss": mean_loss,
    #     "validation_accuracy": accuracy,
    #     "validation_roc_auc": roc_auc,
    #     "validation_pr_auc": pr_auc,
    # })

    return {
        "loss": mean_loss,
        "accuracy": accuracy,
        "roc_auc": roc_auc,
        "pr_auc": pr_auc
    }

def val_epoch(model, dataloader, criterion, device, threshold=0.5, model_name='Scratch', batch_converter=None):
    """
    Validates the model for one epoch and computes evaluation metrics.

    Args:
        model (nn.Module): The model to evaluate.
        dataloader (DataLoader): DataLoader providing validation data.
        criterion (nn.Module): Loss function used for evaluation.
        device (str or torch.device): Device on which computations are performed ('cuda' or 'cpu').
        threshold (float, optional): Threshold for binary classification (applied to predicted probabilities). Defaults to 0.5.
        model_name (str): Model type identifier (e.g., 'Scratch', 'ESM', 'RT', 'CCS').
        batch_converter (callable, optional): Function for converting ESM batches into tokenized format.

    Returns:
        dict: Dictionary containing the following metrics:
              - "loss": Mean validation loss.
              - "accuracy": Classification accuracy.
              - "roc_auc": Area under the ROC curve.
              - "pr_auc": Area under the precision-recall curve.
    """
    model.eval()
    val_losses = []
    all_targets = []
    all_predictions = []
    all_probabilities = []

    with torch.no_grad():
        for batch in dataloader:
            # Prepare inputs and targets based on model type
            if model_name == 'ESM':
                X, y = batch
                X, y = reordering(X, y, batch_converter, device)
            elif model_name in ['RT', 'CCS']:
                aa_idx, mod_x, y = batch
                aa_idx, mod_x, y = aa_idx.to(device), mod_x.to(device), y.to(device)
            else:
                X, y = batch
                X, y = X.to(device), y.to(device)

            # Forward pass
            if model_name in ['RT', 'CCS']:
                output = model(aa_idx, mod_x).squeeze(1)
            else:
                output = model(X).squeeze(1)

            probabilities = torch.sigmoid(output)  # Convert logits to probabilities
            loss = criterion(output, y)
            val_losses.append(loss.item())

            # Threshold probabilities to get binary predictions
            binary_output = (probabilities >= threshold).float()

            # Store results
            all_targets.extend(y.cpu().numpy())
            all_predictions.extend(binary_output.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())

    # Compute average loss and metrics
    mean_loss = np.mean(val_losses)
    accuracy = (np.array(all_predictions) == np.array(all_targets)).mean()
    roc_auc = roc_auc_score(all_targets, all_probabilities) if len(np.unique(all_targets)) > 1 else 0.0
    pr_auc = average_precision_score(all_targets, all_probabilities) if len(np.unique(all_targets)) > 1 else 0.0

    # Log metrics with wandb
    # wandb.log({
    #     "validation_loss": mean_loss,
    #     "validation_accuracy": accuracy,
    #     "validation_roc_auc": roc_auc,
    #     "validation_pr_auc": pr_auc,
    # })

    return {
        "loss": mean_loss,
        "accuracy": accuracy,
        "roc_auc": roc_auc,
        "pr_auc": pr_auc
    }


In [None]:
def train(model, train_loader, val_loader, n_epochs, lr, filename, threshold=0.5,
          do_early_stopping=True, model_name='Scratch', batch_converter=None, weight=None):
    """
    Trains the model over multiple epochs with validation after each epoch.

    Args:
        model (nn.Module): The model instance to train.
        train_loader (DataLoader): DataLoader providing training data.
        val_loader (DataLoader): DataLoader providing validation data.
        n_epochs (int): Number of training epochs.
        lr (float): Learning rate for the optimizer.
        filename (str): File name to save the final model.
        threshold (float, optional): Threshold for binary classification. Defaults to 0.5.
        do_early_stopping (bool, optional): Whether to apply early stopping. Defaults to True.
        model_name (str, optional): Identifier for model type (e.g., 'Scratch', 'ESM', 'RT', 'CCS').
        batch_converter (callable, optional): Function to convert ESM batches into token format.
        weight (torch.Tensor, optional): Weight for positive class in BCEWithLogitsLoss.

    Returns:
        dict: Training history containing:
            - "train_loss": List of training losses per epoch.
            - "val_loss": List of validation losses per epoch.
            - "val_accuracy": List of validation accuracies.
            - "val_roc_auc": List of ROC-AUC scores.
            - "val_pr_auc": List of PR-AUC scores.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.cuda.empty_cache()
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    if weight is not None:
        criterion = nn.BCEWithLogitsLoss(pos_weight=weight)
    else:
        criterion = nn.BCEWithLogitsLoss()

    history = {
        "train_loss": [],
        "val_loss": [],
        "val_accuracy": [],
        "val_roc_auc": [],
        "val_pr_auc": [],
    }

    for epoch in range(n_epochs):
        # Train for one epoch
        train_loss = train_epoch(
            model, train_loader, optimizer, criterion, device,
            model_name=model_name, batch_converter=batch_converter
        )

        # Validate after each epoch
        val_metrics = val_epoch(
            model, val_loader, criterion, device, threshold,
            model_name=model_name, batch_converter=batch_converter
        )

        # Record metrics
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_metrics["loss"])
        history["val_accuracy"].append(val_metrics["accuracy"])
        history["val_roc_auc"].append(val_metrics["roc_auc"])
        history["val_pr_auc"].append(val_metrics["pr_auc"])

        # Early stopping
        if early_stopping(history["val_loss"], patience=5) and do_early_stopping:
            print(f"Early stopping triggered at epoch {epoch}.")
            break

    # wandb.log({"Early_stopping": epoch})

    return history

In [None]:
def compute_pos_weight(df, target_column):
    """
    Computes the pos_weight value for BCEWithLogitsLoss based on class imbalance.

    Args:
        df (pd.DataFrame): DataFrame containing the target labels.
        target_column (str): Name of the column with binary class labels (0 for negative, 1 for positive).

    Returns:
        torch.Tensor: A tensor representing the positive class weight for use in BCEWithLogitsLoss.
                      If no positive samples are found, defaults to a weight of 1.0.
    """
    class_counts = df[target_column].value_counts()

    num_pos = class_counts.get(1, 0)  # Number of positive samples
    num_neg = class_counts.get(0, 0)  # Number of negative samples

    if num_pos == 0:
        pos_weight = torch.tensor(1.0, dtype=torch.float64)
    else:
        pos_weight = torch.tensor(num_neg / num_pos, dtype=torch.float64)

    return pos_weight

compute_pos_weight(df, 'quantotypic')

tensor(3.0789, dtype=torch.float64)

In [None]:
def compute_best(history, best_history, model, best_model):
    """
    Compares the current model to the best model so far based on validation PR AUC,
    and updates the best model and history if the current model performs better.

    Args:
        history (dict): Dictionary containing the current training and validation metrics.
        best_history (dict): Dictionary containing the best validation metrics so far.
        model (nn.Module): Current model instance.
        best_model (nn.Module): Best model instance so far.

    Returns:
        tuple: Updated best model and its corresponding history.
    """
    if best_history is None or (history["val_pr_auc"][-1] > best_history["val_pr_auc"][-1]):
        best_model = model
        best_history = history

    return best_model, best_history

def evaluate_bilstm_model(model, test_df, amino_acid_vocab, max_len):
    """
    Evaluates a trained BiLSTM model on the test dataset and returns performance metrics.

    Args:
        model (nn.Module): The trained model to evaluate.
        test_df (pd.DataFrame): DataFrame containing test sequences and labels.
        amino_acid_vocab (dict): Mapping from amino acids to integer indices.
        max_len (int): Maximum sequence length (used for padding/truncation).

    Returns:
        dict: Dictionary containing the following evaluation metrics:
              - 'accuracy': Classification accuracy.
              - 'precision': Precision score.
              - 'recall': Recall score.
              - 'f1': F1 score.
              - 'roc_auc': ROC AUC score.
              - 'pr_auc': Precision-Recall AUC score.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu" # Get the device
    model.eval()
    model.to(device) # Move the model to the device
    test_dataset = PeptideDataset(test_df['sequence'].values, test_df['quantotypic'].values, amino_acid_vocab, max_len)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    all_preds = []
    all_targets = []

    with torch.no_grad():
        for sequences, labels in test_loader:
            sequences = sequences.to(device) # Move the sequences to the device
            labels = labels.to(device) # Move the labels to the device
            outputs = model(sequences)
            preds = torch.sigmoid(outputs).cpu().numpy()
            all_preds.extend(preds)
            all_targets.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds).flatten()
    all_targets = np.array(all_targets)

    # Convertir les probabilités en classes pour l'accuracy
    predicted_classes = (all_preds >= 0.5).astype(int)

    accuracy = accuracy_score(all_targets, predicted_classes)
    precision = precision_score(all_targets, predicted_classes)
    recall = recall_score(all_targets, predicted_classes)
    f1 = f1_score(all_targets, predicted_classes)
    roc_auc = roc_auc_score(all_targets, all_preds)
    pr_auc = average_precision_score(all_targets, all_preds)

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'roc_auc': roc_auc,
        'pr_auc': pr_auc
    }

def write_into_json(dict, filename):
    import json
    with open(filename, 'w') as f:
        json.dump(dict, f)

import warnings
warnings.simplefilter("ignore", category=UserWarning)

# Training

In [None]:
LEARNING_RATE = 0.001
VOCAB_SIZE = len(amino_acid_vocab) + 1  # +1 for padding
EMBED_DIM = 64
HIDDEN_DIM = 16
OUTPUT_DIM = 1
N_EPOCHS = 100
BATCH_SIZE = 32
LSTM_LAYERS = 2
LSTM_HIDDEN_DIM = 32
N_SPLITS = 5
DROPOUT = 0.3

# Define paths and useful variables
count = 1
do_early_stopping = False
models_path = "/home/jwalraff/TFE/Models"
results_path = "/home/jwalraff/TFE/Results"
images_path = "/home/jwalraff/TFE/Images"

imbalanced_name_dict = {
    "Imbalance": {
        (False, False): "N",
        (True, False): "W",
        (False, True): "O"
    },
    "Early Stopping": {
        True: "ES",
        False: "NES"
    },
    "Transfer Learning": {
        True: "FT",
        False: "FE",
        None: "N"
    }
}

# Stratified K-Fold
kf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=42)

# Loop over all configurations of variant (here only different class imbalance technique) in order to compute all in once
for weighting, oversampling in [(False, False), (True, False), (False, True)]:
    temp_results_list = []
    # Loop over the folds
    for fold, (train_idx, test_idx) in enumerate(kf.split(df, df['quantotypic'])):
        # Split the data and creating dataloaders
        print(f"Fold {fold+1} on 5")
        tmp_train_df = df.iloc[train_idx]
        test_df = df.iloc[test_idx]

        train_df, val_df = train_test_split(tmp_train_df, test_size=0.2, stratify=tmp_train_df['quantotypic'], random_state=42)

        if oversampling:
            train_df = utils.balance_classes_with_oversampling(train_df)

        train_dataset = PeptideDataset(train_df['sequence'].values, train_df['quantotypic'].values, amino_acid_vocab, max_len)
        val_dataset = PeptideDataset(val_df['sequence'].values, val_df['quantotypic'].values, amino_acid_vocab, max_len)

        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        validation_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

        pos_weight = compute_pos_weight(train_df, 'quantotypic') if weighting else None

        best_model = None
        best_history = None

        # GridSearch for hyperparameters
        config_count = 1
        for LR in [0.0001, 0.001, 0.01, 0.1]:
            for n_layers in [1, 2, 3]:
                for dropout in [0.2, 0.3]:
                    for hidden_size, embed_size in [(64, 64), (128, 128)]:
                        for n_epoch in [25, 50, 100]:
                            print(f"Testing config {config_count} on {4*3*2*2*3}")
                            config_count += 1
                            model = PeptideBiLSTMClassifier(VOCAB_SIZE, embed_size, hidden_size, hidden_size, OUTPUT_DIM, dropout, n_layers)

                            # wandb.init(
                            #     project="BiLSTM",
                            #     name=f"BiLSTM_Scratch_config_hyperparameters",
                            #     config={
                            #         "learning_rate": LR,
                            #         "architecture": "BiLSTM",
                            #         "dataset": "SPARE",
                            #         "epochs": n_epoch,
                            #         "Batch_size": batch_size,
                            #         "Dropout": dropout,
                            #         "Hidden_size": hidden_size,
                            #         "Embedding_size": embed_size,
                            #         "Number of layers": n_layers,
                            #         "Oversampling": oversampling,
                            #         "Weighting": weighting,
                            #         "Early stopping": do_early_stopping
                            #     }
                            # )

                            history = train(model, train_loader, validation_loader, n_epoch, LR, "test", do_early_stopping=do_early_stopping,
                                            model_name='Scratch', weight=pos_weight)

                            best_model, best_history = compute_best(history, best_history, model, best_model)

                            # wandb.finish()


        print("\n--------------------------------------Config Done--------------------------------------------\n")
        results = evaluate_bilstm_model(best_model, test_df, amino_acid_vocab, max_len)

        temp_results_list.append({
                'accuracy': results['accuracy'],
                'precision': results['precision'],
                'recall': results['recall'],
                'f1': results['f1'],
                'roc_auc': results['roc_auc'],
                'pr_auc': results['pr_auc']
            })

        wandb.init(project="BiLSTM", name=f"ValidProtocol_BiLSTM_Scratch_fold_{fold+1}",
                    config={
                        "weighting": weighting,
                        "oversampling": oversampling,
                        "early_stopping": do_early_stopping
                    })

        wandb.log(results)
        wandb.finish()
        
        # Save the best model
        text = f"{imbalanced_name_dict['Early Stopping'][False]}_{imbalanced_name_dict['Transfer Learning'][None]}_{imbalanced_name_dict['Imbalance'][(weighting, oversampling)]}"
        filename = f"Best_BiLSTM_{text}_cvfold{fold+1}"
        torch.save(best_model.state_dict(), f"{models_path}/{filename}.pth")

    accuracy_list = [res['accuracy'] for res in temp_results_list]
    precision_list = [res['precision'] for res in temp_results_list]
    recall_list = [res['recall'] for res in temp_results_list]
    f1_list = [res['f1'] for res in temp_results_list]
    roc_auc_list = [res['roc_auc'] for res in temp_results_list]
    pr_auc_list = [res['pr_auc'] for res in temp_results_list]

    metrics_summary = {
        'accuracy_mean': np.mean(accuracy_list),
        'accuracy_std': np.std(accuracy_list),
        'precision_mean': np.mean(precision_list),
        'precision_std': np.std(precision_list),
        'recall_mean': np.mean(recall_list),
        'recall_std': np.std(recall_list),
        'f1_mean': np.mean(f1_list),
        'f1_std': np.std(f1_list),
        'roc_auc_mean': np.mean(roc_auc_list),
        'roc_auc_std': np.std(roc_auc_list),
        'pr_auc_mean': np.mean(pr_auc_list),
        'pr_auc_std': np.std(pr_auc_list),
    }

    wandb.init(project="BiLSTM", name=f"ValidProtocol_BiLSTM_Scratch_CrossValidation_Summary",
                config={
                    "weighting": weighting,
                    "oversampling": oversampling,
                    "early_stopping": do_early_stopping
                })
    wandb.log(metrics_summary)
    wandb.finish()

    dict_save = {
        'weighting': weighting,
        'oversampling': oversampling,
        'early_stopping': do_early_stopping,
        'metrics': metrics_summary
    }

    # Save the results in a JSON file
    write_into_json(dict_save, f"{results_path}/ValidProtocol_BiLSTM_Experiment_{count}.json")
    count += 1

Fold 1 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjimmy-walraff02[0m ([33mTFE-proteomics[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111273252301746, max=1.0)…

VBox(children=(Label(value='0.018 MB of 0.018 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.64516
f1,0.26667
pr_auc,0.31459
precision,0.25
recall,0.28571
roc_auc,0.59524


Fold 2 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

VBox(children=(Label(value='0.015 MB of 0.018 MB uploaded\r'), FloatProgress(value=0.7970033583053474, max=1.0…

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.64516
f1,0.26667
pr_auc,0.40924
precision,0.25
recall,0.28571
roc_auc,0.64286


Fold 3 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

VBox(children=(Label(value='0.014 MB of 0.018 MB uploaded\r'), FloatProgress(value=0.7631837706360296, max=1.0…

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.51613
f1,0.0
pr_auc,0.21814
precision,0.0
recall,0.0
roc_auc,0.36413


Fold 4 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

VBox(children=(Label(value='0.018 MB of 0.018 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.67742
f1,0.44444
pr_auc,0.3369
precision,0.4
recall,0.5
roc_auc,0.54891


Fold 5 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

VBox(children=(Label(value='0.018 MB of 0.018 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.74194
f1,0.0
pr_auc,0.38805
precision,0.0
recall,0.0
roc_auc,0.54348


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112283873889182, max=1.0…

VBox(children=(Label(value='0.004 MB of 0.019 MB uploaded\r'), FloatProgress(value=0.2222902633190447, max=1.0…

0,1
accuracy_mean,▁
accuracy_std,▁
f1_mean,▁
f1_std,▁
pr_auc_mean,▁
pr_auc_std,▁
precision_mean,▁
precision_std,▁
recall_mean,▁
recall_std,▁

0,1
accuracy_mean,0.64516
accuracy_std,0.07356
f1_mean,0.19556
f1_std,0.17236
pr_auc_mean,0.33339
pr_auc_std,0.06692
precision_mean,0.18
precision_std,0.15684
recall_mean,0.21429
recall_std,0.19166


Fold 1 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

VBox(children=(Label(value='0.018 MB of 0.018 MB uploaded\r'), FloatProgress(value=0.9859416994004548, max=1.0…

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.22581
f1,0.36842
pr_auc,0.50746
precision,0.22581
recall,1.0
roc_auc,0.66071


Fold 2 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

VBox(children=(Label(value='0.018 MB of 0.018 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.22581
f1,0.36842
pr_auc,0.43645
precision,0.22581
recall,1.0
roc_auc,0.69643


Fold 3 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

VBox(children=(Label(value='0.018 MB of 0.018 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.51613
f1,0.11765
pr_auc,0.25427
precision,0.11111
recall,0.125
roc_auc,0.41848


Fold 4 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

VBox(children=(Label(value='0.018 MB of 0.018 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.74194
f1,0.0
pr_auc,0.33023
precision,0.0
recall,0.0
roc_auc,0.53804


Fold 5 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

VBox(children=(Label(value='0.018 MB of 0.018 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.77419
f1,0.22222
pr_auc,0.32741
precision,1.0
recall,0.125
roc_auc,0.39674


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112533219986492, max=1.0…

VBox(children=(Label(value='0.004 MB of 0.019 MB uploaded\r'), FloatProgress(value=0.2223639889762172, max=1.0…

0,1
accuracy_mean,▁
accuracy_std,▁
f1_mean,▁
f1_std,▁
pr_auc_mean,▁
pr_auc_std,▁
precision_mean,▁
precision_std,▁
recall_mean,▁
recall_std,▁

0,1
accuracy_mean,0.49677
accuracy_std,0.23845
f1_mean,0.21534
f1_std,0.14341
pr_auc_mean,0.37117
pr_auc_std,0.08955
precision_mean,0.31254
precision_std,0.35381
recall_mean,0.45
recall_std,0.45139


Fold 1 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

VBox(children=(Label(value='0.018 MB of 0.018 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.70968
f1,0.30769
pr_auc,0.36675
precision,0.33333
recall,0.28571
roc_auc,0.71429


Fold 2 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

VBox(children=(Label(value='0.004 MB of 0.018 MB uploaded\r'), FloatProgress(value=0.22830643470970446, max=1.…

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.77419
f1,0.36364
pr_auc,0.6128
precision,0.5
recall,0.28571
roc_auc,0.76786


Fold 3 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

VBox(children=(Label(value='0.004 MB of 0.018 MB uploaded\r'), FloatProgress(value=0.225129265770424, max=1.0)…

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.41935
f1,0.1
pr_auc,0.19345
precision,0.08333
recall,0.125
roc_auc,0.26087


Fold 4 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

VBox(children=(Label(value='0.015 MB of 0.018 MB uploaded\r'), FloatProgress(value=0.7970348176464511, max=1.0…

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.67742
f1,0.54545
pr_auc,0.5651
precision,0.42857
recall,0.75
roc_auc,0.74457


Fold 5 on 5
Testing config 1 on 144
Testing config 2 on 144
Testing config 3 on 144
Testing config 4 on 144
Testing config 5 on 144
Testing config 6 on 144
Testing config 7 on 144
Testing config 8 on 144
Testing config 9 on 144
Testing config 10 on 144
Testing config 11 on 144
Testing config 12 on 144
Testing config 13 on 144
Testing config 14 on 144
Testing config 15 on 144
Testing config 16 on 144
Testing config 17 on 144
Testing config 18 on 144
Testing config 19 on 144
Testing config 20 on 144
Testing config 21 on 144
Testing config 22 on 144
Testing config 23 on 144
Testing config 24 on 144
Testing config 25 on 144
Testing config 26 on 144
Testing config 27 on 144
Testing config 28 on 144
Testing config 29 on 144
Testing config 30 on 144
Testing config 31 on 144
Testing config 32 on 144
Testing config 33 on 144
Testing config 34 on 144
Testing config 35 on 144
Testing config 36 on 144
Testing config 37 on 144
Testing config 38 on 144
Testing config 39 on 144
Testing config 40 on 1

VBox(children=(Label(value='0.015 MB of 0.018 MB uploaded\r'), FloatProgress(value=0.7981988509911495, max=1.0…

0,1
accuracy,▁
f1,▁
pr_auc,▁
precision,▁
recall,▁
roc_auc,▁

0,1
accuracy,0.67742
f1,0.16667
pr_auc,0.35702
precision,0.25
recall,0.125
roc_auc,0.5


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112323693103261, max=1.0…

VBox(children=(Label(value='0.004 MB of 0.019 MB uploaded\r'), FloatProgress(value=0.22169498443798152, max=1.…

0,1
accuracy_mean,▁
accuracy_std,▁
f1_mean,▁
f1_std,▁
pr_auc_mean,▁
pr_auc_std,▁
precision_mean,▁
precision_std,▁
recall_mean,▁
recall_std,▁

0,1
accuracy_mean,0.65161
accuracy_std,0.12139
f1_mean,0.29669
f1_std,0.15626
pr_auc_mean,0.41902
pr_auc_std,0.15254
precision_mean,0.31905
precision_std,0.1451
recall_mean,0.31429
recall_std,0.22941
