In [2]:
# CELL 1: imports & setup (PyTorch Geometric + RDKit)

import os
import random
import math

import pandas as pd
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

from torch.utils.data import Dataset

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as GeoDataLoader
from torch_geometric.nn import GINEConv, global_add_pool, global_mean_pool

from rdkit import Chem

from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    f1_score,
    precision_recall_curve,
)

from tqdm import tqdm

# reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cpu


In [4]:
# CELL 2: locate project root, load parquet data, configure columns

candidate_roots = [".", "..", "../.."]
project_root = None
for cand in candidate_roots:
    train_candidate = os.path.join(cand, "data", "train_brd4_50k_stratified.parquet")
    if os.path.exists(train_candidate):
        project_root = os.path.abspath(cand)
        break

if project_root is None:
    raise FileNotFoundError(
        "Could not find 'data/train_brd4_50k_stratified.parquet' from current directory."
    )

print("Project root:", project_root)

train_path = os.path.join(project_root, "data", "train_brd4_50k_stratified.parquet")
test_path  = os.path.join(project_root, "data", "test_brd4_50k.parquet")

train_df = pd.read_parquet(train_path)
test_df  = pd.read_parquet(test_path)

print("Train shape:", train_df.shape)
print("Test shape :", test_df.shape)
print("Train columns:", list(train_df.columns))

# Your dataset
SMILES_COL = "molecule_smiles"
LABEL_COL  = "binds"

assert SMILES_COL in train_df.columns, f"{SMILES_COL} not in train_df"
assert LABEL_COL  in train_df.columns, f"{LABEL_COL} not in train_df"

print("\nUsing:")
print("  SMILES_COL =", SMILES_COL)
print("  LABEL_COL  =", LABEL_COL, "dtype:", train_df[LABEL_COL].dtype)

display(train_df[[SMILES_COL, LABEL_COL]].head())


Project root: /Users/pabloperezgonzalez/F.I.T-PROTEINS-NEW
Train shape: (50000, 7)
Test shape : (50000, 6)
Train columns: ['id', 'protein_name', 'molecule_smiles', 'buildingblock1_smiles', 'buildingblock2_smiles', 'buildingblock3_smiles', 'binds']

Using:
  SMILES_COL = molecule_smiles
  LABEL_COL  = binds dtype: uint8


