In [None]:
# Notebook native
#import os
#os.environ["SWEEP_ID"] = "pwg5qe4n"
#os.environ["WANDB_DUMMY_RUN"] = "False"

In [None]:
import copy
import os
import sys
from typing import Any, Dict, List, Tuple
from sklearn.metrics import roc_auc_score, precision_score, recall_score
import numpy as np
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
import wandb
from loguru import logger
from sklearn.model_selection import StratifiedShuffleSplit
from torch.nn import BatchNorm1d, Linear, ModuleList, ReLU, Sequential
from torch_geometric.data import Dataset
from torch_geometric.datasets import MoleculeNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
    GATv2Conv,
    GCNConv,
    GINEConv,
    global_add_pool,
    global_mean_pool,
)
import random

logger.remove()
logger.add(
    sys.stderr,
    format="<green>{time:mm:ss}</green> | <level>{message}</level>",
    level="INFO",
)

SWEEP_CONFIG: Dict[str, Any] = {
    "method": "bayes",
    "metric": {"name": "val_auc", "goal": "maximize"},
    "parameters": {
        "model_type": {"values": ["GCN_Classic", "GCN_Res", "GINE", "GATv2"]},
        "num_layers": {"values": [3, 4, 5]},
        "hidden_channels": {"values": [64, 128, 256]},
        "heads": {"values": [2, 4]},
        "dropout": {"min": 0.1, "max": 0.6},
        "weight_decay": {"values": [1e-4, 1e-5, 0.0]},
        "learning_rate": {"min": 0.0001, "max": 0.005},
        "batch_size": {"values": [32, 64, 128]},
        "consistency_weight": {"min": 0.1, "max": 2.0},
        "teacher_alpha": {"min": 0.95, "max": 0.999},
        "label_rate": {"values": [0.1]},
    },
}

random_seed = np.random.randint(0, 100000)

class Config:
    """System and experiment configuration."""

    PROJECT: str = "drug-discovery-ssl"
    ENTITY: str = "REDACTED"
    DUMMY_RUN: bool = os.getenv("WANDB_DUMMY_RUN", "False").lower() == "true"
    DUMMY_SIZE: int = 2000
    DATASET: str = "BBBP"
    DEVICE: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    EPOCHS: int = 100
    SEED: int = random_seed


def get_stratified_validation_split(
    train_val_set: Dataset, val_ratio: float = 0.1
) -> Tuple[Dataset, Dataset]:
    """Splits combined pool using stratification to ensure class balance."""
    # extract labels via iteration to handle subsets correctly
    labels = [d.y.item() for d in train_val_set]
    y_full = np.nan_to_num(np.array(labels), 0)

    splitter = StratifiedShuffleSplit(
        n_splits=1, test_size=val_ratio, random_state=Config.SEED
    )
    indices = np.arange(len(train_val_set))

    for train_idx, val_idx in splitter.split(indices.reshape(-1, 1), y_full):
        return train_val_set[train_idx], train_val_set[val_idx]

    return train_val_set, train_val_set


def get_scaffold_split(dataset: Dataset) -> Tuple[Dataset, Dataset, Dataset]:
    """Deterministic split via smiles sorting to enforce generalization (OOD)."""
    if not hasattr(dataset.data, "smiles"):
        logger.warning("data -> smiles missing: random split fallback")
        return dataset.shuffle(), dataset.shuffle(), dataset.shuffle()

    smiles: List[str] = dataset.data.smiles[: len(dataset)]
    perm: List[int] = sorted(range(len(dataset)), key=lambda k: smiles[k])

    n: int = len(dataset)
    train_end: int = int(0.8 * n)
    val_end: int = int(0.9 * n)

    # split: 80% pool | 10% unused | 10% strict OOD test
    return (
        dataset[perm[:train_end]],
        dataset[perm[train_end:val_end]],
        dataset[perm[val_end:]],
    )


