In [1]:
import os
import pickle
import math
import random
import gc
import matplotlib.colors as mcolors
from math import ceil
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import (accuracy_score, balanced_accuracy_score,
                             precision_recall_fscore_support, roc_auc_score,
                             average_precision_score, confusion_matrix,
                             precision_recall_curve, roc_curve)
from sklearn.metrics import f1_score
from sklearn.metrics import matthews_corrcoef
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap

In [2]:
folds_dir = "/content/folds"
os.makedirs(folds_dir, exist_ok=True)
print("Directory created at:", folds_dir)

Directory created at: /content/folds


In [3]:
# Paths of files coming from notebooks 1 and 2
dict_path = '/content/lncrna-pseudo_dictionary_pooled_embeddings_RNAFM.p'   # embeddings pickle file
folds_dir = f"/content/folds" # dataset file, divided in train/test

In [4]:
# Cross-fold training script.
# MAX AND AVG POOLING
# ----------------------------
# Hyperparameters
# ----------------------------
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Original hyperparameters (kept)
input_dimension = 2560  # final input expected (we construct from embeddings)
batch_size = 512
train_negative_reduction_factor = 1   # set to >1 to downsample negatives
validation_split_ratio = 0.2
epochs = 50
learning_rate = 0.0005
num_layers = 4
dropout = 0.2
hidden_dimension = 1024
warmup_epochs = 4
patience = 10
min_lr = 5e-5


out_models_dir = "models_by_fold_lncrna_pseudo_denovo_concat"
os.makedirs(out_models_dir, exist_ok=True)
save_metrics_path = "cv_training_results_full_lncrna_pseudo_denovo_concat.pkl"
save_plots_dir = "cv_plots/lncrna_pseudo_denovo/concat_pool"
os.makedirs(save_plots_dir, exist_ok=True)


# embeddings
if not os.path.exists(dict_path):
    raise FileNotFoundError(f"Embeddings dict not found at {dict_path}")
with open(dict_path, 'rb') as fh:
    embeddings_dict = pickle.load(fh)
print("Loaded embeddings dict with", len(embeddings_dict), "entries")


# Colors / fonts (kept)
color2 = '#9999CC'
color3 = '#00BBD8'
color_green = '#88B04B'
color_red = '#FF765B'
font_ticks = 16
font_labels = 18
heat0 = color_green
heat2 = '#204900'
heat3 = '#102A00'
cv_folds_path = ''

custom_cmap = LinearSegmentedColormap.from_list("soft_oranges", [heat0, heat2, heat3])

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)

# ----------------------------
# Load folds and embeddings
# ----------------------------
if os.path.isdir(folds_dir):
    files = sorted([p for p in os.listdir(folds_dir) if p.endswith('.pkl')])
    folds = []
    for fn in files:
        with open(os.path.join(folds_dir, fn), 'rb') as fh:
            folds.append(pickle.load(fh))
    print(f"Loaded {len(folds)} per-fold pickles from {folds_dir}")
else:
    raise FileNotFoundError(f"Neither {cv_folds_path} nor directory {folds_dir} found.")


# Convert pair lists into numpy arrays (embeddings) and labels
def pairs_to_arrays(pairs_list, embeddings_dict, label):
    X_list = []
    y_list = []
    missing = 0
    for (s1, s2), (t1, t2) in pairs_list:
        if s1 not in embeddings_dict or s2 not in embeddings_dict:
            missing += 1
            continue
        e1 = np.asarray(embeddings_dict[s1], dtype=np.float32)
        e2 = np.asarray(embeddings_dict[s2], dtype=np.float32)
        # keep last 640 of each like original notebook
        e1s = e1
        e2s = e2
        if e1s.shape[0] < 1280 or e2s.shape[0] < 1280:
            # if shorter than 640, pad with zeros to 640 to keep shape consistent
            e1s = np.pad(e1s, (1280 - e1s.shape[0], 0), mode='constant') if e1s.shape[0] < 1280 else e1s
            e2s = np.pad(e2s, (1280 - e2s.shape[0], 0), mode='constant') if e2s.shape[0] < 1280 else e2s
        X_list.append(np.concatenate([e1s, e2s], axis=0))
        y_list.append(int(label))
    if missing:
        print(f"  Skipped {missing} pairs due to missing embeddings.")
    if len(X_list) == 0:
        return None, None
    X = np.stack(X_list)
    y = np.array(y_list, dtype=np.int64)
    return X, y




# Balanced batch sampler (same logic as notebook)
def get_data_loaders_without_replacement_from_arrays(X_train, y_train, X_valid, y_valid, batch_size, neg_batch_ratio=0.7, num_workers=2):
    """
    Build DataLoaders using BalancedBatchSampler logic but from numpy arrays.
    """
    train_tensor = torch.from_numpy(X_train).float()
    train_labels = torch.from_numpy(y_train)
    train_dataset = TensorDataset(train_tensor, train_labels)

    valid_tensor = torch.from_numpy(X_valid).float()
    valid_labels = torch.from_numpy(y_valid)
    valid_dataset = TensorDataset(valid_tensor, valid_labels)

    class BalancedBatchSampler:
        def __init__(self, labels, batch_size, neg_batch_ratio):
            self.labels = labels.numpy() if torch.is_tensor(labels) else labels
            self.batch_size = batch_size
            self.neg_batch_ratio = neg_batch_ratio
            self.neg_indices = np.where(self.labels == 0)[0]
            self.pos_indices = np.where(self.labels == 1)[0]
            self.num_neg_per_batch = int(batch_size * neg_batch_ratio)
            self.num_batches = len(self.neg_indices) // max(1, self.num_neg_per_batch)
            self.leftover_negatives = len(self.neg_indices) % max(1, self.num_neg_per_batch)

        def __iter__(self):
            neg_idx = np.random.permutation(self.neg_indices)
            pos_idx = self.pos_indices
            batch_start = 0
            for batch_num in range(self.num_batches):
                if batch_num < self.num_batches - 1 or self.leftover_negatives == 0:
                    num_neg_in_batch = self.num_neg_per_batch
                else:
                    num_neg_in_batch = self.leftover_negatives
                neg_batch = neg_idx[batch_start:batch_start + num_neg_in_batch]
                batch_start += num_neg_in_batch
                num_pos_in_batch = self.batch_size - num_neg_in_batch
                pos_batch = np.random.choice(pos_idx, num_pos_in_batch, replace=True)
                batch_indices = np.concatenate([neg_batch, pos_batch])
                np.random.shuffle(batch_indices)
                yield batch_indices.tolist()

        def __len__(self):
            return self.num_batches if self.num_batches > 0 else 1

    train_sampler = BalancedBatchSampler(train_labels, batch_size, neg_batch_ratio=0.7)
    train_loader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=4)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    return train_loader, valid_loader