Unnamed: 0,molecule_smiles,binds
0,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCCOCC(=C)C)nc(Nc...,0
1,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCCC2OCCC2(C)C)nc...,0
2,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCCC(C)(O)CC)nc(N...,0
3,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCCCOCC)nc(NCC2CC...,0
4,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCc2nnc(C(C)(C)C)...,0


In [5]:
# CELL 3: atom/bond features and SMILES â†’ PyG Data

# common atom and bond types; final slot in one-hot is "other"
ATOM_TYPES = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I"]
BOND_TYPES = [
    Chem.rdchem.BondType.SINGLE,
    Chem.rdchem.BondType.DOUBLE,
    Chem.rdchem.BondType.TRIPLE,
    Chem.rdchem.BondType.AROMATIC,
]

def one_hot_with_other(x, choices):
    v = [0] * (len(choices) + 1)
    if x in choices:
        v[choices.index(x)] = 1
    else:
        v[-1] = 1
    return v

def atom_to_feature_vector(atom: Chem.rdchem.Atom):
    atom_type = one_hot_with_other(atom.GetSymbol(), ATOM_TYPES)
    formal_charge = atom.GetFormalCharge()
    is_aromatic = int(atom.GetIsAromatic())
    return torch.tensor(atom_type + [formal_charge, is_aromatic], dtype=torch.float)

def bond_to_feature_vector(bond: Chem.rdchem.Bond):
    bt = bond.GetBondType()
    bond_type_oh = one_hot_with_other(bt, BOND_TYPES)
    is_conjugated = int(bond.GetIsConjugated())
    is_in_ring = int(bond.IsInRing())
    return torch.tensor(bond_type_oh + [is_conjugated, is_in_ring], dtype=torch.float)

NODE_FEAT_DIM = len(ATOM_TYPES) + 1 + 2   # atom types + other + charge + aromatic
EDGE_FEAT_DIM = len(BOND_TYPES) + 1 + 2   # bond types + other + conjugated + in_ring

print("NODE_FEAT_DIM:", NODE_FEAT_DIM)
print("EDGE_FEAT_DIM:", EDGE_FEAT_DIM)


def smiles_to_data(smiles: str, y_value=None):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    Chem.SanitizeMol(mol)

    # node features
    x_list = [atom_to_feature_vector(atom) for atom in mol.GetAtoms()]
    if len(x_list) == 0:
        return None
    x = torch.stack(x_list, dim=0)  # [num_nodes, NODE_FEAT_DIM]

    # edges (undirected)
    edge_index_list = []
    edge_attr_list = []

    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        e = bond_to_feature_vector(bond)

        edge_index_list.append([i, j])
        edge_attr_list.append(e)

        edge_index_list.append([j, i])
        edge_attr_list.append(e)

    if len(edge_index_list) == 0:
        edge_index = torch.zeros((2, 0), dtype=torch.long)
        edge_attr = torch.zeros((0, EDGE_FEAT_DIM), dtype=torch.float)
    else:
        edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous()
        edge_attr = torch.stack(edge_attr_list, dim=0)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    if y_value is not None:
        data.y = torch.tensor([float(y_value)], dtype=torch.float)
    return data


# quick sanity check on 1 molecule
example_smiles = train_df[SMILES_COL].iloc[0]
example_label  = train_df[LABEL_COL].iloc[0]

example_data = smiles_to_data(example_smiles, example_label)

print(example_data)
print("x shape       :", example_data.x.shape)
print("edge_index    :", example_data.edge_index.shape)
print("edge_attr     :", example_data.edge_attr.shape)
print("y             :", example_data.y)


NODE_FEAT_DIM: 13
EDGE_FEAT_DIM: 7
Data(x=[33, 13], edge_index=[2, 68], edge_attr=[68, 7], y=[1])
x shape       : torch.Size([33, 13])
edge_index    : torch.Size([2, 68])
edge_attr     : torch.Size([68, 7])
y             : tensor([0.])


In [6]:
# CELL 4: PyTorch Dataset wrapping PyG Data objects + DataLoaders

class BRD4MoleculeDataset(Dataset):
    def __init__(self, df, smiles_col, label_col=None):
        self.graphs = []
        self.has_labels = label_col is not None

        iterator = zip(
            df[smiles_col],
            df[label_col] if label_col is not None else [None] * len(df),
        )

        for smiles, label in tqdm(iterator, total=len(df), desc="Building graphs"):
            if label_col is not None and pd.isna(label):
                continue
            g = smiles_to_data(smiles, label if label_col is not None else None)
            if g is not None:
                self.graphs.append(g)

        print(
            f"Built {len(self.graphs)} graphs from {len(df)} rows "
            f"({'with' if self.has_labels else 'without'} labels)."
        )

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        return self.graphs[idx]


train_dataset_full = BRD4MoleculeDataset(train_df, SMILES_COL, LABEL_COL)
test_dataset       = BRD4MoleculeDataset(test_df,  SMILES_COL, label_col=None)

# 80/20 train/val split
indices = np.arange(len(train_dataset_full))
np.random.shuffle(indices)
split = int(0.8 * len(indices))
train_idx, val_idx = indices[:split], indices[split:]

from torch.utils.data import Subset

train_dataset = Subset(train_dataset_full, train_idx)
val_dataset   = Subset(train_dataset_full, val_idx)

print("Train graphs:", len(train_dataset))
print("Val graphs  :", len(val_dataset))
print("Test graphs :", len(test_dataset))

BATCH_SIZE = 256  # adjust if you run out of memory

train_loader = GeoDataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = GeoDataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False)
test_loader  = GeoDataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False)

# sanity-check one batch
batch_example = next(iter(train_loader))
print(batch_example)
print("Batch x shape      :", batch_example.x.shape)
print("Batch edge_attr    :", batch_example.edge_attr.shape)
print("Batch y shape      :", batch_example.y.shape)
print("num_graphs in batch:", batch_example.num_graphs)


Building graphs: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 50000/50000 [02:23<00:00, 349.13it/s]


Built 50000 graphs from 50000 rows (with labels).


Building graphs: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 50000/50000 [02:50<00:00, 292.76it/s]


