In [1]:
!pip install torchsummary
!pip install torchviz



In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score, precision_recall_curve
)
import joblib
from tqdm import tqdm
from IPython.display import display
from torchsummary import summary
from torchviz import make_dot

In [3]:
# === Load and preprocess data ===
df = pd.read_csv("vcf_feature_vectors.csv")
df["GOLDEN"] = df["GOLDEN"].fillna(0).astype(int)

non_feature_cols = ["CHROM", "POS", "REF", "ALT", "GOLDEN"]
features = df.drop(columns=non_feature_cols).columns.tolist()
X = df[features].values.astype(np.float32)
y = df["GOLDEN"].values.astype(np.int64)

In [4]:
# === Define model ===
class FeedforwardNet(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [5]:
# === Visualize architecture (summary and torchviz) ===
dummy_input_dim = X.shape[1]  # assuming X already defined
model_viz = FeedforwardNet(input_dim=dummy_input_dim)
summary(model_viz, input_size=(dummy_input_dim,))

dummy_input = torch.randn(1, dummy_input_dim)
model_viz.eval()
make_dot(model_viz(dummy_input), params=dict(model_viz.named_parameters())).render("model_architecture", format="png")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                   [-1, 64]             768
       BatchNorm1d-2                   [-1, 64]             128
              ReLU-3                   [-1, 64]               0
           Dropout-4                   [-1, 64]               0
            Linear-5                   [-1, 32]           2,080
       BatchNorm1d-6                   [-1, 32]              64
              ReLU-7                   [-1, 32]               0
           Dropout-8                   [-1, 32]               0
            Linear-9                    [-1, 1]              33
          Sigmoid-10                    [-1, 1]               0
Total params: 3,073
Trainable params: 3,073
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.01
Estimated Total

'model_architecture.png'

In [6]:
# === Cross-validation ===
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
metrics_per_fold = []

for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
    X_train, X_val = X[train_idx], X[val_idx]
    y_train, y_val = y[train_idx], y[val_idx]

    model = FeedforwardNet(input_dim=X.shape[1])
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    train_loader = DataLoader(TensorDataset(torch.tensor(X_train), torch.tensor(y_train)), batch_size=128, shuffle=True)

    best_f1 = -1
    epochs_no_improve = 0
    patience = 5
    best_model_state = None

    model.train()
    for epoch in range(30):
        running_loss = 0.0
        with tqdm(train_loader, desc=f"Fold {fold+1} Epoch {epoch+1}/30", unit="batch") as tepoch:
            for xb, yb in tepoch:
                preds = model(xb).squeeze()
                loss = criterion(preds, yb.float())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                tepoch.set_postfix(loss=running_loss / (tepoch.n + 1))

        # Evaluate for early stopping
        model.eval()
        with torch.no_grad():
            val_probs = model(torch.tensor(X_val)).squeeze().numpy()
        val_preds = (val_probs >= 0.5).astype(int)
        val_f1 = f1_score(y_val, val_preds)

        if val_f1 > best_f1:
            best_f1 = val_f1
            epochs_no_improve = 0
            best_model_state = model.state_dict()
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"🔁 Early stopping at epoch {epoch+1}")
                break

    # Load best weights if available
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    else:
        print(f"⚠️ Warning: No improvement found during training for fold {fold + 1}. Skipping weight restoration.")


    # Final evaluation with tuned threshold
    model.eval()
    with torch.no_grad():
        probas = model(torch.tensor(X_val)).squeeze().numpy()

    precisions, recalls, thresholds = precision_recall_curve(y_val, probas)
    f1_scores = 2 * (precisions * recalls) / (precisions + recalls + 1e-10)
    best_idx = np.argmax(f1_scores)
    best_thresh = thresholds[best_idx]
    preds = (probas >= best_thresh).astype(int)

    metrics_per_fold.append({
        "Fold": fold + 1,
        "Accuracy": accuracy_score(y_val, preds),
        "Precision": precision_score(y_val, preds),
        "Recall": recall_score(y_val, preds),
        "F1": f1_score(y_val, preds),
        "ROC AUC": roc_auc_score(y_val, probas),
        "Best Threshold": best_thresh
    })

results_df = pd.DataFrame(metrics_per_fold)
results_df.loc["Mean"] = results_df.mean(numeric_only=True)
display(results_df)