# InteractionNN (kept same)
class InteractionNN(nn.Module):
    def __init__(self, input_dim=input_dimension, hidden_dim=hidden_dimension, num_layers=num_layers, dropout=dropout):
        super(InteractionNN, self).__init__()
        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(p=dropout)]
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(p=dropout))
        self.hidden_layers = nn.Sequential(*layers)
        self.output_layer = nn.Linear(hidden_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.hidden_layers(x)
        x = self.output_layer(x)
        return torch.sigmoid(x).view(-1)

# learning rate scheduler with warmup
def cosine_annealing_with_warmup(epoch, warmup_epochs, max_epochs, min_lr=1e-5):
    if epoch < warmup_epochs:
        return (epoch + 1) / warmup_epochs
    else:
        cosine_decay = 0.5 * (1 + math.cos(math.pi * (epoch - warmup_epochs) / (max_epochs - warmup_epochs)))
        return (cosine_decay * (1 - min_lr) + min_lr)

# ----------------------------
# Cross-fold training
# ----------------------------

# collect confusion matrices across folds
cms = []
all_test_probs_all = []
all_test_labels_all = []
roc_curves = []
pr_curves = []

all_fold_metrics = []
for fi, fold in enumerate(folds):
    print("\n=== Fold", fi, "===")
    train_pos = fold['train']['positives']
    train_neg = fold['train']['negatives']
    test_pos = fold['test']['positives']
    test_neg = fold['test']['negatives']

    # Build full training arrays from pairs
    Xp, yp = pairs_to_arrays(train_pos, embeddings_dict, label=1)
    Xn, yn = pairs_to_arrays(train_neg, embeddings_dict, label=0)

    if Xp is None and Xn is None:
        print("No training data for fold", fi); continue
    if Xp is None:
        X_all = Xn; y_all = yn
    elif Xn is None:
        X_all = Xp; y_all = yp
    else:
        X_all = np.concatenate([Xp, Xn], axis=0)
        y_all = np.concatenate([yp, yn], axis=0)

    # Optionally reduce negatives like original script
    # Separate positives and negatives to apply reduction factor conveniently
    pos_idx = np.where(y_all == 1)[0]
    neg_idx = np.where(y_all == 0)[0]
    # Shuffle negatives
    np.random.shuffle(neg_idx)
    reduced_neg_count = max(1, int(len(neg_idx) // train_negative_reduction_factor))
    neg_selected = neg_idx[:reduced_neg_count]
    selected_idx = np.concatenate([pos_idx, neg_selected])
    np.random.shuffle(selected_idx)
    X_sel = X_all[selected_idx]
    y_sel = y_all[selected_idx]

    print(f" Train samples after reduction: {len(y_sel)} (pos={sum(y_sel==1)}, neg={sum(y_sel==0)})")

    # split into train/val
    n_total = len(y_sel)
    n_train = int((1 - validation_split_ratio) * n_total)
    n_val = n_total - n_train
    indices = np.arange(n_total)
    np.random.shuffle(indices)
    train_inds = indices[:n_train]
    val_inds = indices[n_train:]

    X_train = X_sel[train_inds]; y_train = y_sel[train_inds]
    X_val = X_sel[val_inds]; y_val = y_sel[val_inds]

    # Standardize per fold using training data
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_val = scaler.transform(X_val)

    # Prepare test set arrays
    Xp_test, yp_test = pairs_to_arrays(test_pos, embeddings_dict, label=1)
    Xn_test, yn_test = pairs_to_arrays(test_neg, embeddings_dict, label=0)
    if Xp_test is None and Xn_test is None:
        print("No test data for fold", fi); continue
    if Xp_test is None:
        X_test = Xn_test; y_test = yn_test
    elif Xn_test is None:
        X_test = Xp_test; y_test = yp_test
    else:
        X_test = np.concatenate([Xp_test, Xn_test], axis=0)
        y_test = np.concatenate([yp_test, yn_test], axis=0)
    X_test = scaler.transform(X_test)

    # Build DataLoaders using balanced sampler approach
    train_loader, val_loader = get_data_loaders_without_replacement_from_arrays(X_train, y_train, X_val, y_val, batch_size, neg_batch_ratio=0.5)

    test_ds = TensorDataset(torch.from_numpy(X_test).float(), torch.from_numpy(y_test))
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=4)

    # model, optimizer
    model = InteractionNN(input_dim=input_dimension, hidden_dim=hidden_dimension, num_layers=num_layers, dropout=dropout).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda e: cosine_annealing_with_warmup(e, warmup_epochs, epochs, min_lr))
    criterion = nn.BCELoss()

    # training loop with early stopping on validation loss
    best_val_loss = float('inf')
    best_state = None
    epochs_no_improv = 0

    train_losses = []
    val_losses = []
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for batch in train_loader:
            xb, yb = batch
            xb = xb.to(device)
            yb = yb.to(device, dtype=torch.float32)
            optimizer.zero_grad()
            outputs = model(xb).view(-1)
            loss = criterion(outputs, yb)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        train_loss = running_loss / (len(train_loader) if len(train_loader)>0 else 1)
        train_losses.append(train_loss)

        # validation
        model.eval()
        total_val_loss = 0.0
        all_val_probs = []
        all_val_labels = []
        with torch.no_grad():
            for batch in val_loader:
                xb, yb = batch
                xb = xb.to(device)
                yb = yb.to(device, dtype=torch.float32)
                outputs = model(xb).view(-1)
                loss = criterion(outputs, yb)
                total_val_loss += loss.item()
                all_val_probs.extend(outputs.cpu().numpy())
                all_val_labels.extend(yb.cpu().numpy())
        val_loss = total_val_loss / (len(val_loader) if len(val_loader)>0 else 1)
        val_losses.append(val_loss)

        # scheduler step & early stop logic
        if val_loss < best_val_loss - 1e-8:
            best_val_loss = val_loss
            best_state = model.state_dict()
            epochs_no_improv = 0
            # save best per-fold
            per_model_path = os.path.join(out_models_dir, f"best_model_fold{fi}.pth")
            torch.save(best_state, per_model_path)
        else:
            epochs_no_improv += 1

        scheduler.step()
        lr_now = optimizer.param_groups[0]['lr']
        print(f"Fold {fi} Epoch {epoch+1}/{epochs} TrainLoss={train_loss:.6f} ValLoss={val_loss:.6f} LR={lr_now:.2e}")

        if epochs_no_improv >= patience:
            print(f"Early stopping on fold {fi} at epoch {epoch+1}")
            break

    # restore best state
    if best_state is not None:
        model.load_state_dict(best_state)

    # Compute optimal threshold on validation (maximize MCC, same as your notebook)
    model.eval()
    val_probs = np.array(all_val_probs)
    val_labels = np.array(all_val_labels)
    thresholds = np.linspace(0, 1, 101)
    mcc_scores = []
    for thr in thresholds:
        preds_thr = (val_probs > thr).astype(np.int32)
        mcc_scores.append(matthews_corrcoef(val_labels, preds_thr))
    best_thr = thresholds[int(np.nanargmax(mcc_scores))]
    print(" Fold", fi, " optimal threshold (val MCC):", best_thr)

    # Evaluate on test set using best_thr
    all_test_probs = []
    all_test_labels = []
    with torch.no_grad():
        for xb, yb in test_loader:
            xb = xb.to(device)
            outputs = model(xb).view(-1)
            all_test_probs.extend(outputs.cpu().numpy())
            all_test_labels.extend(yb.numpy())
    all_test_probs = np.array(all_test_probs)
    all_test_labels = np.array(all_test_labels)
    test_preds = (all_test_probs > best_thr).astype(int)

    # compute metrics on test
    acc = accuracy_score(all_test_labels, test_preds)
    bal_acc = balanced_accuracy_score(all_test_labels, test_preds)
    prec, rec, f1, _ = precision_recall_fscore_support(all_test_labels, test_preds, average='binary', zero_division=0)
    try:
        auroc = roc_auc_score(all_test_labels, all_test_probs)
    except Exception:
        auroc = float('nan')
    try:
        auprc = average_precision_score(all_test_labels, all_test_probs)
    except Exception:
        auprc = float('nan')

    fold_metrics = {
        'fold': fi,
        'n_train': len(train_loader.dataset),
        'n_val': len(val_loader.dataset),
        'n_test': len(test_loader.dataset),
        'threshold': float(best_thr),
        'accuracy': float(acc),
        'balanced_accuracy': float(bal_acc),
        'precision': float(prec),
        'recall': float(rec),
        'f1': float(f1),
        'auroc': float(auroc),
        'auprc': float(auprc)
    }
    all_fold_metrics.append(fold_metrics)
    print(" Fold", fi, "metrics:", fold_metrics)

    # X boxplots and confusion matrices
    cm = confusion_matrix(all_test_labels, test_preds, normalize="true")
    cms.append(cm)
    all_test_probs_all.extend(all_test_probs)
    all_test_labels_all.extend(all_test_labels)

    try:
        fpr, tpr, _ = roc_curve(all_test_labels, all_test_probs)
        roc_curves.append((fi, fpr, tpr, auroc))
    except Exception:
        pass

    try:
        prec_vals, rec_vals, _ = precision_recall_curve(all_test_labels, all_test_probs)
        pr_curves.append((fi, rec_vals, prec_vals, auprc))
    except Exception:
        pass

    # clear memory
    del model, optimizer, scheduler, train_loader, val_loader, test_loader
    gc.collect()
    torch.cuda.empty_cache()

# ----------------------------
# Aggregate metrics across folds
# ----------------------------
df_res = pd.DataFrame(all_fold_metrics)
print("\nPer-fold results:")
display(df_res)

agg = df_res.drop(columns=['fold','n_train','n_val','n_test','threshold']).agg(['mean','std']).T
agg['mean'] = agg['mean'].round(4)
agg['std'] = agg['std'].round(4)
print("\nAggregated metrics (mean ± std):")
display(agg)

# =========================
# Plot metrics across folds
# =========================
colors = plt.cm.tab10.colors  # distinct colors
from matplotlib.backends.backend_pdf import PdfPages
metrics_to_plot = ["accuracy", "balanced_accuracy", "precision", "recall", "f1"]
# Custom display names
metric_display_names = {
    "accuracy": "ACCURACY",
    "balanced_accuracy": "BAL. ACC.",
    "precision": "PRECISION",
    "recall": "RECALL",
    "f1": "F1"
}

pdf_path = os.path.join(save_plots_dir, "metrics_summary_table.pdf")
with PdfPages(pdf_path) as pdf:
    plt.figure(figsize=(8, 4))
    ax = plt.gca()
    ax.axis('off')

    # Header row with display names
    header = ["Fold"] + [metric_display_names[m] for m in metrics_to_plot]
    table_data = [header]

    for _, row in df_res.iterrows():
        table_data.append(
            [int(row["fold"])] + [f"{row[m]:.3f}" for m in metrics_to_plot]
        )

    # Mean ± Std row
    mean_std_vals = [f"{df_res[m].mean():.3f} ± {df_res[m].std():.3f}" for m in metrics_to_plot]
    table_data.append(["Mean ± Std"] + mean_std_vals)

    # Draw table
    table = ax.table(cellText=table_data, loc="center", cellLoc="center")
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1.2, 1.2)

    pdf.savefig()
    plt.close()