def get_loaders(config: wandb.Config) -> Dict[str, Any]:
    """Prepares dataloaders with salt removal and scaffold splitting."""
    # transform: force_reload ensures lcc (salt removal) applies
    dataset = MoleculeNet(
        root="./data/MoleculeNet",
        name=Config.DATASET,
        pre_transform=T.LargestConnectedComponents(),
        force_reload=True,
    )

    if Config.DUMMY_RUN:
        dataset = dataset[: Config.DUMMY_SIZE]
        logger.warning(f"mode -> dummy run active: subset to {len(dataset)} graphs")
    else:
        logger.info(f"mode -> production: full dataset ({len(dataset)} graphs)")

    # calc: imbalance weight
    y_vals = [d.y.item() for d in dataset]
    y_arr = np.nan_to_num(np.array(y_vals), 0)
    pos_count = y_arr.sum()
    neg_count = len(y_arr) - pos_count
    pos_weight: float = neg_count / (pos_count + 1e-5)

    # splits
    train_val_pool, _, test_ds = get_scaffold_split(dataset)
    train_ds, val_ds = get_stratified_validation_split(train_val_pool, val_ratio=0.11)

    # ssl scarcity
    train_ds = train_ds.shuffle()
    n_lbl: int = int(config.label_rate * len(train_ds))

    return {
        "labeled": DataLoader(
            train_ds[:n_lbl],
            batch_size=config.batch_size,
            shuffle=True,
            drop_last=True,
        ),
        "unlabeled": DataLoader(
            train_ds[n_lbl:],
            batch_size=config.batch_size,
            shuffle=True,
            drop_last=True,
        ),
        "val": DataLoader(val_ds, batch_size=config.batch_size),
        "test": DataLoader(test_ds, batch_size=config.batch_size),
        "pos_weight": torch.tensor([pos_weight]).to(Config.DEVICE),
    }


class BaseGNN(torch.nn.Module):
    """Abstract base for consistency in forward pass signature."""

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: torch.Tensor,
        batch: torch.Tensor,
    ) -> torch.Tensor:
        raise NotImplementedError


class GCN_Classic(BaseGNN):
    """The Welling Replica (ICLR 2017). No residuals, mean pool."""

    def __init__(
        self,
        in_dim: int,
        hidden: int,
        out_dim: int,
        edge_dim: int,
        dropout: float,
        layers: int,
    ) -> None:
        super().__init__()
        self.node_lin = Linear(in_dim, hidden)
        self.convs = ModuleList()

        for _ in range(layers):
            self.convs.append(GCNConv(hidden, hidden))

        self.lin = Linear(hidden, out_dim)
        self.dropout = dropout

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: torch.Tensor,
        batch: torch.Tensor,
    ) -> torch.Tensor:
        x = self.node_lin(x.float())

        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = global_mean_pool(x, batch)
        return self.lin(x)


class GCN_Res(BaseGNN):
    """Modern GCN baseline. Residuals, batchnorm, add pool."""

    def __init__(
        self,
        in_dim: int,
        hidden: int,
        out_dim: int,
        edge_dim: int,
        dropout: float,
        layers: int,
    ) -> None:
        super().__init__()
        self.node_lin = Linear(in_dim, hidden)
        self.convs = ModuleList()
        self.bns = ModuleList()

        for _ in range(layers):
            self.convs.append(GCNConv(hidden, hidden))
            self.bns.append(BatchNorm1d(hidden))

        self.lin = Linear(hidden, out_dim)
        self.dropout = dropout

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: torch.Tensor,
        batch: torch.Tensor,
    ) -> torch.Tensor:
        x = self.node_lin(x.float())

        for conv, bn in zip(self.convs, self.bns):
            identity = x
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = x + identity

        x = global_add_pool(x, batch)
        return self.lin(x)