Fold 1 Epoch 1/30: 100%|█████████████████████████████████████████| 1041/1041 [00:07<00:00, 146.89batch/s, loss=0.0992]
Fold 1 Epoch 2/30: 100%|█████████████████████████████████████████| 1041/1041 [00:09<00:00, 110.62batch/s, loss=0.0429]
Fold 1 Epoch 3/30: 100%|█████████████████████████████████████████| 1041/1041 [00:08<00:00, 123.90batch/s, loss=0.0404]
Fold 1 Epoch 4/30: 100%|█████████████████████████████████████████| 1041/1041 [00:08<00:00, 128.96batch/s, loss=0.0384]
Fold 1 Epoch 5/30: 100%|█████████████████████████████████████████| 1041/1041 [00:09<00:00, 115.61batch/s, loss=0.0375]
Fold 1 Epoch 6/30: 100%|█████████████████████████████████████████| 1041/1041 [00:09<00:00, 111.31batch/s, loss=0.0368]
Fold 1 Epoch 7/30: 100%|█████████████████████████████████████████| 1041/1041 [00:09<00:00, 114.99batch/s, loss=0.0367]
Fold 1 Epoch 8/30: 100%|██████████████████████████████████████████| 1041/1041 [00:07<00:00, 132.16batch/s, loss=0.036]
Fold 1 Epoch 9/30: 100%|████████████████████████

🔁 Early stopping at epoch 19


Fold 2 Epoch 1/30: 100%|██████████████████████████████████████████| 1041/1041 [00:10<00:00, 102.85batch/s, loss=0.107]
Fold 2 Epoch 2/30: 100%|█████████████████████████████████████████| 1041/1041 [00:08<00:00, 117.98batch/s, loss=0.0441]
Fold 2 Epoch 3/30: 100%|███████████████████████████████████████████| 1041/1041 [00:12<00:00, 81.09batch/s, loss=0.042]
Fold 2 Epoch 4/30: 100%|█████████████████████████████████████████| 1041/1041 [00:09<00:00, 114.51batch/s, loss=0.0395]
Fold 2 Epoch 5/30: 100%|█████████████████████████████████████████| 1041/1041 [00:09<00:00, 112.05batch/s, loss=0.0391]
Fold 2 Epoch 6/30: 100%|█████████████████████████████████████████| 1041/1041 [00:10<00:00, 102.08batch/s, loss=0.0379]
Fold 2 Epoch 7/30: 100%|█████████████████████████████████████████| 1041/1041 [00:09<00:00, 110.00batch/s, loss=0.0376]
Fold 2 Epoch 8/30: 100%|██████████████████████████████████████████| 1041/1041 [00:10<00:00, 99.55batch/s, loss=0.0374]
Fold 2 Epoch 9/30: 100%|████████████████████████

🔁 Early stopping at epoch 9


Fold 3 Epoch 1/30: 100%|█████████████████████████████████████████| 1041/1041 [00:08<00:00, 127.92batch/s, loss=0.0925]
Fold 3 Epoch 2/30: 100%|█████████████████████████████████████████| 1041/1041 [00:08<00:00, 128.87batch/s, loss=0.0429]
Fold 3 Epoch 3/30: 100%|█████████████████████████████████████████| 1041/1041 [00:08<00:00, 121.81batch/s, loss=0.0403]
Fold 3 Epoch 4/30: 100%|█████████████████████████████████████████| 1041/1041 [00:08<00:00, 128.97batch/s, loss=0.0387]
Fold 3 Epoch 5/30: 100%|█████████████████████████████████████████| 1041/1041 [00:09<00:00, 105.09batch/s, loss=0.0374]
Fold 3 Epoch 6/30: 100%|█████████████████████████████████████████| 1041/1041 [00:10<00:00, 102.40batch/s, loss=0.0372]
Fold 3 Epoch 7/30: 100%|█████████████████████████████████████████| 1041/1041 [00:08<00:00, 124.31batch/s, loss=0.0365]
Fold 3 Epoch 8/30: 100%|█████████████████████████████████████████| 1041/1041 [00:09<00:00, 109.34batch/s, loss=0.0361]
Fold 3 Epoch 9/30: 100%|████████████████████████

🔁 Early stopping at epoch 22


Fold 4 Epoch 1/30: 100%|██████████████████████████████████████████| 1041/1041 [00:09<00:00, 111.16batch/s, loss=0.119]
Fold 4 Epoch 2/30: 100%|█████████████████████████████████████████| 1041/1041 [00:07<00:00, 144.59batch/s, loss=0.0441]
Fold 4 Epoch 3/30: 100%|█████████████████████████████████████████| 1041/1041 [00:08<00:00, 127.63batch/s, loss=0.0419]
Fold 4 Epoch 4/30: 100%|█████████████████████████████████████████| 1041/1041 [00:08<00:00, 123.54batch/s, loss=0.0408]
Fold 4 Epoch 5/30: 100%|█████████████████████████████████████████| 1041/1041 [00:09<00:00, 112.13batch/s, loss=0.0387]
Fold 4 Epoch 6/30: 100%|█████████████████████████████████████████| 1041/1041 [00:08<00:00, 118.19batch/s, loss=0.0378]
Fold 4 Epoch 7/30: 100%|█████████████████████████████████████████| 1041/1041 [00:08<00:00, 124.08batch/s, loss=0.0375]
Fold 4 Epoch 8/30: 100%|█████████████████████████████████████████| 1041/1041 [00:08<00:00, 129.19batch/s, loss=0.0371]
Fold 4 Epoch 9/30: 100%|████████████████████████