# boxplot and normalised confusion matrix averaged across folds
cms = np.array(cms)
cm_mean = cms.mean(axis=0)
cm_std = cms.std(axis=0)

plt.figure(figsize=(6, 5))
ax = sns.heatmap(cm_mean, annot=True, fmt=".2%", cmap=custom_cmap,
                 xticklabels=['Negative', 'Positive'], yticklabels=['Negative', 'Positive'],
                 annot_kws={"size": font_ticks}, cbar_kws={"shrink": 0.9})

plt.xlabel('Predicted', fontsize=font_labels)
plt.ylabel('Actual', fontsize=font_labels)
plt.xticks(fontsize=font_ticks)
plt.yticks(fontsize=font_ticks)

cbar = ax.collections[0].colorbar
cbar.ax.yaxis.set_tick_params(labelsize=font_ticks)

plt.savefig(os.path.join(save_plots_dir, "confusion_matrix_normalized_mean_std.pdf"), bbox_inches="tight")
plt.close()

# boxoplot now
all_test_probs_all = np.array(all_test_probs_all)
all_test_labels_all = np.array(all_test_labels_all)

plt.figure(figsize=(8, 6))
box = plt.boxplot([all_test_probs_all[all_test_labels_all == 0],
                   all_test_probs_all[all_test_labels_all == 1]],
                  labels=["Negative", "Positive"], patch_artist=True)

# Set colors: red for negatives, green for positives
box['boxes'][0].set(facecolor=color_red)   # Negative
box['boxes'][1].set(facecolor=color_green) # Positive

# Median lines
for median in box['medians']:
    median.set(color='#FFD260', linewidth=2)

plt.ylabel("Predicted Probability", fontsize=font_labels)
plt.xticks(fontsize=font_ticks)
plt.yticks(fontsize=font_ticks)
plt.grid()

plt.savefig(os.path.join(save_plots_dir, "boxplot_probabilities_all_folds.pdf"), bbox_inches="tight")
plt.close()

# =========================
# Plot AUROC curves all folds
# =========================
plt.figure(figsize=(6, 5))
for fi, fpr, tpr, auroc in roc_curves:
    plt.plot(fpr, tpr, color=colors[fi % len(colors)],
             label=f"AUROC={auroc:.3f}")
plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
plt.xlabel('False Positive Rate', fontsize=font_labels)
plt.ylabel('True Positive Rate', fontsize=font_labels)
plt.grid(True)
plt.legend(fontsize=12, loc="lower right")  # bigger legend inside, bottom-right
plt.tight_layout()
plt.savefig(os.path.join(save_plots_dir, "all_folds_roc_curves.pdf"), bbox_inches="tight")
plt.close()

# =========================
# Plot AUPRC curves all folds
# =========================
plt.figure(figsize=(6, 5))
for fi, rec_vals, prec_vals, auprc in pr_curves:
    plt.plot(rec_vals, prec_vals, color=colors[fi % len(colors)],
             label=f"AUPRC={auprc:.3f}")
plt.xlabel('Recall', fontsize=font_labels)
plt.ylabel('Precision', fontsize=font_labels)
plt.grid(True)
plt.legend(fontsize=12, loc="lower left")  # bigger legend inside, bottom-left
plt.tight_layout()
plt.savefig(os.path.join(save_plots_dir, "all_folds_pr_curves.pdf"), bbox_inches="tight")
plt.close()

# Save metrics and per-fold models metadata
with open(save_metrics_path, 'wb') as fh:
    pickle.dump({'per_fold': all_fold_metrics, 'aggregate': agg.to_dict()}, fh)
print("Saved metrics to", save_metrics_path)

Loaded embeddings dict with 262 entries
Device: cuda
Loaded 1 per-fold pickles from /content/folds

=== Fold 0 ===
 Train samples after reduction: 19086 (pos=3472, neg=15614)




Fold 0 Epoch 1/50 TrainLoss=0.343408 ValLoss=0.137649 LR=2.50e-04
Fold 0 Epoch 2/50 TrainLoss=0.144254 ValLoss=0.112628 LR=3.75e-04
Fold 0 Epoch 3/50 TrainLoss=0.116640 ValLoss=0.132385 LR=5.00e-04
Fold 0 Epoch 4/50 TrainLoss=0.103295 ValLoss=0.118541 LR=5.00e-04
Fold 0 Epoch 5/50 TrainLoss=0.091342 ValLoss=0.105212 LR=4.99e-04
Fold 0 Epoch 6/50 TrainLoss=0.072733 ValLoss=0.149574 LR=4.98e-04
Fold 0 Epoch 7/50 TrainLoss=0.062397 ValLoss=0.097462 LR=4.95e-04
Fold 0 Epoch 8/50 TrainLoss=0.052935 ValLoss=0.103418 LR=4.91e-04
Fold 0 Epoch 9/50 TrainLoss=0.048461 ValLoss=0.112553 LR=4.86e-04
Fold 0 Epoch 10/50 TrainLoss=0.042739 ValLoss=0.112111 LR=4.79e-04
Fold 0 Epoch 11/50 TrainLoss=0.036200 ValLoss=0.121387 LR=4.72e-04
Fold 0 Epoch 12/50 TrainLoss=0.039029 ValLoss=0.143365 LR=4.64e-04
Fold 0 Epoch 13/50 TrainLoss=0.034459 ValLoss=0.113704 LR=4.54e-04
Fold 0 Epoch 14/50 TrainLoss=0.028340 ValLoss=0.144261 LR=4.44e-04
Fold 0 Epoch 15/50 TrainLoss=0.031335 ValLoss=0.148082 LR=4.33e-04
Fold



 Fold 0 metrics: {'fold': 0, 'n_train': 15268, 'n_val': 3818, 'n_test': 6205, 'threshold': 0.96, 'accuracy': 0.9911361804995971, 'balanced_accuracy': 0.9765179440286876, 'precision': 0.9799291617473436, 'recall': 0.956221198156682, 'f1': 0.967930029154519, 'auroc': 0.9964997422566916, 'auprc': 0.9868173280095116}

Per-fold results:


Unnamed: 0,fold,n_train,n_val,n_test,threshold,accuracy,balanced_accuracy,precision,recall,f1,auroc,auprc
0,0,15268,3818,6205,0.96,0.991136,0.976518,0.979929,0.956221,0.96793,0.9965,0.986817



Aggregated metrics (mean ± std):


Unnamed: 0,mean,std
accuracy,0.9911,
balanced_accuracy,0.9765,
precision,0.9799,
recall,0.9562,
f1,0.9679,
auroc,0.9965,
auprc,0.9868,


  box = plt.boxplot([all_test_probs_all[all_test_labels_all == 0],


Saved metrics to cv_training_results_full_lncrna_pseudo_denovo_concat.pkl


In [5]:
import os
import pickle
import math
import random
import gc
from math import ceil
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import (accuracy_score, balanced_accuracy_score,
                             precision_recall_fscore_support, roc_auc_score,
                             average_precision_score, confusion_matrix,
                             precision_recall_curve, roc_curve)
from sklearn.metrics import f1_score
from sklearn.metrics import matthews_corrcoef
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap

# ==========================================
# HYPERPARAMETER CONFIGURATIONS TO TEST
# ==========================================

hyperparameter_configs = [
    {
        'name': 'Baseline',
        'learning_rate': 0.0005,
        'num_layers': 4,
        'dropout': 0.2,
        'hidden_dimension': 1024,
        'batch_size': 512,
        'warmup_epochs': 4,
    },
    {
        'name': 'Deeper_HighDropout',
        'learning_rate': 0.0003,
        'num_layers': 6,
        'dropout': 0.3,
        'hidden_dimension': 1024,
        'batch_size': 512,
        'warmup_epochs': 5,
    },
    {
        'name': 'Wider_LowLR',
        'learning_rate': 0.0001,
        'num_layers': 4,
        'dropout': 0.2,
        'hidden_dimension': 2048,
        'batch_size': 512,
        'warmup_epochs': 6,
    },
    {
        'name': 'Smaller_HighLR',
        'learning_rate': 0.001,
        'num_layers': 3,
        'dropout': 0.15,
        'hidden_dimension': 512,
        'batch_size': 256,
        'warmup_epochs': 3,
    },
    {
        'name': 'Balanced',
        'learning_rate': 0.0002,
        'num_layers': 5,
        'dropout': 0.25,
        'hidden_dimension': 1536,
        'batch_size': 384,
        'warmup_epochs': 5,
    },
]

seed = 42
input_dimension = 2560
epochs = 50
min_lr = 5e-5
patience = 10
validation_split_ratio = 0.2
device = 'cuda' if torch.cuda.is_available() else 'cpu'

np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

color2 = '#9999CC'
color3 = '#00BBD8'
color_green = '#88B04B'
color_red = '#FF765B'
font_ticks = 16
font_labels = 18
heat0 = color_green
heat2 = '#204900'
heat3 = '#102A00'

custom_cmap = LinearSegmentedColormap.from_list("soft_oranges", [heat0, heat2, heat3])

blue_cmap = LinearSegmentedColormap.from_list(
    "soft_blues",
    ["#dceeff", "#6aaed6", "#1b6ca8", "#08306b"]
)

class InteractionNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout):
        super(InteractionNN, self).__init__()
        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(p=dropout)]
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(p=dropout))
        self.hidden_layers = nn.Sequential(*layers)
        self.output_layer = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = self.hidden_layers(x)
        x = self.output_layer(x)
        return torch.sigmoid(x).view(-1)

