In [1]:
import torch
# Load dictionary
TF_train_dict = torch.load("esm2_t6_8M_UR50D_TF_Training_cls.pt") #load any model embeddings
NTF_train_dict = torch.load("esm2_t6_8M_UR50D_NTF_training_cls.pt")

# Extract embeddings as a list of tensors and stack
TF_train_tensor = torch.stack([v for v in TF_train_dict.values()])
NTF_train_tensor = torch.stack([v for v in NTF_train_dict.values()])

# Combine positive and negative samples
X_train = torch.cat([TF_train_tensor, NTF_train_tensor], dim=0)

# Create labels: 1 for TF, 0 for NTF
y_train = torch.cat([
    torch.ones(TF_train_tensor.size(0), dtype=torch.long),
    torch.zeros(NTF_train_tensor.size(0), dtype=torch.long)
])


In [2]:
TF_ind_dict = torch.load("esm2_t6_8M_UR50D_TF_Ind_cls.pt")
NTF_ind_dict = torch.load("esm2_t6_8M_UR50D_NTF_Ind_cls.pt")

TF_ind_tensor = torch.stack([v for v in TF_ind_dict.values()])
NTF_ind_tensor = torch.stack([v for v in NTF_ind_dict.values()])

X_test = torch.cat([TF_ind_tensor, NTF_ind_tensor], dim=0)
y_test = torch.cat([
    torch.ones(TF_ind_tensor.size(0), dtype=torch.long),
    torch.zeros(NTF_ind_tensor.size(0), dtype=torch.long)
])


In [3]:
len(X_train[0])

320

In [4]:
X_train.shape, y_train.shape, X_test.shape, y_test.shape

(torch.Size([829, 320]),
 torch.Size([829]),
 torch.Size([212, 320]),
 torch.Size([212]))

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_train, y_train = X_train.to(device), y_train.to(device)
X_test, y_test = X_test.to(device), y_test.to(device)

In [6]:
from sklearn.model_selection import train_test_split

# Move to CPU and convert to numpy
X_train_cpu = X_train.cpu().numpy()
y_train_cpu = y_train.cpu().numpy()

# Split: 80% train, 20% validation
X_train_split_np, X_val_np, y_train_split_np, y_val_np = train_test_split(
    X_train_cpu, y_train_cpu, test_size=0.2, random_state=66, stratify=y_train_cpu
)

# Convert back to torch tensors
X_train_split = torch.tensor(X_train_split_np, dtype=torch.float32, device=device)
y_train_split = torch.tensor(y_train_split_np, dtype=torch.long, device=device)
X_val = torch.tensor(X_val_np, dtype=torch.float32, device=device)
y_val = torch.tensor(y_val_np, dtype=torch.long, device=device)


In [7]:
from torch.utils.data import TensorDataset, DataLoader

train_dataset = torch.utils.data.TensorDataset(X_train_split, y_train_split)
val_dataset = torch.utils.data.TensorDataset(X_val, y_val)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=False)

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EmsPredictor(nn.Module):
    def __init__(self, embedding_size, hidden_size, dropout, num_classes):
        super(EmsPredictor, self).__init__()
        self.fc1 = nn.Linear(embedding_size, hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)  # logits
        return x