🔁 Early stopping at epoch 25


Fold 5 Epoch 1/30: 100%|████████████████████████████████████████████| 1041/1041 [00:14<00:00, 71.56batch/s, loss=0.12]
Fold 5 Epoch 2/30: 100%|██████████████████████████████████████████| 1041/1041 [00:13<00:00, 76.55batch/s, loss=0.0447]
Fold 5 Epoch 3/30: 100%|██████████████████████████████████████████| 1041/1041 [00:11<00:00, 90.69batch/s, loss=0.0415]
Fold 5 Epoch 4/30: 100%|██████████████████████████████████████████| 1041/1041 [00:14<00:00, 70.83batch/s, loss=0.0395]
Fold 5 Epoch 5/30: 100%|█████████████████████████████████████████| 1041/1041 [00:09<00:00, 105.92batch/s, loss=0.0387]
Fold 5 Epoch 6/30: 100%|█████████████████████████████████████████| 1041/1041 [00:09<00:00, 105.73batch/s, loss=0.0378]
Fold 5 Epoch 7/30: 100%|██████████████████████████████████████████| 1041/1041 [00:08<00:00, 116.89batch/s, loss=0.037]
Fold 5 Epoch 8/30: 100%|█████████████████████████████████████████| 1041/1041 [00:09<00:00, 107.66batch/s, loss=0.0368]
Fold 5 Epoch 9/30: 100%|████████████████████████

🔁 Early stopping at epoch 22


Unnamed: 0,Fold,Accuracy,Precision,Recall,F1,ROC AUC,Best Threshold
0,1.0,0.987781,0.532609,0.454756,0.490613,0.964834,0.208299
1,2.0,0.988891,0.593168,0.444186,0.507979,0.965854,0.251655
2,3.0,0.987,0.496437,0.486047,0.491187,0.965279,0.190984
3,4.0,0.988711,0.580838,0.451163,0.507853,0.968028,0.232157
4,5.0,0.990332,0.698182,0.445476,0.543909,0.970826,0.323058
Mean,3.0,0.988543,0.580247,0.456325,0.508308,0.966964,0.24123


In [7]:
# Save average threshold
avg_best_thresh = results_df["Best Threshold"].dropna().mean()

# Save model metadata
metadata = {
    "features": features,  # or top_features if you're using selection
    "threshold": float(avg_best_thresh)
}
with open("model_metadata.json", "w") as f:
    json.dump(metadata, f)

In [8]:
# === Final training ===
final_model = FeedforwardNet(input_dim=X.shape[1])
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(final_model.parameters(), lr=0.001)
loader = DataLoader(TensorDataset(torch.tensor(X), torch.tensor(y)), batch_size=128, shuffle=True)

final_model.train()
for epoch in range(30):
    running_loss = 0.0
    with tqdm(loader, desc=f"Final Model Epoch {epoch+1}/30", unit="batch") as tepoch:
        for xb, yb in tepoch:
            preds = final_model(xb).squeeze()
            loss = criterion(preds, yb.float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            tepoch.set_postfix(loss=running_loss / (tepoch.n + 1))

Final Model Epoch 1/30: 100%|█████████████████████████████████████| 1302/1302 [00:14<00:00, 87.40batch/s, loss=0.0952]
Final Model Epoch 2/30: 100%|████████████████████████████████████| 1302/1302 [00:11<00:00, 113.41batch/s, loss=0.0482]
Final Model Epoch 3/30: 100%|████████████████████████████████████| 1302/1302 [00:08<00:00, 154.36batch/s, loss=0.0466]
Final Model Epoch 4/30: 100%|████████████████████████████████████| 1302/1302 [00:07<00:00, 175.84batch/s, loss=0.0453]
Final Model Epoch 5/30: 100%|████████████████████████████████████| 1302/1302 [00:07<00:00, 174.95batch/s, loss=0.0447]
Final Model Epoch 6/30: 100%|████████████████████████████████████| 1302/1302 [00:07<00:00, 175.66batch/s, loss=0.0441]
Final Model Epoch 7/30: 100%|████████████████████████████████████| 1302/1302 [00:07<00:00, 164.34batch/s, loss=0.0438]
Final Model Epoch 8/30: 100%|████████████████████████████████████| 1302/1302 [00:08<00:00, 159.64batch/s, loss=0.0438]
Final Model Epoch 9/30: 100%|███████████████████

In [9]:
# === Save model ===
torch.save(final_model.state_dict(), "final_nn_model.pt")