def cosine_annealing_with_warmup(epoch, warmup_epochs, max_epochs, min_lr=1e-5):
    if epoch < warmup_epochs:
        return (epoch + 1) / warmup_epochs
    else:
        cosine_decay = 0.5 * (1 + math.cos(math.pi * (epoch - warmup_epochs) / (max_epochs - warmup_epochs)))
        return (cosine_decay * (1 - min_lr) + min_lr)

class BalancedBatchSampler:
    def __init__(self, labels, batch_size, neg_batch_ratio=0.7):
        self.labels = labels.numpy() if torch.is_tensor(labels) else labels
        self.batch_size = batch_size
        self.neg_batch_ratio = neg_batch_ratio
        self.neg_indices = np.where(self.labels == 0)[0]
        self.pos_indices = np.where(self.labels == 1)[0]
        self.num_neg_per_batch = int(batch_size * neg_batch_ratio)
        self.num_batches = len(self.neg_indices) // max(1, self.num_neg_per_batch)

    def __iter__(self):
        neg_idx = np.random.permutation(self.neg_indices)
        pos_idx = self.pos_indices
        batch_start = 0
        for batch_num in range(self.num_batches):
            num_neg_in_batch = self.num_neg_per_batch
            neg_batch = neg_idx[batch_start:batch_start + num_neg_in_batch]
            batch_start += num_neg_in_batch
            num_pos_in_batch = self.batch_size - num_neg_in_batch
            pos_batch = np.random.choice(pos_idx, num_pos_in_batch, replace=True)
            batch_indices = np.concatenate([neg_batch, pos_batch])
            np.random.shuffle(batch_indices)
            yield batch_indices.tolist()

    def __len__(self):
        return self.num_batches if self.num_batches > 0 else 1

def train_with_config(config, X_train, y_train, X_val, y_val, X_test, y_test, fold_id=0, output_dir='hyperparameter_tuning_results'):
    print(f"\n{'='*60}")
    print(f"Training Config: {config['name']}")
    print(f"{'='*60}")

    config_plot_dir = os.path.join(output_dir, f"plots_{config['name']}")
    os.makedirs(config_plot_dir, exist_ok=True)

    train_tensor = torch.from_numpy(X_train).float()
    train_labels = torch.from_numpy(y_train)
    train_dataset = TensorDataset(train_tensor, train_labels)

    val_tensor = torch.from_numpy(X_val).float()
    val_labels = torch.from_numpy(y_val)
    val_dataset = TensorDataset(val_tensor, val_labels)

    train_sampler = BalancedBatchSampler(train_labels, config['batch_size'], neg_batch_ratio=0.7)
    train_loader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2)

    test_ds = TensorDataset(torch.from_numpy(X_test).float(), torch.from_numpy(y_test))
    test_loader = DataLoader(test_ds, batch_size=config['batch_size'], shuffle=False, num_workers=2)

    model = InteractionNN(
        input_dim=input_dimension,
        hidden_dim=config['hidden_dimension'],
        num_layers=config['num_layers'],
        dropout=config['dropout']
    ).to(device)

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {total_params:,}")

    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda e: cosine_annealing_with_warmup(e, config['warmup_epochs'], epochs, min_lr)
    )
    criterion = nn.BCELoss()

    best_val_loss = float('inf')
    best_state = None
    epochs_no_improv = 0

    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for batch in train_loader:
            xb, yb = batch
            xb = xb.to(device)
            yb = yb.to(device, dtype=torch.float32)
            optimizer.zero_grad()
            outputs = model(xb).view(-1)
            loss = criterion(outputs, yb)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)

        model.eval()
        total_val_loss = 0.0
        all_val_probs = []
        all_val_labels = []
        with torch.no_grad():
            for batch in val_loader:
                xb, yb = batch
                xb = xb.to(device)
                yb = yb.to(device, dtype=torch.float32)
                outputs = model(xb).view(-1)
                loss = criterion(outputs, yb)
                total_val_loss += loss.item()
                all_val_probs.extend(outputs.cpu().numpy())
                all_val_labels.extend(yb.cpu().numpy())
        val_loss = total_val_loss / len(val_loader)
        val_losses.append(val_loss)

        if val_loss < best_val_loss - 1e-8:
            best_val_loss = val_loss
            best_state = model.state_dict()
            epochs_no_improv = 0
        else:
            epochs_no_improv += 1

        scheduler.step()
        lr_now = optimizer.param_groups[0]['lr']

        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | LR: {lr_now:.2e}")

        if epochs_no_improv >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    if best_state is not None:
        model.load_state_dict(best_state)

    model.eval()
    val_probs = np.array(all_val_probs)
    val_labels = np.array(all_val_labels)
    thresholds = np.linspace(0, 1, 101)
    f1_scores = []
    for thr in thresholds:
        preds_thr = (val_probs > thr).astype(np.int32)
        _, _, f1, _ = precision_recall_fscore_support(val_labels, preds_thr, average='binary', zero_division=0)
        f1_scores.append(f1)
    best_thr = thresholds[int(np.nanargmax(f1_scores))]
    print(f"Optimal threshold (val F1): {best_thr:.3f}")

    all_test_probs = []
    all_test_labels = []
    with torch.no_grad():
        for xb, yb in test_loader:
            xb = xb.to(device)
            outputs = model(xb).view(-1)
            all_test_probs.extend(outputs.cpu().numpy())
            all_test_labels.extend(yb.numpy())

    all_test_probs = np.array(all_test_probs)
    all_test_labels = np.array(all_test_labels)
    test_preds = (all_test_probs > best_thr).astype(int)

    accuracy = accuracy_score(all_test_labels, test_preds)
    balanced_acc = balanced_accuracy_score(all_test_labels, test_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_test_labels, test_preds, average='binary', zero_division=0
    )
    auroc = roc_auc_score(all_test_labels, all_test_probs)
    auprc = average_precision_score(all_test_labels, all_test_probs)

    fpr, tpr, _ = roc_curve(all_test_labels, all_test_probs)
    plt.figure(figsize=(6, 5))
    plt.plot(fpr, tpr, color=color_green, linewidth=2, label=f'AUROC={auroc:.3f}')
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC Curve - {config["name"]}')
    plt.xticks(fontsize=font_ticks)
    plt.yticks(fontsize=font_ticks)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=14, loc="lower right")
    plt.tight_layout()
    plt.savefig(os.path.join(config_plot_dir, "roc_curve.pdf"), bbox_inches="tight")
    plt.close()

    prec_vals, rec_vals, _ = precision_recall_curve(all_test_labels, all_test_probs)
    plt.figure(figsize=(6, 5))
    plt.plot(rec_vals, prec_vals, color=color3, linewidth=2, label=f'AUPRC={auprc:.3f}')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(f'Precision-Recall Curve - {config["name"]}')
    plt.xticks(fontsize=font_ticks)
    plt.yticks(fontsize=font_ticks)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=14, loc="lower left")
    plt.tight_layout()
    plt.savefig(os.path.join(config_plot_dir, "pr_curve.pdf"), bbox_inches="tight")
    plt.close()

    cm = confusion_matrix(all_test_labels, test_preds, normalize="true")
    plt.figure(figsize=(6, 5))

    ax = sns.heatmap(
        cm,
        annot=True,
        fmt=".2%",
        cmap=blue_cmap,
        xticklabels=['Negative', 'Positive'],
        yticklabels=['Negative', 'Positive'],
        annot_kws={"size": font_ticks},
        cbar_kws={"shrink": 0.9}
    )

    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title(f'Confusion Matrix - {config["name"]}')
    plt.xticks(fontsize=font_ticks)
    plt.yticks(fontsize=font_ticks)
    cbar = ax.collections[0].colorbar
    cbar.ax.yaxis.set_tick_params(labelsize=font_ticks)
    plt.tight_layout()
    plt.savefig(os.path.join(config_plot_dir, "confusion_matrix.pdf"), bbox_inches="tight")
    plt.close()

    plt.figure(figsize=(8, 6))
    box = plt.boxplot([all_test_probs[all_test_labels == 0],
                       all_test_probs[all_test_labels == 1]],
                      labels=["Negative", "Positive"], patch_artist=True)
    box['boxes'][0].set(facecolor=color_red)
    box['boxes'][1].set(facecolor=color_green)

    for median in box['medians']:
        median.set(color='#FFD260', linewidth=2)

    plt.ylabel("Predicted Probability")
    plt.title(f'Predicted Probabilities - {config["name"]}')
    plt.xticks(fontsize=font_ticks)
    plt.yticks(fontsize=font_ticks)
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(config_plot_dir, "boxplot_probabilities.pdf"), bbox_inches="tight")
    plt.close()

    results = {
        'config_name': config['name'],
        'fold': fold_id,
        'accuracy': float(accuracy),
        'balanced_accuracy': float(balanced_acc),
        'precision': float(precision),
        'recall': float(recall),
        'f1': float(f1),
        'auroc': float(auroc),
        'auprc': float(auprc),
        'threshold': float(best_thr),
        'total_params': total_params,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'test_probs': all_test_probs,
        'test_labels': all_test_labels,
        'test_preds': test_preds,
        **config
    }

    print(f"\n{'='*60}")
    print(f"Results for {config['name']}:")
    print(f"  Accuracy:          {accuracy:.4f}")
    print(f"  Balanced Accuracy: {balanced_acc:.4f}")
    print(f"  Precision:         {precision:.4f}")
    print(f"  Recall:            {recall:.4f}")
    print(f"  F1 Score:          {f1:.4f}")
    print(f"  AUROC:             {auroc:.4f}")
    print(f"  AUPRC:             {auprc:.4f}")
    print(f"  Plots saved to:    {config_plot_dir}")
    print(f"{'='*60}\n")

    return results, model