class GINE(BaseGNN):
    """Graph Isomorphism Network + Edges. Maximally expressive."""

    def __init__(
        self,
        in_dim: int,
        hidden: int,
        out_dim: int,
        edge_dim: int,
        dropout: float,
        layers: int,
    ) -> None:
        super().__init__()
        self.node_lin = Linear(in_dim, hidden)
        self.edge_lin = Linear(edge_dim, hidden)
        self.convs = ModuleList()
        self.bns = ModuleList()

        for _ in range(layers):
            mlp = Sequential(
                Linear(hidden, hidden),
                BatchNorm1d(hidden),
                ReLU(),
                Linear(hidden, hidden),
            )
            self.convs.append(GINEConv(mlp, edge_dim=hidden))
            self.bns.append(BatchNorm1d(hidden))

        self.lin = Linear(hidden, out_dim)
        self.dropout = dropout

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: torch.Tensor,
        batch: torch.Tensor,
    ) -> torch.Tensor:
        x = self.node_lin(x.float())
        edge_attr = self.edge_lin(edge_attr.float())

        for conv, bn in zip(self.convs, self.bns):
            identity = x
            x = conv(x, edge_index, edge_attr=edge_attr)
            x = bn(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = x + identity

        x = global_add_pool(x, batch)
        return self.lin(x)


class GATv2(BaseGNN):
    """Dynamic Attention with Edge Features."""

    def __init__(
        self,
        in_dim: int,
        hidden: int,
        out_dim: int,
        edge_dim: int,
        dropout: float,
        layers: int,
        heads: int,
    ) -> None:
        super().__init__()
        self.node_lin = Linear(in_dim, hidden)
        self.edge_lin = Linear(edge_dim, hidden)
        self.convs = ModuleList()
        self.bns = ModuleList()

        for _ in range(layers):
            self.convs.append(
                GATv2Conv(hidden, hidden, heads=heads, concat=False, edge_dim=hidden)
            )
            self.bns.append(BatchNorm1d(hidden))

        self.lin = Linear(hidden, out_dim)
        self.dropout = dropout

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: torch.Tensor,
        batch: torch.Tensor,
    ) -> torch.Tensor:
        x = self.node_lin(x.float())
        edge_attr = self.edge_lin(edge_attr.float())

        for conv, bn in zip(self.convs, self.bns):
            identity = x
            x = conv(x, edge_index, edge_attr=edge_attr)
            x = bn(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = x + identity

        x = global_add_pool(x, batch)
        return self.lin(x)


def get_model(config: wandb.Config, in_dim: int, edge_dim: int) -> BaseGNN:
    """Factory dispatch for architecture selection."""
    if config.model_type == "GCN_Classic":
        return GCN_Classic(
            in_dim,
            config.hidden_channels,
            1,
            edge_dim,
            config.dropout,
            config.num_layers,
        )
    elif config.model_type == "GCN_Res":
        return GCN_Res(
            in_dim,
            config.hidden_channels,
            1,
            edge_dim,
            config.dropout,
            config.num_layers,
        )
    elif config.model_type == "GINE":
        return GINE(
            in_dim,
            config.hidden_channels,
            1,
            edge_dim,
            config.dropout,
            config.num_layers,
        )
    elif config.model_type == "GATv2":
        return GATv2(
            in_dim,
            config.hidden_channels,
            1,
            edge_dim,
            config.dropout,
            config.num_layers,
            config.heads,
        )
    else:
        raise ValueError(f"unknown model_type: {config.model_type}")


def update_teacher(
    student: torch.nn.Module, teacher: torch.nn.Module, alpha: float
) -> None:
    for s, t in zip(student.parameters(), teacher.parameters()):
        t.data.mul_(alpha).add_(s.data, alpha=1.0 - alpha)






def train_epoch(
    student: BaseGNN,
    teacher: BaseGNN,
    loaders: Dict,
    optimizer: torch.optim.Adam,
    config: wandb.Config,
) -> Dict[str, float]:
    student.train()
    teacher.train()
    metrics = {"loss": 0.0, "sup": 0.0, "cons": 0.0}
    iter_lbl = iter(loaders["labeled"])

    for batch_unlbl in loaders["unlabeled"]:
        try:
            batch_lbl = next(iter_lbl)
        except StopIteration:
            iter_lbl = iter(loaders["labeled"])
            batch_lbl = next(iter_lbl)

        # move to device
        batch_lbl = batch_lbl.to(Config.DEVICE)
        batch_unlbl = batch_unlbl.to(Config.DEVICE)

        optimizer.zero_grad()

        # supervised
        out_lbl = student(
            batch_lbl.x,
            batch_lbl.edge_index,
            batch_lbl.edge_attr,
            batch_lbl.batch,
        )
        mask = ~torch.isnan(batch_lbl.y)
        loss_sup = F.binary_cross_entropy_with_logits(
            out_lbl[mask],
            batch_lbl.y[mask].float(),
            pos_weight=loaders["pos_weight"],
        )

        # consistency
        out_student = student(
            batch_unlbl.x,
            batch_unlbl.edge_index,
            batch_unlbl.edge_attr,
            batch_unlbl.batch,
        )
        with torch.no_grad():
            out_teacher = teacher(
                batch_unlbl.x,
                batch_unlbl.edge_index,
                batch_unlbl.edge_attr,
                batch_unlbl.batch,
            )

        loss_cons = F.mse_loss(torch.sigmoid(out_student), torch.sigmoid(out_teacher))

        # optimize
        loss = loss_sup + (config.consistency_weight * loss_cons)
        loss.backward()
        optimizer.step()
        update_teacher(student, teacher, config.teacher_alpha)

        metrics["loss"] += loss.item()
        metrics["sup"] += loss_sup.item()
        metrics["cons"] += loss_cons.item()

    steps = len(loaders["unlabeled"])
    return {k: v / steps for k, v in metrics.items()}


# 1. Helper for Enrichment Factor
def calculate_enrichment_factor(preds: np.ndarray, targets: np.ndarray, percentile: float = 0.05) -> float:
    """Calculates how many actives are in the top X% of predictions."""
    n_total = len(preds)
    n_top = int(percentile * n_total)
    
    if n_top == 0: return 0.0
    
    # Sort descending
    sorted_indices = np.argsort(preds)[::-1]
    top_indices = sorted_indices[:n_top]
    
    n_actives_top = targets[top_indices].sum()
    subset_hit_rate = n_actives_top / n_top
    base_hit_rate = targets.sum() / n_total
    
    return subset_hit_rate / (base_hit_rate + 1e-9)

# 2. Updated Evaluation Function
@torch.no_grad()
def evaluate(model: BaseGNN, loader: DataLoader, log_table: bool = False, split_name: str = "val") -> float:
    model.eval()
    preds, targets = [], []
    
    for batch in loader:
        batch = batch.to(Config.DEVICE)
        out = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        preds.append(torch.sigmoid(out).cpu())
        targets.append(batch.y.cpu())

    cat_preds = torch.cat(preds).numpy().flatten()
    cat_targets = torch.cat(targets).numpy().flatten()
    
    # Mask NaNs
    mask = ~np.isnan(cat_targets)
    clean_preds = cat_preds[mask]
    clean_targets = cat_targets[mask]

    try:
        auc = roc_auc_score(clean_targets, clean_preds)
    except ValueError:
        return 0.5

    # log deep metrics
    if log_table:
        hard_preds = (clean_preds > 0.5).astype(int)
        prec = precision_score(clean_targets, hard_preds, zero_division=0)
        rec = recall_score(clean_targets, hard_preds, zero_division=0)
        ef_10 = calculate_enrichment_factor(clean_preds, clean_targets, percentile=0.1)
        
        # Log Scalar Metrics
        logger.info(f"test metrics -> auc: {auc:.4f} | ef@10%: {ef_10:.2f} | prec: {prec:.2f} | rec: {rec:.2f}")
        wandb.log({
            f"{split_name}_auc": auc,
            f"{split_name}_ef_10": ef_10,
            f"{split_name}_precision": prec,
            f"{split_name}_recall": rec
        })
        
        # Log Raw Data Table (The "Safety Net")
        table = wandb.Table(data=[[p, t] for p, t in zip(clean_preds, clean_targets)], columns=["pred_prob", "target"])
        wandb.log({f"{split_name}_predictions": table})
        
        # Log ROC Curve directly
        wandb.log({f"{split_name}_roc": wandb.plot.roc_curve(clean_targets, np.stack([1-clean_preds, clean_preds], axis=1))})

    return float(auc)


def main_sweep() -> None:
    project_name = Config.PROJECT
    if Config.DUMMY_RUN:
        project_name = f"{project_name}_min"

    with wandb.init(project=project_name, entity=Config.ENTITY):
        config = wandb.config
        loaders = get_loaders(config)

        # dynamic dim extraction
        sample = next(iter(loaders["labeled"]))
        in_dim = sample.num_node_features
        edge_dim = sample.num_edge_features

        # factory init
        student = get_model(config, in_dim, edge_dim).to(Config.DEVICE)
        teacher = get_model(config, in_dim, edge_dim).to(Config.DEVICE)

        teacher.load_state_dict(student.state_dict())
        for p in teacher.parameters():
            p.requires_grad = False

        optimizer = torch.optim.Adam(
            student.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay,
        )

        logger.info(
            f"start -> model: {config.model_type} | layers: {config.num_layers}"
        )

        best_model_state = None
        best_val_auc = 0.0

        for epoch in range(Config.EPOCHS):
            m = train_epoch(student, teacher, loaders, optimizer, config)
            val_auc = evaluate(student, loaders["val"], log_table=False)
            if val_auc > best_val_auc:
                best_val_auc = val_auc
                best_model_state = copy.deepcopy(student.state_dict())

            wandb.log({**m, "val_auc": val_auc, "epoch": epoch})

        # test eval
        if best_model_state is not None:
            student.load_state_dict(best_model_state)
            # ENABLE LOGGING HERE
            evaluate(student, loaders["test"], log_table=True, split_name="test")
            logger.success(f"run complete")
        else:
            logger.error("failed -> no model saved")


if __name__ == "__main__" and False:
    try:
        import sklearn  # noqa: F401
    except ImportError:
        logger.error("missing -> pip install scikit-learn")
        sys.exit(1)

    # DISTRIBUTED LOGIC:
    # 1. Check if a SWEEP_ID is provided => Worker mode
    # 2. If not, initialize a new sweep => Controller mode
    if "SWEEP_ID" in os.environ:
        sweep_id = os.environ["SWEEP_ID"]
        logger.info(f"worker start -> joining sweep: {sweep_id}")
    else:
        sweep_id = wandb.sweep(
            SWEEP_CONFIG, project=Config.PROJECT, entity=Config.ENTITY
        )
        logger.info(f"controller start -> init sweep: {sweep_id}")
        logger.info(
            "to join workers, run: export SWEEP_ID={sweep_id} && python main.py"
        )

    wandb.agent(sweep_id, main_sweep, count=None)

In [None]:
RUN_FINAL_EVALUATION = True

if RUN_FINAL_EVALUATION:
    import copy
    import sys
    from typing import Any, Dict

    import numpy as np
    import torch
    import wandb

    # Best WandB run hyperparameters
    BEST_PARAMS = {
        "model_type": "GCN_Res",
        "num_layers": 4,
        "hidden_channels": 128,
        "heads": 4,
        "dropout": 0.3280387,
        "batch_size": 32,
        "label_rate": 0.1,
        "weight_decay": 1e-05,
        "learning_rate": 0.000692,
        "teacher_alpha": 0.9905,
        "epochs": 100,
    }

    BEST_PARAMS = {
        "heads": 2,
        "dropout": 0.5174957708684919,
        "batch_size": 64,
        "label_rate": 0.1,
        "model_type": "GCN_Res",
        "num_layers": 4,
        "weight_decay": 0,
        "learning_rate": 0.0014010278646211564,
        "teacher_alpha": 0.9854702728613248,
        "hidden_channels": 64,
        "consistency_weight": 0.5082421350750431,
}



    # Proper seeded experiments
    SEEDS = [0, 1, 2, 3, 4]
    EXPERIMENTS = {
        "Baseline (No Teacher)": {"consistency_weight": 0.0},
        "Mean Teacher": {"consistency_weight": 1.03712},
    }

    def train_one_run(run_config: Dict[str, Any]) -> float:
        """Executes full training lifecycle for a single seed configuration."""
        Config.SEED = run_config["seed"]
        torch.manual_seed(Config.SEED)
        np.random.seed(Config.SEED)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(Config.SEED)

        mode = "disabled" if Config.DUMMY_RUN else "online"
        run_name = f"eval_lambda{run_config['consistency_weight']}_s{run_config['seed']}"

        with wandb.init(
            project=Config.PROJECT,
            entity=Config.ENTITY,
            config=run_config,
            reinit=True,
            mode=mode,
            name=run_name,
        ) as run:
            config = wandb.config
            loaders = get_loaders(config)

            # dimensions
            sample = next(iter(loaders["labeled"]))
            in_dim = sample.num_node_features
            edge_dim = sample.num_edge_features

            # models & optimizer
            student = get_model(config, in_dim, edge_dim).to(Config.DEVICE)
            teacher = get_model(config, in_dim, edge_dim).to(Config.DEVICE)
            teacher.load_state_dict(student.state_dict())

            # freeze, teacher is ema only
            for p in teacher.parameters():
                p.requires_grad = False

            optimizer = torch.optim.Adam(
                student.parameters(),
                lr=config.learning_rate,
                weight_decay=config.weight_decay,
            )

            best_model_state = None
            best_val_auc = 0.0

            # note: uses config.epochs if present, else global default
            epochs = getattr(config, "epochs", Config.EPOCHS)
            for epoch in range(epochs):
                metrics = train_epoch(student, teacher, loaders, optimizer, config)
                val_auc = evaluate(student, loaders["val"], log_table=False)

                if val_auc > best_val_auc:
                    best_val_auc = val_auc
                    best_model_state = copy.deepcopy(student.state_dict())

                # log: internal loop metrics
                wandb.log({**metrics, "val_auc": val_auc, "epoch": epoch})

            # eval: test set
            if best_model_state is not None:
                student.load_state_dict(best_model_state)
                # log_table=True ensures rich media (roc, tables) are logged for final eval
                test_auc = evaluate(
                    student, loaders["test"], log_table=True, split_name="test"
                )
                return float(test_auc)

            return 0.0

    final_results = {}
    for exp_name, specific_args in EXPERIMENTS.items():
        print(f"\n>> Evaluating: {exp_name}")
        scores = []

        for seed in SEEDS:
            # Prepare the specific config for this run
            run_config = copy.deepcopy(BEST_PARAMS)
            run_config.update(specific_args)
            run_config["seed"] = seed

            try:
                test_auc = train_one_run(run_config)
                scores.append(test_auc)
                print(f"Seed {seed} Test AUC: {test_auc:.4f}")

            except Exception as e:

        # Aggregate stats
        if scores:
            mean_score = np.mean(scores)
            std_score = np.std(scores)
            final_results[exp_name] = (mean_score, std_score)
            print(f"Average: {mean_score:.4f} ± {std_score:.4f}")


    for name, (mean, std) in final_results.items():
        print(f"{name:<30} | {mean:.4f} ± {std:.4f}")