In [9]:
# Model, loss, optimizer
model = EmsPredictor(embedding_size=len(X_train[0]), hidden_size=128, dropout=0.3, num_classes=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

## Single seed training

In [13]:
# epochs = 250  # adjust as needed

# for epoch in range(epochs):
#     # --- Training ---
#     model.train()
#     total_loss = 0
#     correct_train = 0
#     total_train = 0
    
#     for batch_X, batch_y in train_loader:
#         optimizer.zero_grad()
#         outputs = model(batch_X)
#         loss = criterion(outputs, batch_y)
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item()
        
#         preds = torch.argmax(outputs, dim=1)
#         correct_train += (preds == batch_y).sum().item()
#         total_train += batch_y.size(0)
    
#     train_acc = correct_train / total_train
#     avg_loss = total_loss / len(train_loader)
    
#     # --- Validation ---
#     model.eval()
#     correct_val = 0
#     total_val = 0
#     with torch.no_grad():
#         for val_X, val_y in val_loader:
#             outputs = model(val_X)
#             preds = torch.argmax(outputs, dim=1)
#             correct_val += (preds == val_y).sum().item()
#             total_val += val_y.size(0)
    
#     val_acc = correct_val / total_val
    
#     print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")




# import torch
# import numpy as np
# from sklearn.metrics import confusion_matrix, f1_score, matthews_corrcoef, roc_auc_score

# # Predictions and labels on independent set
# model.eval()
# with torch.no_grad():
#     logits = model(X_test)
#     probs = torch.softmax(logits, dim=1)
#     preds = torch.argmax(probs, dim=1).cpu().numpy()
#     prob_pos = probs[:, 1].cpu().numpy()
#     labels = y_test.cpu().numpy()

# # Bootstrap parameters
# n_bootstrap = 1000
# rng = np.random.default_rng(seed=42)

# # Lists to store metrics
# accs, senss, specs, mccs, f1s, aucs = [], [], [], [], [], []

# for _ in range(n_bootstrap):
#     # Sample with replacement
#     idx = rng.integers(0, len(labels), len(labels))
#     sample_labels = labels[idx]
#     sample_preds = preds[idx]
#     sample_probs = prob_pos[idx]

#     # Confusion matrix
#     tn, fp, fn, tp = confusion_matrix(sample_labels, sample_preds, labels=[0,1]).ravel()

#     # Metrics
#     acc = np.mean(sample_preds == sample_labels)
#     sens = tp / (tp + fn) if (tp + fn) > 0 else 0.0
#     spec = tn / (tn + fp) if (tn + fp) > 0 else 0.0
#     mcc = matthews_corrcoef(sample_labels, sample_preds)
#     f1 = f1_score(sample_labels, sample_preds)
#     auc = roc_auc_score(sample_labels, sample_probs) if len(np.unique(sample_labels)) > 1 else np.nan

#     # Store
#     accs.append(acc)
#     senss.append(sens)
#     specs.append(spec)
#     mccs.append(mcc)
#     f1s.append(f1)
#     if not np.isnan(auc):
#         aucs.append(auc)

# # Compute mean ± std
# print(f"Accuracy       : {np.mean(accs):.3f} ± {np.std(accs):.3f}")
# print(f"Sensitivity    : {np.mean(senss):.3f} ± {np.std(senss):.3f}")
# print(f"Specificity    : {np.mean(specs):.3f} ± {np.std(specs):.3f}")
# print(f"MCC            : {np.mean(mccs):.3f} ± {np.std(mccs):.3f}")
# print(f"F1-score       : {np.mean(f1s):.3f} ± {np.std(f1s):.3f}")
# print(f"AUC-ROC        : {np.mean(aucs):.3f} ± {np.std(aucs):.3f}")

# # Overall confusion matrix on full independent set
# cm = confusion_matrix(labels, preds, labels=[0,1])
# print("\nOverall Confusion Matrix:")
# print(cm)

## 3 Different Seeds

In [10]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix, f1_score, matthews_corrcoef, roc_auc_score

# Function to train, validate, and test with bootstrap for a given seed
def run_pipeline(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)

    # --- Training loop ---
    epochs = 150
    for epoch in range(epochs):
        # --- Training ---
        model.train()
        total_loss = 0
        correct_train = 0
        total_train = 0
        
        for batch_X, batch_y in train_loader:
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
            preds = torch.argmax(outputs, dim=1)
            correct_train += (preds == batch_y).sum().item()
            total_train += batch_y.size(0)
        
        train_acc = correct_train / total_train
        avg_loss = total_loss / len(train_loader)
        
        # --- Validation ---
        model.eval()
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for val_X, val_y in val_loader:
                outputs = model(val_X)
                preds = torch.argmax(outputs, dim=1)
                correct_val += (preds == val_y).sum().item()
                total_val += val_y.size(0)
        
        val_acc = correct_val / total_val
        
        print(f"Seed {seed} | Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

    # --- Independent set evaluation ---
    model.eval()
    with torch.no_grad():
        logits = model(X_test)
        probs = torch.softmax(logits, dim=1)
        preds = torch.argmax(probs, dim=1).cpu().numpy()
        prob_pos = probs[:, 1].cpu().numpy()
        labels = y_test.cpu().numpy()

    # Bootstrap parameters
    n_bootstrap = 1000
    rng = np.random.default_rng(seed=seed)

    # Lists to store metrics
    accs, senss, specs, mccs, f1s, aucs = [], [], [], [], [], []

    for _ in range(n_bootstrap):
        # Sample with replacement
        idx = rng.integers(0, len(labels), len(labels))
        sample_labels = labels[idx]
        sample_preds = preds[idx]
        sample_probs = prob_pos[idx]

        # Confusion matrix
        tn, fp, fn, tp = confusion_matrix(sample_labels, sample_preds, labels=[0,1]).ravel()

        # Metrics
        acc = np.mean(sample_preds == sample_labels)
        sens = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        spec = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        mcc = matthews_corrcoef(sample_labels, sample_preds)
        f1 = f1_score(sample_labels, sample_preds)
        auc = roc_auc_score(sample_labels, sample_probs) if len(np.unique(sample_labels)) > 1 else np.nan

        accs.append(acc)
        senss.append(sens)
        specs.append(spec)
        mccs.append(mcc)
        f1s.append(f1)
        if not np.isnan(auc):
            aucs.append(auc)

    # Print independent results for this seed
    print(f"\nIndependent Evaluation (Seed {seed})")
    print(f"Accuracy       : {np.mean(accs):.3f} ± {np.std(accs):.3f}")
    print(f"Sensitivity    : {np.mean(senss):.3f} ± {np.std(senss):.3f}")
    print(f"Specificity    : {np.mean(specs):.3f} ± {np.std(specs):.3f}")
    print(f"MCC            : {np.mean(mccs):.3f} ± {np.std(mccs):.3f}")
    print(f"F1-score       : {np.mean(f1s):.3f} ± {np.std(f1s):.3f}")
    print(f"AUC-ROC        : {np.mean(aucs):.3f} ± {np.std(aucs):.3f}")

    cm = confusion_matrix(labels, preds, labels=[0,1])
    print("\nOverall Confusion Matrix:")
    print(cm)

    return {
        "acc": np.mean(accs),
        "sens": np.mean(senss),
        "spec": np.mean(specs),
        "mcc": np.mean(mccs),
        "f1": np.mean(f1s),
        "auc": np.mean(aucs)
    }


# ===============================
# Run for 3 seeds and average
# ===============================
all_metrics = []
seeds = [42, 33, 101]

for seed in seeds:
    # Reinitialize model + optimizer each run
    model = EmsPredictor(embedding_size=len(X_train[0]), hidden_size=128, dropout=0.3, num_classes=2).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    metrics = run_pipeline(seed)
    all_metrics.append(metrics)

# Compute mean ± std across seeds
final_metrics = {}
for key in all_metrics[0].keys():
    vals = [m[key] for m in all_metrics]
    final_metrics[key] = (np.mean(vals), np.std(vals))

print("\n=== Final Average over 3 seeds ===")
for k, (mean_val, std_val) in final_metrics.items():
    print(f"{k.upper():<12}: {mean_val:.3f} ± {std_val:.3f}")


Seed 42 | Epoch 1/150, Loss: 0.6814, Train Acc: 0.5928, Val Acc: 0.7590
Seed 42 | Epoch 2/150, Loss: 0.6595, Train Acc: 0.7014, Val Acc: 0.7410
Seed 42 | Epoch 3/150, Loss: 0.6432, Train Acc: 0.7255, Val Acc: 0.7229
Seed 42 | Epoch 4/150, Loss: 0.6251, Train Acc: 0.7315, Val Acc: 0.7530
Seed 42 | Epoch 5/150, Loss: 0.6036, Train Acc: 0.7587, Val Acc: 0.7771
Seed 42 | Epoch 6/150, Loss: 0.5824, Train Acc: 0.7647, Val Acc: 0.7831
Seed 42 | Epoch 7/150, Loss: 0.5600, Train Acc: 0.7738, Val Acc: 0.7651
Seed 42 | Epoch 8/150, Loss: 0.5414, Train Acc: 0.7873, Val Acc: 0.7831
Seed 42 | Epoch 9/150, Loss: 0.5177, Train Acc: 0.7934, Val Acc: 0.7952
Seed 42 | Epoch 10/150, Loss: 0.5078, Train Acc: 0.7919, Val Acc: 0.8072
Seed 42 | Epoch 11/150, Loss: 0.4940, Train Acc: 0.7888, Val Acc: 0.7952
Seed 42 | Epoch 12/150, Loss: 0.4747, Train Acc: 0.8024, Val Acc: 0.8072
Seed 42 | Epoch 13/150, Loss: 0.4642, Train Acc: 0.8190, Val Acc: 0.8133
Seed 42 | Epoch 14/150, Loss: 0.4512, Train Acc: 0.8054, Val