def run_hyperparameter_tuning(folds_dir, dict_path, output_dir='hyperparameter_tuning_results'):
    """
    Main function to run hyperparameter tuning across all configurations.
    """
    os.makedirs(output_dir, exist_ok=True)

    # Load embeddings
    with open(dict_path, 'rb') as fh:
        embeddings_dict = pickle.load(fh)
    print(f"Loaded embeddings: {len(embeddings_dict)} sequences")

    # Load fold data
    fold_files = sorted([f for f in os.listdir(folds_dir) if f.endswith('.pkl')])
    if not fold_files:
        raise FileNotFoundError(f"No fold files found in {folds_dir}")

    # Use first fold
    fold_path = os.path.join(folds_dir, fold_files[0])
    with open(fold_path, 'rb') as fh:
        fold_data = pickle.load(fh)

    print(f"Loaded fold data from {fold_path}")

    # Prepare data
    def pairs_to_arrays(pairs_list, embeddings_dict, label):
        X_list = []
        y_list = []
        missing = 0
        for (s1, s2), (t1, t2) in pairs_list:
            if s1 not in embeddings_dict or s2 not in embeddings_dict:
                missing += 1
                continue
            e1 = np.asarray(embeddings_dict[s1], dtype=np.float32)
            e2 = np.asarray(embeddings_dict[s2], dtype=np.float32)
            e1s = e1 if e1.shape[0] >= 1280 else np.pad(e1, (1280 - e1.shape[0], 0), mode='constant')
            e2s = e2 if e2.shape[0] >= 1280 else np.pad(e2, (1280 - e2.shape[0], 0), mode='constant')
            X_list.append(np.concatenate([e1s, e2s], axis=0))
            y_list.append(int(label))
        if missing:
            print(f"  Skipped {missing} pairs due to missing embeddings.")
        if len(X_list) == 0:
            return None, None
        return np.stack(X_list), np.array(y_list, dtype=np.int64)

    # Extract train and test data
    train_pos = fold_data['train']['positives']
    train_neg = fold_data['train']['negatives']
    test_pos = fold_data['test']['positives']
    test_neg = fold_data['test']['negatives']

    Xp, yp = pairs_to_arrays(train_pos, embeddings_dict, label=1)
    Xn, yn = pairs_to_arrays(train_neg, embeddings_dict, label=0)
    X_all = np.concatenate([Xp, Xn], axis=0)
    y_all = np.concatenate([yp, yn], axis=0)

    # Train/val split
    n_total = len(y_all)
    n_train = int((1 - validation_split_ratio) * n_total)
    indices = np.arange(n_total)
    np.random.shuffle(indices)
    train_inds = indices[:n_train]
    val_inds = indices[n_train:]

    X_train_raw = X_all[train_inds]
    y_train = y_all[train_inds]
    X_val_raw = X_all[val_inds]
    y_val = y_all[val_inds]

    # Test data
    Xp_test, yp_test = pairs_to_arrays(test_pos, embeddings_dict, label=1)
    Xn_test, yn_test = pairs_to_arrays(test_neg, embeddings_dict, label=0)
    X_test_raw = np.concatenate([Xp_test, Xn_test], axis=0)
    y_test = np.concatenate([yp_test, yn_test], axis=0)

    # Standardize
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train_raw)
    X_val = scaler.transform(X_val_raw)
    X_test = scaler.transform(X_test_raw)

    print(f"\nData shapes:")
    print(f"  Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}")

    # Train all configurations
    all_results = []
    for config in hyperparameter_configs:
        results, model = train_with_config(
            config, X_train, y_train, X_val, y_val, X_test, y_test, fold_id=0, output_dir=output_dir
        )
        all_results.append(results)

        # Save model
        model_path = os.path.join(output_dir, f"model_{config['name']}.pth")
        torch.save(model.state_dict(), model_path)

    # Save results
    results_df = pd.DataFrame(all_results)
    results_df.to_csv(os.path.join(output_dir, 'hyperparameter_results.csv'), index=False)

    with open(os.path.join(output_dir, 'all_results.pkl'), 'wb') as f:
        pickle.dump(all_results, f)

    # Generate comparison plots
    generate_comparison_plots(all_results, output_dir)

    return results_df, all_results

def generate_comparison_plots(all_results, output_dir):
    """Generate comparison plots for all configurations."""

    for result in all_results:
        result['test_probs'] = np.array(result['test_probs'])
        result['test_labels'] = np.array(result['test_labels'])
        result['test_preds'] = np.array(result['test_preds'])

    df = pd.DataFrame(all_results)
    config_order = ['Baseline', 'Deeper_HighDropout', 'Wider_LowLR', 'Smaller_HighLR', 'Balanced']
    df['config_name'] = pd.Categorical(df['config_name'], categories=config_order, ordered=True)
    df = df.sort_values('config_name')

    # 1. Bar plot comparison with 7 metrics
    fig, axes = plt.subplots(4, 2, figsize=(16, 18))
    axes = axes.flatten()

    metrics = ['accuracy', 'balanced_accuracy', 'precision', 'recall', 'f1', 'auroc', 'auprc']
    titles = ['Accuracy', 'Balanced Accuracy', 'Precision', 'Recall', 'F1 Score', 'AUROC', 'AUPRC']

    for idx, (metric, title) in enumerate(zip(metrics, titles)):
        ax = axes[idx]
        bar_width = 0.45
        x = np.arange(len(df))
        bars = ax.bar(range(len(df)), df[metric])

        for bar in bars:
          bar.set_color('#1f77b4')

        # Color best bar
        best_idx = df[metric].idxmax()
        best_position = df.index.get_loc(best_idx)
        bars[best_position].set_color('#c16dff')

        ax.set_xticks(range(len(df)))
        ax.set_xticklabels(df['config_name'])
        ax.set_ylabel(title)
        ax.set_title(f'{title} Comparison')
        ax.grid(axis='y', alpha=0.3)

        for bar, val in zip(bars, df[metric]):
          height = bar.get_height()
          ax.text(
              bar.get_x() + bar.get_width() / 2.,
              height,
              f'{val:.3f}',
              ha='center',
              va='bottom',
              fontsize=9
              )

    for idx in range(len(metrics), len(axes)):
        fig.delaxes(axes[idx])

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'metrics_comparison.pdf'), bbox_inches='tight')
    plt.close()

    # 2. Summary table
    fig, ax = plt.subplots(figsize=(16, 6))
    ax.axis('tight')
    ax.axis('off')

    table_data = []
    table_data.append(['Config', 'Acc', 'Bal Acc', 'Prec', 'Rec', 'F1', 'AUROC', 'AUPRC', 'Params', 'Hidden', 'Layers', 'LR'])

    for _, row in df.iterrows():
        table_data.append([
            row['config_name'],
            f"{row['accuracy']:.4f}",
            f"{row['balanced_accuracy']:.4f}",
            f"{row['precision']:.4f}",
            f"{row['recall']:.4f}",
            f"{row['f1']:.4f}",
            f"{row['auroc']:.4f}",
            f"{row['auprc']:.4f}",
            f"{row['total_params']/1e6:.2f}M",
            str(row['hidden_dimension']),
            str(row['num_layers']),
            f"{row['learning_rate']:.4f}"
        ])

    table = ax.table(cellText=table_data, loc='center', cellLoc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(8)
    table.scale(1.2, 2)

    # Highlight header
    for i in range(len(table_data[0])):
        table[(0, i)].set_facecolor('#88B04B')
        table[(0, i)].set_text_props(weight='bold', color='white')

    plt.savefig(os.path.join(output_dir, 'results_table.pdf'), bbox_inches='tight')
    plt.close()

    # 3. Combined ROC curves (all configs on one plot)
    plt.figure(figsize=(8, 6))
    colors_list = plt.cm.tab10.colors
    for idx, row in df.iterrows():
        fpr, tpr, _ = roc_curve(row['test_labels'], row['test_probs'])
        plt.plot(fpr, tpr, color=colors_list[idx % len(colors_list)], linewidth=2,
                label=f"{row['config_name']} (AUROC={row['auroc']:.3f})")
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate', fontsize=font_labels)
    plt.title('ROC Curves - All Configurations')
    plt.xticks(fontsize=font_ticks)
    plt.yticks(fontsize=font_ticks)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=10, loc="lower right")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'combined_roc_curves.pdf'), bbox_inches="tight")
    plt.close()

    # 4. Combined PR curves (all configs on one plot)
    plt.figure(figsize=(8, 6))
    for idx, row in df.iterrows():
        prec_vals, rec_vals, _ = precision_recall_curve(row['test_labels'], row['test_probs'])
        plt.plot(rec_vals, prec_vals, color=colors_list[idx % len(colors_list)], linewidth=2,
                label=f"{row['config_name']} (AUPRC={row['auprc']:.3f})")
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curves - All Configurations')
    plt.xticks(fontsize=font_ticks)
    plt.yticks(fontsize=font_ticks)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=10, loc="lower left")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'combined_pr_curves.pdf'), bbox_inches="tight")
    plt.close()

    print(f"\nPlots saved to {output_dir}")
    print(f"\nBest configurations:")
    print(f"  Accuracy:          {df.loc[df['accuracy'].idxmax(), 'config_name']}")
    print(f"  Balanced Accuracy: {df.loc[df['balanced_accuracy'].idxmax(), 'config_name']}")
    print(f"  Precision:         {df.loc[df['precision'].idxmax(), 'config_name']}")
    print(f"  Recall:            {df.loc[df['recall'].idxmax(), 'config_name']}")
    print(f"  F1 Score:          {df.loc[df['f1'].idxmax(), 'config_name']}")
    print(f"  AUROC:             {df.loc[df['auroc'].idxmax(), 'config_name']}")
    print(f"  AUPRC:             {df.loc[df['auprc'].idxmax(), 'config_name']}")