Built 50000 graphs from 50000 rows (without labels).
Train graphs: 40000
Val graphs  : 10000
Test graphs : 50000
DataBatch(x=[9783, 13], edge_index=[2, 20716], edge_attr=[20716, 7], y=[256], batch=[9783], ptr=[257])
Batch x shape      : torch.Size([9783, 13])
Batch edge_attr    : torch.Size([20716, 7])
Batch y shape      : torch.Size([256])
num_graphs in batch: 256


In [12]:
# CELL 5 (UPDATED): focal loss + improved GNN with sum+mean pooling

def binary_focal_loss_with_logits(
    logits,
    targets,
    alpha: float = 0.99,   # more emphasis on positives
    gamma: float = 1.5,    # slightly less peaky than 2.0
    reduction: str = "mean",
):
    """
    Focal loss for severe class imbalance.
    alpha is the weight for the positive class.
    """
    logits = logits.view(-1)
    targets = targets.view(-1).float()

    bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
    p_t = torch.exp(-bce)
    loss = alpha * (1 - p_t) ** gamma * bce

    if reduction == "mean":
        return loss.mean()
    elif reduction == "sum":
        return loss.sum()
    else:
        return loss


class GNNWithVirtualNode(nn.Module):
    def __init__(
        self,
        node_in_dim: int,
        edge_in_dim: int,
        hidden_dim: int = 128,
        num_layers: int = 4,
        dropout: float = 0.2,
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout

        self.node_encoder = nn.Linear(node_in_dim, hidden_dim)
        self.edge_encoder = nn.Linear(edge_in_dim, hidden_dim)

        self.convs = nn.ModuleList()
        self.bns   = nn.ModuleList()

        for _ in range(num_layers):
            mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            conv = GINEConv(mlp)  # GIN with edge features
            self.convs.append(conv)
            self.bns.append(nn.BatchNorm1d(hidden_dim))

        # virtual node embedding
        self.virtualnode_embedding = nn.Embedding(1, hidden_dim)
        nn.init.constant_(self.virtualnode_embedding.weight.data, 0.0)

        # MLPs to update virtual node after each layer (except last)
        self.mlp_virtual_list = nn.ModuleList()
        for _ in range(num_layers - 1):
            self.mlp_virtual_list.append(
                nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.ReLU(),
                    nn.Linear(hidden_dim, hidden_dim),
                )
            )

        # graph-level readout MLP
        # note: 2*hidden_dim input because we concat [sum || mean] pooling
        self.mlp_readout = nn.Sequential(
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, data: Data):
        x, edge_index, edge_attr, batch = (
            data.x,
            data.edge_index,
            data.edge_attr,
            data.batch,
        )

        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)

        num_graphs = batch.max().item() + 1
        virtualnode_emb = self.virtualnode_embedding.weight.repeat(num_graphs, 1)

        for layer in range(self.num_layers):
            # add virtual node embedding to node features
            x = x + virtualnode_emb[batch]

            x = self.convs[layer](x, edge_index, edge_attr)
            x = self.bns[layer](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

            # update virtual node (except last layer)
            if layer < self.num_layers - 1:
                pooled = global_add_pool(x, batch)  # [num_graphs, hidden_dim]
                virtualnode_emb = virtualnode_emb + self.mlp_virtual_list[layer](pooled)

        # global pooling: use both sum and mean
        graph_sum  = global_add_pool(x, batch)           # [num_graphs, hidden_dim]
        graph_mean = global_mean_pool(x, batch)          # [num_graphs, hidden_dim]
        graph_emb  = torch.cat([graph_sum, graph_mean], dim=1)  # [num_graphs, 2H]

        logits = self.mlp_readout(graph_emb).view(-1)    # [num_graphs]
        return logits


# instantiate model (same as before)
sample = train_dataset_full[0]
node_in_dim = sample.x.size(1)
edge_in_dim = sample.edge_attr.size(1)

print("node_in_dim:", node_in_dim, "| edge_in_dim:", edge_in_dim)

model = GNNWithVirtualNode(
    node_in_dim=node_in_dim,
    edge_in_dim=edge_in_dim,
    hidden_dim=128,
    num_layers=4,
    dropout=0.2,
).to(device)

print(model)



node_in_dim: 13 | edge_in_dim: 7
GNNWithVirtualNode(
  (node_encoder): Linear(in_features=13, out_features=128, bias=True)
  (edge_encoder): Linear(in_features=7, out_features=128, bias=True)
  (convs): ModuleList(
    (0-3): 4 x GINEConv(nn=Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
    ))
  )
  (bns): ModuleList(
    (0-3): 4 x BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (virtualnode_embedding): Embedding(1, 128)
  (mlp_virtual_list): ModuleList(
    (0-2): 3 x Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (mlp_readout): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=128, out_features=1, bias=True)


In [14]:
# CELL 6 (FIXED): longer training + LR scheduler + early stopping on ROC-AUC

def train_one_epoch(model, loader, optimizer, alpha=0.99, gamma=1.5):
    model.train()
    total_loss = 0.0
    n_graphs = 0

    for batch in tqdm(loader, desc="Train", leave=False):
        batch = batch.to(device)
        optimizer.zero_grad()

        logits = model(batch)
        targets = batch.y.view(-1)

        loss = binary_focal_loss_with_logits(
            logits,
            targets,
            alpha=alpha,
            gamma=gamma,
            reduction="mean",
        )
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch.num_graphs
        n_graphs += batch.num_graphs

    return total_loss / n_graphs


@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total_loss = 0.0
    n_graphs = 0
    all_probs = []
    all_targets = []

    for batch in tqdm(loader, desc="Val", leave=False):
        batch = batch.to(device)
        logits = model(batch)
        targets = batch.y.view(-1)

        loss = F.binary_cross_entropy_with_logits(logits, targets, reduction="sum")
        total_loss += loss.item()
        n_graphs += batch.num_graphs

        probs = torch.sigmoid(logits).cpu().numpy()
        all_probs.append(probs)
        all_targets.append(targets.cpu().numpy())

    y_true = np.concatenate(all_targets)
    y_prob = np.concatenate(all_probs)

    roc_auc = roc_auc_score(y_true, y_prob)
    ap = average_precision_score(y_true, y_prob)

    return total_loss / n_graphs, roc_auc, ap


MAX_EPOCHS = 30       # upper bound; early stopping will usually stop earlier
LR = 1e-3
WEIGHT_DECAY = 1e-5

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="max",      # maximize ROC-AUC
    factor=0.5,
    patience=2,
    verbose=True,
)

best_val_auc = -float("inf")
best_val_ap = 0.0
best_epoch = 0
no_improve = 0

print("Training on device:", device)
for epoch in range(1, MAX_EPOCHS + 1):
    print(f"\nEpoch {epoch}/{MAX_EPOCHS}")
    # ðŸ”§ FIX: pass optimizer into train_one_epoch
    train_loss = train_one_epoch(model, train_loader, optimizer)
    val_loss, val_auc, val_ap = evaluate(model, val_loader)

    scheduler.step(val_auc)

    print(
        f"Epoch {epoch:02d} | "
        f"train loss: {train_loss:.4f} | "
        f"val loss: {val_loss:.4f} | "
        f"val ROC-AUC: {val_auc:.4f} | "
        f"val AP: {val_ap:.4f}"
    )

    if val_auc > best_val_auc + 1e-4:
        best_val_auc = val_auc
        best_val_ap = val_ap
        best_epoch = epoch
        no_improve = 0
    else:
        no_improve += 1

    if no_improve >= 5:
        print(f"Early stopping: no ROC-AUC improvement for {no_improve} epochs.")
        break

print(
    f"\nBest validation ROC-AUC: {best_val_auc:.4f}, "
    f"AP: {best_val_ap:.4f} at epoch {best_epoch}"
)




Training on device: cpu

Epoch 1/30


                                                        

Epoch 01 | train loss: 0.0245 | val loss: 0.7273 | val ROC-AUC: 0.5690 | val AP: 0.0066

Epoch 2/30


                                                        

Epoch 02 | train loss: 0.0147 | val loss: 0.0700 | val ROC-AUC: 0.7893 | val AP: 0.0221

Epoch 3/30


                                                        

Epoch 03 | train loss: 0.0132 | val loss: 0.0805 | val ROC-AUC: 0.8040 | val AP: 0.0515

Epoch 4/30


                                                        

Epoch 04 | train loss: 0.0118 | val loss: 0.0490 | val ROC-AUC: 0.7696 | val AP: 0.0717

Epoch 5/30


                                                        

Epoch 05 | train loss: 0.0118 | val loss: 0.0328 | val ROC-AUC: 0.8018 | val AP: 0.0288

Epoch 6/30


                                                        

Epoch 06 | train loss: 0.0111 | val loss: 0.0588 | val ROC-AUC: 0.8426 | val AP: 0.0370

Epoch 7/30


                                                        

Epoch 07 | train loss: 0.0110 | val loss: 0.0697 | val ROC-AUC: 0.8541 | val AP: 0.0874

Epoch 8/30


                                                        

Epoch 08 | train loss: 0.0113 | val loss: 0.0728 | val ROC-AUC: 0.8697 | val AP: 0.0623

Epoch 9/30


                                                        

Epoch 09 | train loss: 0.0114 | val loss: 0.0644 | val ROC-AUC: 0.8510 | val AP: 0.0753

Epoch 10/30


                                                        

Epoch 10 | train loss: 0.0111 | val loss: 0.0615 | val ROC-AUC: 0.9022 | val AP: 0.1131

Epoch 11/30


                                                        

Epoch 11 | train loss: 0.0099 | val loss: 0.0836 | val ROC-AUC: 0.8933 | val AP: 0.0799

Epoch 12/30


                                                        

Epoch 12 | train loss: 0.0103 | val loss: 0.0708 | val ROC-AUC: 0.8540 | val AP: 0.0613

Epoch 13/30


                                                        

Epoch 13 | train loss: 0.0100 | val loss: 0.0460 | val ROC-AUC: 0.8411 | val AP: 0.0524

Epoch 14/30


                                                        

Epoch 14 | train loss: 0.0094 | val loss: 0.0445 | val ROC-AUC: 0.9172 | val AP: 0.1644

Epoch 15/30


                                                        

Epoch 15 | train loss: 0.0086 | val loss: 0.1032 | val ROC-AUC: 0.9111 | val AP: 0.1119

Epoch 16/30


                                                        

Epoch 16 | train loss: 0.0086 | val loss: 0.0392 | val ROC-AUC: 0.9046 | val AP: 0.1336

Epoch 17/30


                                                        

Epoch 17 | train loss: 0.0088 | val loss: 0.0410 | val ROC-AUC: 0.8942 | val AP: 0.1490

Epoch 18/30


                                                        

Epoch 18 | train loss: 0.0082 | val loss: 0.0471 | val ROC-AUC: 0.9162 | val AP: 0.1693

Epoch 19/30


                                                        

Epoch 19 | train loss: 0.0079 | val loss: 0.0373 | val ROC-AUC: 0.9163 | val AP: 0.1775
Early stopping: no ROC-AUC improvement for 5 epochs.

Best validation ROC-AUC: 0.9172, AP: 0.1644 at epoch 14


In [9]:
# CELL 7: detailed metrics on validation set (AP, ROC-AUC, F1, best threshold)

@torch.no_grad()
def compute_metrics_pyg(model, loader, name="val"):
    model.eval()
    all_probs = []
    all_targets = []

    for batch in loader:
        batch = batch.to(device)
        logits = model(batch)
        probs = torch.sigmoid(logits).cpu().numpy()
        y = batch.y.view(-1).cpu().numpy()

        all_probs.append(probs)
        all_targets.append(y)

    y_true = np.concatenate(all_targets)
    y_prob = np.concatenate(all_probs)

    prevalence = float(y_true.mean())
    roc_auc = roc_auc_score(y_true, y_prob)
    ap = average_precision_score(y_true, y_prob)

    prec, rec, thr = precision_recall_curve(y_true, y_prob)
    f1_scores = 2 * prec[:-1] * rec[:-1] / (prec[:-1] + rec[:-1] + 1e-8)
    best_idx = int(np.argmax(f1_scores))
    best_thr = float(thr[best_idx])
    best_f1 = float(f1_scores[best_idx])

    f1_at_05 = float(f1_score(y_true, (y_prob >= 0.5).astype(int)))

    print(f"=== Metrics on {name} set ===")
    print(f"Prevalence (mean label) : {prevalence:.4f}")
    print(f"ROC-AUC                 : {roc_auc:.4f}")
    print(f"Average Precision (AP)  : {ap:.4f}")
    print(f"Best F1                 : {best_f1:.4f} at threshold {best_thr:.3f}")
    print(f"F1 at threshold 0.5     : {f1_at_05:.4f}")

    return {
        "prevalence": prevalence,
        "roc_auc": roc_auc,
        "ap": ap,
        "best_f1": best_f1,
        "best_thr": best_thr,
        "f1_at_0.5": f1_at_05,
    }

metrics_val = compute_metrics_pyg(model, val_loader, name="val_full")


=== Metrics on val_full set ===
Prevalence (mean label) : 0.0045
ROC-AUC                 : 0.8493
Average Precision (AP)  : 0.0520
Best F1                 : 0.1446 at threshold 0.354
F1 at threshold 0.5     : 0.0000


In [1]:
# CELL 7: detailed metrics on validation set (AP, ROC-AUC, F1, best threshold)

@torch.no_grad()
def compute_metrics_pyg(model, loader, name="val"):
    model.eval()
    all_probs = []
    all_targets = []

    for batch in loader:
        batch = batch.to(device)
        logits = model(batch)
        probs = torch.sigmoid(logits).cpu().numpy()
        y = batch.y.view(-1).cpu().numpy()

        all_probs.append(probs)
        all_targets.append(y)

    y_true = np.concatenate(all_targets)
    y_prob = np.concatenate(all_probs)

    prevalence = float(y_true.mean())
    roc_auc = roc_auc_score(y_true, y_prob)
    ap = average_precision_score(y_true, y_prob)

    prec, rec, thr = precision_recall_curve(y_true, y_prob)
    f1_scores = 2 * prec[:-1] * rec[:-1] / (prec[:-1] + rec[:-1] + 1e-8)
    best_idx = int(np.argmax(f1_scores))
    best_thr = float(thr[best_idx])
    best_f1 = float(f1_scores[best_idx])

    f1_at_05 = float(f1_score(y_true, (y_prob >= 0.5).astype(int)))

    print(f"=== Metrics on {name} set ===")
    print(f"Prevalence (mean label) : {prevalence:.4f}")
    print(f"ROC-AUC                 : {roc_auc:.4f}")
    print(f"Average Precision (AP)  : {ap:.4f}")
    print(f"Best F1                 : {best_f1:.4f} at threshold {best_thr:.3f}")
    print(f"F1 at threshold 0.5     : {f1_at_05:.4f}")

    return {
        "prevalence": prevalence,
        "roc_auc": roc_auc,
        "ap": ap,
        "best_f1": best_f1,
        "best_thr": best_thr,
        "f1_at_0.5": f1_at_05,
    }

metrics_val = compute_metrics_pyg(model, val_loader, name="val_full")


NameError: name 'torch' is not defined

In [11]:
# CELL 8 (optional): predict on test set and save CSV

@torch.no_grad()
def predict_loader_pyg(model, loader):
    model.eval()
    all_probs = []
    for batch in loader:
        batch = batch.to(device)
        logits = model(batch)
        probs = torch.sigmoid(logits).cpu().numpy()
        all_probs.append(probs)
    return np.concatenate(all_probs)

test_preds = predict_loader_pyg(model, test_loader)
print("Test predictions shape:", test_preds.shape)

submission = pd.DataFrame({
    "id": np.arange(len(test_preds)),
    "y_pred": test_preds,
})

artifact_dir = os.path.join(project_root, "notebooks", "neural_networks", "artifacts")
os.makedirs(artifact_dir, exist_ok=True)
save_path = os.path.join(artifact_dir, "gnn_pyg_predictions.csv")
submission.to_csv(save_path, index=False)
print("Saved predictions to:", save_path)

submission.head()


Test predictions shape: (50000,)
Saved predictions to: /Users/pabloperezgonzalez/F.I.T-PROTEINS-NEW/notebooks/neural_networks/artifacts/gnn_pyg_predictions.csv


Unnamed: 0,id,y_pred
0,0,0.309667
1,1,0.005264
2,2,0.212053
3,3,0.331305
4,4,0.039361