folds_dir = "/content/folds"
dict_path = "/content/lncrna-pseudo_dictionary_pooled_embeddings_RNAFM.p"

results_df, all_results = run_hyperparameter_tuning(folds_dir, dict_path)

print("\n" + "="*60)
print("HYPERPARAMETER TUNING COMPLETE!")
print("="*60)
print("\nResults Summary:")
print(results_df[['config_name', 'accuracy', 'balanced_accuracy', 'precision',
                      'recall', 'f1', 'auroc', 'auprc']].to_string(index=False))

Loaded embeddings: 262 sequences
Loaded fold data from /content/folds/fold_0.pkl

Data shapes:
  Train: (15268, 2560), Val: (3818, 2560), Test: (6205, 2560)

Training Config: Baseline
Model parameters: 5,772,289
Epoch 1/50 | Train Loss: 0.339980 | Val Loss: 0.138299 | LR: 2.50e-04
Epoch 2/50 | Train Loss: 0.138897 | Val Loss: 0.113758 | LR: 3.75e-04
Epoch 3/50 | Train Loss: 0.112389 | Val Loss: 0.110320 | LR: 5.00e-04
Epoch 4/50 | Train Loss: 0.110109 | Val Loss: 0.105497 | LR: 5.00e-04
Epoch 5/50 | Train Loss: 0.087465 | Val Loss: 0.126174 | LR: 4.99e-04
Epoch 6/50 | Train Loss: 0.081738 | Val Loss: 0.144462 | LR: 4.98e-04
Epoch 7/50 | Train Loss: 0.059249 | Val Loss: 0.139402 | LR: 4.95e-04
Epoch 8/50 | Train Loss: 0.052637 | Val Loss: 0.159912 | LR: 4.91e-04
Epoch 9/50 | Train Loss: 0.044964 | Val Loss: 0.166289 | LR: 4.86e-04
Epoch 10/50 | Train Loss: 0.056304 | Val Loss: 0.158880 | LR: 4.79e-04
Epoch 11/50 | Train Loss: 0.040926 | Val Loss: 0.148454 | LR: 4.72e-04
Epoch 12/50 | Tr

  box = plt.boxplot([all_test_probs[all_test_labels == 0],



Results for Baseline:
  Accuracy:          0.9892
  Balanced Accuracy: 0.9701
  Precision:         0.9785
  Recall:            0.9435
  F1 Score:          0.9607
  AUROC:             0.9930
  AUPRC:             0.9721
  Plots saved to:    hyperparameter_tuning_results/plots_Baseline


Training Config: Deeper_HighDropout
Model parameters: 7,871,489
Epoch 1/50 | Train Loss: 0.524994 | Val Loss: 0.279716 | LR: 1.20e-04
Epoch 2/50 | Train Loss: 0.277121 | Val Loss: 0.143919 | LR: 1.80e-04
Epoch 3/50 | Train Loss: 0.131912 | Val Loss: 0.119006 | LR: 2.40e-04
Epoch 4/50 | Train Loss: 0.114730 | Val Loss: 0.111811 | LR: 3.00e-04
Epoch 5/50 | Train Loss: 0.103465 | Val Loss: 0.128974 | LR: 3.00e-04
Epoch 6/50 | Train Loss: 0.087116 | Val Loss: 0.107416 | LR: 3.00e-04
Epoch 7/50 | Train Loss: 0.079163 | Val Loss: 0.111132 | LR: 2.99e-04
Epoch 8/50 | Train Loss: 0.065700 | Val Loss: 0.121760 | LR: 2.97e-04
Epoch 9/50 | Train Loss: 0.071145 | Val Loss: 0.122885 | LR: 2.94e-04
Epoch 10/50 | Train

  box = plt.boxplot([all_test_probs[all_test_labels == 0],



Results for Deeper_HighDropout:
  Accuracy:          0.9871
  Balanced Accuracy: 0.9650
  Precision:         0.9724
  Recall:            0.9343
  F1 Score:          0.9530
  AUROC:             0.9941
  AUPRC:             0.9778
  Plots saved to:    hyperparameter_tuning_results/plots_Deeper_HighDropout


Training Config: Wider_LowLR
Model parameters: 17,836,033
Epoch 1/50 | Train Loss: 0.554338 | Val Loss: 0.287642 | LR: 3.33e-05
Epoch 2/50 | Train Loss: 0.238479 | Val Loss: 0.134599 | LR: 5.00e-05
Epoch 3/50 | Train Loss: 0.130047 | Val Loss: 0.122571 | LR: 6.67e-05
Epoch 4/50 | Train Loss: 0.104104 | Val Loss: 0.105530 | LR: 8.33e-05
Epoch 5/50 | Train Loss: 0.088998 | Val Loss: 0.101837 | LR: 1.00e-04
Epoch 6/50 | Train Loss: 0.075455 | Val Loss: 0.107226 | LR: 1.00e-04
Epoch 7/50 | Train Loss: 0.059765 | Val Loss: 0.106129 | LR: 9.99e-05
Epoch 8/50 | Train Loss: 0.049783 | Val Loss: 0.093430 | LR: 9.95e-05
Epoch 9/50 | Train Loss: 0.038357 | Val Loss: 0.137501 | LR: 9.89e-05
Epoch

  box = plt.boxplot([all_test_probs[all_test_labels == 0],



Results for Wider_LowLR:
  Accuracy:          0.9908
  Balanced Accuracy: 0.9744
  Precision:         0.9822
  Recall:            0.9516
  F1 Score:          0.9666
  AUROC:             0.9923
  AUPRC:             0.9694
  Plots saved to:    hyperparameter_tuning_results/plots_Wider_LowLR


Training Config: Smaller_HighLR
Model parameters: 1,837,057
Epoch 1/50 | Train Loss: 0.212075 | Val Loss: 0.124576 | LR: 6.67e-04
Epoch 2/50 | Train Loss: 0.125948 | Val Loss: 0.119068 | LR: 1.00e-03
Epoch 3/50 | Train Loss: 0.108117 | Val Loss: 0.126852 | LR: 1.00e-03
Epoch 4/50 | Train Loss: 0.094235 | Val Loss: 0.110742 | LR: 9.99e-04
Epoch 5/50 | Train Loss: 0.070736 | Val Loss: 0.175592 | LR: 9.96e-04
Epoch 6/50 | Train Loss: 0.065563 | Val Loss: 0.167913 | LR: 9.90e-04
Epoch 7/50 | Train Loss: 0.053215 | Val Loss: 0.144298 | LR: 9.82e-04
Epoch 8/50 | Train Loss: 0.043787 | Val Loss: 0.152323 | LR: 9.72e-04
Epoch 9/50 | Train Loss: 0.043918 | Val Loss: 0.156572 | LR: 9.60e-04
Epoch 10/50 | Tra

  box = plt.boxplot([all_test_probs[all_test_labels == 0],



Results for Smaller_HighLR:
  Accuracy:          0.9868
  Balanced Accuracy: 0.9696
  Precision:         0.9591
  Recall:            0.9459
  F1 Score:          0.9524
  AUROC:             0.9914
  AUPRC:             0.9744
  Plots saved to:    hyperparameter_tuning_results/plots_Smaller_HighLR


Training Config: Balanced
Model parameters: 13,378,561
Epoch 1/50 | Train Loss: 0.457891 | Val Loss: 0.198632 | LR: 8.00e-05
Epoch 2/50 | Train Loss: 0.162637 | Val Loss: 0.117979 | LR: 1.20e-04
Epoch 3/50 | Train Loss: 0.113791 | Val Loss: 0.105756 | LR: 1.60e-04
Epoch 4/50 | Train Loss: 0.098777 | Val Loss: 0.106591 | LR: 2.00e-04
Epoch 5/50 | Train Loss: 0.091797 | Val Loss: 0.123597 | LR: 2.00e-04
Epoch 6/50 | Train Loss: 0.076269 | Val Loss: 0.117209 | LR: 2.00e-04
Epoch 7/50 | Train Loss: 0.062210 | Val Loss: 0.132459 | LR: 1.99e-04
Epoch 8/50 | Train Loss: 0.056495 | Val Loss: 0.100896 | LR: 1.98e-04
Epoch 9/50 | Train Loss: 0.046083 | Val Loss: 0.110896 | LR: 1.96e-04
Epoch 10/50 | Tr

  box = plt.boxplot([all_test_probs[all_test_labels == 0],



Results for Balanced:
  Accuracy:          0.9908
  Balanced Accuracy: 0.9725
  Precision:         0.9868
  Recall:            0.9470
  F1 Score:          0.9665
  AUROC:             0.9935
  AUPRC:             0.9783
  Plots saved to:    hyperparameter_tuning_results/plots_Balanced


Plots saved to hyperparameter_tuning_results

Best configurations:
  Accuracy:          Wider_LowLR
  Balanced Accuracy: Wider_LowLR
  Precision:         Balanced
  Recall:            Wider_LowLR
  F1 Score:          Wider_LowLR
  AUROC:             Deeper_HighDropout
  AUPRC:             Balanced

HYPERPARAMETER TUNING COMPLETE!

Results Summary:
       config_name  accuracy  balanced_accuracy  precision   recall       f1    auroc    auprc
          Baseline  0.989202           0.970088   0.978495 0.943548 0.960704 0.993013 0.972139
Deeper_HighDropout  0.987107           0.965011   0.972422 0.934332 0.952996 0.994085 0.977763
       Wider_LowLR  0.990814           0.974401   0.982164 0.951613 0.966647 0

optional

In [6]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
from scipy.stats import gaussian_kde
import umap
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap

# ==========================================
# MODEL DEFINITION
# ==========================================
class InteractionNN(nn.Module):
    """Neural network model with embedding extraction capability"""
    def __init__(self, input_dim=2560, hidden_dim=1024, num_layers=4, dropout=0.2):
        super(InteractionNN, self).__init__()
        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(p=dropout)]
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(p=dropout))
        self.hidden_layers = nn.Sequential(*layers)
        self.output_layer = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = self.hidden_layers(x)
        x = self.output_layer(x)
        return torch.sigmoid(x).view(-1)

    def get_embeddings(self, x):
        """Extract embeddings from second-to-last layer"""
        with torch.no_grad():
            embeddings = self.hidden_layers(x)
        return embeddings

# ==========================================
# EMBEDDING EXTRACTION
# ==========================================
def extract_embeddings(model, data_loader, device='cuda'):
    """Extract embeddings, labels, and predictions from model"""
    model.eval()
    all_embeddings = []
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for batch_x, batch_y in data_loader:
            batch_x = batch_x.to(device)
            embeddings = model.get_embeddings(batch_x)
            all_embeddings.append(embeddings.cpu().numpy())
            outputs = model(batch_x)
            all_predictions.append(outputs.cpu().numpy())
            all_labels.append(batch_y.numpy())

    embeddings = np.vstack(all_embeddings)
    labels = np.concatenate(all_labels)
    predictions = np.concatenate(all_predictions)

    return embeddings, labels, predictions

# ==========================================
# DIMENSIONALITY REDUCTION
# ==========================================
def apply_umap(embeddings, n_neighbors=15, min_dist=0.1, random_state=42):
    """Apply UMAP dimensionality reduction"""
    print(f"Applying UMAP (n_neighbors={n_neighbors}, min_dist={min_dist})...")
    reducer = umap.UMAP(
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        metric='euclidean',
        random_state=random_state,
        n_components=2
    )
    reduced = reducer.fit_transform(embeddings)
    print(f"UMAP complete. Shape: {reduced.shape}")
    return reduced

def apply_tsne(embeddings, perplexity=30, n_iter=1000, random_state=42):
    """Apply t-SNE dimensionality reduction"""
    print(f"Applying t-SNE (perplexity={perplexity}, n_iter={n_iter})...")
    reducer = TSNE(
        n_components=2,
        perplexity=perplexity,
        n_iter=n_iter,
        random_state=random_state
    )
    reduced = reducer.fit_transform(embeddings)
    print(f"t-SNE complete. Shape: {reduced.shape}")
    return reduced

# ==========================================
# INDIVIDUAL PLOT FUNCTIONS
# ==========================================
def plot_true_labels(reduced_embeddings, labels, method_name, save_path):
    """Create plot colored by true labels"""
    fig, ax = plt.subplots(figsize=(10, 8))

    color_positive = '#0e1bcc'
    color_negative = '#c370ff'

    # Plot negative samples
    mask_neg = labels == 0
    ax.scatter(
        reduced_embeddings[mask_neg, 0],
        reduced_embeddings[mask_neg, 1],
        c=color_negative,
        label='Negative',
        alpha=0.6,
        s=30,
        edgecolors='none'
    )

    # Plot positive samples
    mask_pos = labels == 1
    ax.scatter(
        reduced_embeddings[mask_pos, 0],
        reduced_embeddings[mask_pos, 1],
        c=color_positive,
        label='Positive',
        alpha=0.6,
        s=30,
        edgecolors='none'
    )

    ax.set_xlabel(f'{method_name} 1')
    ax.set_ylabel(f'{method_name} 2')
    ax.set_title(f'{method_name} - True Labels')
    ax.legend(fontsize=12, markerscale=1.5)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved: {save_path}")


def plot_prediction_confidence(reduced_embeddings, predictions, method_name, save_path):
    """Create plot colored by prediction confidence"""
    plasma = plt.cm.get_cmap("plasma")
    colors = plasma(np.linspace(0, 0.8, 256))
    plasma_no_yellow = LinearSegmentedColormap.from_list(
        "plasma_no_yellow", colors
        )
    plasma_no_yellow_r = plasma_no_yellow.reversed()
    fig, ax = plt.subplots(figsize=(10, 8))
    scatter = ax.scatter(
        reduced_embeddings[:, 0],
        reduced_embeddings[:, 1],
        c=predictions,
        cmap=plasma_no_yellow_r,
        alpha=0.6,
        s=30,
        edgecolors='none',
        vmin=0,
        vmax=1
        )
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label('Prediction Probability')
    cbar.ax.tick_params(labelsize=10)

    ax.set_xlabel(f'{method_name} 1')
    ax.set_ylabel(f'{method_name} 2')
    ax.set_title(f'{method_name} - Prediction Confidence')
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved: {save_path}")

def plot_density(reduced_embeddings, method_name, save_path):
    """Create density plot"""
    fig, ax = plt.subplots(figsize=(10, 8))

    # Calculate density using Gaussian KDE
    xy = reduced_embeddings.T
    try:
        z = gaussian_kde(xy)(xy)
    except:
        # If KDE fails (e.g., singular matrix), use simple density
        print("KDE failed, using uniform density")
        z = np.ones(reduced_embeddings.shape[0])

    scatter = ax.scatter(
        reduced_embeddings[:, 0],
        reduced_embeddings[:, 1],
        c=z,
        cmap='viridis',
        alpha=0.6,
        s=30,
        edgecolors='none'
    )

    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label('Density')
    cbar.ax.tick_params(labelsize=10)

    ax.set_xlabel(f'{method_name} 1')
    ax.set_ylabel(f'{method_name} 2')
    ax.set_title(f'{method_name} - Density Map')
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved: {save_path}")

# ==========================================
# MAIN VISUALIZATION PIPELINE
# ==========================================
def visualize_embeddings_separately(
    model_path,
    fold_data_path,
    embeddings_dict_path,
    output_dir='embedding_visualizations_separate',
    device='cuda',
    batch_size=512,
    # UMAP parameters
    umap_n_neighbors=15,
    umap_min_dist=0.1,
    # t-SNE parameters
    tsne_perplexity=30,
    tsne_n_iter=1000
):
    """
    Complete pipeline to extract and visualize embeddings with separate plots

    Creates 6 individual plots:
    - UMAP: true labels, prediction confidence, density
    - t-SNE: true labels, prediction confidence, density
    """
    os.makedirs(output_dir, exist_ok=True)

    # Load data
    print("="*60)
    print("LOADING DATA")
    print("="*60)

    with open(embeddings_dict_path, 'rb') as f:
        embeddings_dict = pickle.load(f)
    print(f"Loaded embeddings dictionary")

    with open(fold_data_path, 'rb') as f:
        fold_data = pickle.load(f)
    print(f"Loaded fold data")

    # Prepare test data
    def pairs_to_arrays(pairs_list, embeddings_dict, label):
        X_list = []
        y_list = []
        missing = 0
        for (s1, s2), (t1, t2) in pairs_list:
            if s1 not in embeddings_dict or s2 not in embeddings_dict:
                missing += 1
                continue
            e1 = np.asarray(embeddings_dict[s1], dtype=np.float32)
            e2 = np.asarray(embeddings_dict[s2], dtype=np.float32)
            e1s = e1 if e1.shape[0] >= 1280 else np.pad(e1, (1280 - e1.shape[0], 0), mode='constant')
            e2s = e2 if e2.shape[0] >= 1280 else np.pad(e2, (1280 - e2.shape[0], 0), mode='constant')
            X_list.append(np.concatenate([e1s, e2s], axis=0))
            y_list.append(int(label))
        if missing:
            print(f"  Skipped {missing} pairs due to missing embeddings")
        return np.stack(X_list), np.array(y_list, dtype=np.int64)

    print("\nPreparing test data...")
    test_pos = fold_data['test']['positives']
    test_neg = fold_data['test']['negatives']

    Xp_test, yp_test = pairs_to_arrays(test_pos, embeddings_dict, label=1)
    Xn_test, yn_test = pairs_to_arrays(test_neg, embeddings_dict, label=0)
    X_test = np.concatenate([Xp_test, Xn_test], axis=0)
    y_test = np.concatenate([yp_test, yn_test], axis=0)

    # Standardize
    scaler = StandardScaler()
    X_test = scaler.fit_transform(X_test)

    print(f"Test data shape: {X_test.shape}")
    print(f"Positive samples: {np.sum(y_test == 1)}")
    print(f"Negative samples: {np.sum(y_test == 0)}")

    # Create DataLoader
    test_dataset = TensorDataset(
        torch.from_numpy(X_test).float(),
        torch.from_numpy(y_test)
    )
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Load model
    print("\n" + "="*60)
    print("LOADING MODEL")
    print("="*60)
    model = InteractionNN(input_dim=2560, hidden_dim=1024, num_layers=4, dropout=0.2)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()
    print(f"Model loaded from: {model_path}")

    # Extract embeddings
    print("\n" + "="*60)
    print("EXTRACTING EMBEDDINGS")
    print("="*60)
    embeddings, labels, predictions = extract_embeddings(model, test_loader, device)
    print(f"Extracted embeddings shape: {embeddings.shape}")

    # Save raw embeddings
    np.savez(
        os.path.join(output_dir, 'raw_embeddings.npz'),
        embeddings=embeddings,
        labels=labels,
        predictions=predictions
    )
    print(f"Saved raw embeddings")

    # Apply UMAP and create plots
    print("\n" + "="*60)
    print("UMAP VISUALIZATION")
    print("="*60)
    umap_embeddings = apply_umap(
        embeddings,
        n_neighbors=umap_n_neighbors,
        min_dist=umap_min_dist
    )

    print("\nCreating UMAP plots...")
    plot_true_labels(
        umap_embeddings, labels, 'UMAP',
        os.path.join(output_dir, 'umap_true_labels.pdf')
    )
    plot_prediction_confidence(
        umap_embeddings, predictions, 'UMAP',
        os.path.join(output_dir, 'umap_prediction_confidence.pdf')
    )
    plot_density(
        umap_embeddings, 'UMAP',
        os.path.join(output_dir, 'umap_density.pdf')
    )

    # Apply t-SNE and create plots
    print("\n" + "="*60)
    print("t-SNE VISUALIZATION")
    print("="*60)
    tsne_embeddings = apply_tsne(
        embeddings,
        perplexity=tsne_perplexity,
        n_iter=tsne_n_iter
    )

    print("\nCreating t-SNE plots...")
    plot_true_labels(
        tsne_embeddings, labels, 't-SNE',
        os.path.join(output_dir, 'tsne_true_labels.pdf')
    )
    plot_prediction_confidence(
        tsne_embeddings, predictions, 't-SNE',
        os.path.join(output_dir, 'tsne_prediction_confidence.pdf')
    )
    plot_density(
        tsne_embeddings, 't-SNE',
        os.path.join(output_dir, 'tsne_density.pdf')
    )

    # Save reduced embeddings
    np.savez(
        os.path.join(output_dir, 'reduced_embeddings.npz'),
        umap=umap_embeddings,
        tsne=tsne_embeddings,
        labels=labels,
        predictions=predictions
    )

    # Summary
    print("\n" + "="*60)
    print("VISUALIZATION COMPLETE!")
    print("="*60)
    print(f"\nAll visualizations saved to: {output_dir}/")
    print("\nUMAP plots:")
    print("  - umap_true_labels.pdf")
    print("  - umap_prediction_confidence.pdf")
    print("  - umap_density.pdf")
    print("\nt-SNE plots:")
    print("  - tsne_true_labels.pdf")
    print("  - tsne_prediction_confidence.pdf")
    print("  - tsne_density.pdf")
    print("\nData files:")
    print("  - raw_embeddings.npz")
    print("  - reduced_embeddings.npz")

# ==========================================
# EXAMPLE USAGE
# ==========================================
if __name__ == "__main__":
    # PATHS - MODIFY THESE FOR YOUR SETUP
    model_path = "hyperparameter_tuning_results/model_Baseline.pth"
    fold_data_path = "/content/folds/fold_0.pkl"
    embeddings_dict_path = "/content/lncrna-pseudo_dictionary_pooled_embeddings_RNAFM.p"
    output_dir = "embedding_visualizations_separate"

    # Check device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}\n")

    # Run visualization
    visualize_embeddings_separately(
        model_path=model_path,
        fold_data_path=fold_data_path,
        embeddings_dict_path=embeddings_dict_path,
        output_dir=output_dir,
        device=device,
        batch_size=512,
        # UMAP parameters
        umap_n_neighbors=15,
        umap_min_dist=0.1,
        # t-SNE parameters
        tsne_perplexity=30,
        tsne_n_iter=1000
    )


Using device: cuda

LOADING DATA
Loaded embeddings dictionary
Loaded fold data

Preparing test data...
Test data shape: (6205, 2560)
Positive samples: 868
Negative samples: 5337

LOADING MODEL
Model loaded from: hyperparameter_tuning_results/model_Baseline.pth

EXTRACTING EMBEDDINGS
Extracted embeddings shape: (6205, 1024)
Saved raw embeddings

UMAP VISUALIZATION
Applying UMAP (n_neighbors=15, min_dist=0.1)...


  warn(


UMAP complete. Shape: (6205, 2)

Creating UMAP plots...
Saved: embedding_visualizations_separate/umap_true_labels.pdf


  plasma = plt.cm.get_cmap("plasma")


Saved: embedding_visualizations_separate/umap_prediction_confidence.pdf
Saved: embedding_visualizations_separate/umap_density.pdf

t-SNE VISUALIZATION
Applying t-SNE (perplexity=30, n_iter=1000)...




t-SNE complete. Shape: (6205, 2)

Creating t-SNE plots...
Saved: embedding_visualizations_separate/tsne_true_labels.pdf


  plasma = plt.cm.get_cmap("plasma")


Saved: embedding_visualizations_separate/tsne_prediction_confidence.pdf
Saved: embedding_visualizations_separate/tsne_density.pdf

VISUALIZATION COMPLETE!

All visualizations saved to: embedding_visualizations_separate/

UMAP plots:
  - umap_true_labels.pdf
  - umap_prediction_confidence.pdf
  - umap_density.pdf

t-SNE plots:
  - tsne_true_labels.pdf
  - tsne_prediction_confidence.pdf
  - tsne_density.pdf

Data files:
  - raw_embeddings.npz
  - reduced_embeddings.npz
