# Tox21 GINENet

We’re solving a multi-task binary toxicity prediction problem on Tox21. Concretely, for each molecule the model outputs 12 probabilities, one for each assay (e.g. NR-AR, SR-ARE, p53, etc.). At training time we use a binary cross-entropy loss (with masking for missing labels) over those 12 tasks, and at the end of each epoch we compute the ROC-AUC per task (then average) on the held-out validation set to see how well the model is distinguishing actives vs. inactives across all assays.

In [1]:
%pip -q install rdkit-pypi torch_geometric

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m52.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m54.3 MB/s[0m eta [36m0:00:00[0m
[?25hNote: you may need to restart the kernel to use updated packages.


In [2]:
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Subset, WeightedRandomSampler
from torch.serialization import safe_globals, add_safe_globals

# Pytorch Geometric Imports
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.loader import DataLoader
from torch_geometric.data.data import DataTensorAttr, DataEdgeAttr
from torch_geometric.data.storage import GlobalStorage
from torch_geometric.nn import (
    GINEConv,
    GatedGraphConv,
    Set2Set,
    GlobalAttention,
    global_add_pool
)
from torch_geometric.nn.models import JumpingKnowledge

# RDKit Imports
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.Chem import MolFromSmiles, MolToSmiles, rdchem
from rdkit.Chem import Descriptors

from sklearn.metrics import roc_auc_score
import os
import optuna
import numpy as np

In [3]:
!wget wget https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/tox21.csv.gz
!gunzip tox21.csv.gz
!mkdir -p raw
!mv tox21.csv raw/

--2025-05-31 20:05:58--  http://wget/
Resolving wget (wget)... failed: Name or service not known.
wget: unable to resolve host address ‘wget’
--2025-05-31 20:05:58--  https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/tox21.csv.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 125310 (122K) [application/octet-stream]
Saving to: ‘tox21.csv.gz’


2025-05-31 20:05:58 (5.61 MB/s) - ‘tox21.csv.gz’ saved [125310/125310]

FINISHED --2025-05-31 20:05:58--
Total wall clock time: 0.2s
Downloaded: 1 files, 122K in 0.02s (5.61 MB/s)


In [4]:
add_safe_globals([DataTensorAttr, DataEdgeAttr, GlobalStorage])

# Dataset

In [5]:
class Tox21Dataset(InMemoryDataset):
    def __init__(self, root: str = '.', transform=None, pre_transform=None):
        """
        Expects:
        root/
            raw/tox21.csv
        Will create:
            processed/data.pt
        """
        super().__init__(root, transform, pre_transform)
        # Load the processed data
        with safe_globals([DataTensorAttr, DataEdgeAttr, GlobalStorage]):
            self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        # File expected in root/raw/
        return ['tox21.csv']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # No download step needed; CSV is already in place
        return

    def process(self):
        df = pd.read_csv(self.raw_paths[0])
        df = df.fillna(value=0)
        
        if "mol_id" in df.columns:
            df = df.drop(columns=["mol_id"])

        # 3) convert all non-smiles columns to numeric
        non_smiles = [c for c in df.columns if c != "smiles"]
        df[non_smiles] = df[non_smiles].apply(pd.to_numeric, errors="coerce")
        df = df.reset_index(drop=True)

        data_list = []
        bond_types = [
            rdchem.BondType.SINGLE,
            rdchem.BondType.DOUBLE,
            rdchem.BondType.TRIPLE,
            rdchem.BondType.AROMATIC,
        ]
        stereo_types = [
            rdchem.BondStereo.STEREONONE,
            rdchem.BondStereo.STEREOZ,
            rdchem.BondStereo.STEREOE,
            rdchem.BondStereo.STEREOANY
        ]

        for _, row in df.iterrows():
            smiles = row["smiles"]
            mol = MolFromSmiles(smiles, sanitize=True)
            if mol is None:
                continue

            # node features x
            atom_feats = [
                [
                    a.GetAtomicNum(),
                    a.GetDegree(),
                    a.GetFormalCharge(),
                    a.GetNumRadicalElectrons(),
                    a.GetTotalNumHs(),
                    int(a.GetIsAromatic()),
                    int(a.IsInRing())
                ]
                for a in mol.GetAtoms()
            ]
            x = torch.tensor(atom_feats, dtype=torch.float32)

            # get molecule descriptors
            descriptors = [
                Descriptors.MolWt(mol),
                Descriptors.MolLogP(mol),
                Descriptors.TPSA(mol),
                Descriptors.NumHAcceptors(mol),
                Descriptors.NumHDonors(mol)
            ]

            # build edge_index AND edge_attr in lock‐step
            edge_index = []
            edge_attr  = []
            for bond in mol.GetBonds():
                i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                # one‐hot encode this bond’s type
                bt = bond.GetBondType()
                bfeat = [int(bt == t) for t in bond_types]
                st = bond.GetStereo()
                sfeat = [int(st == s) for s in stereo_types]
                feat = bfeat + sfeat

                # add both directions
                edge_index += [[i, j], [j, i]]
                edge_attr  += [feat, feat]

            edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
            edge_attr  = torch.tensor(edge_attr, dtype=torch.float32)  # [2E, len(bond_types)]
            desc = torch.tensor(descriptors, dtype=torch.float32)

            # your labels
            y = torch.tensor(row[[c for c in df.columns if c!="smiles"]].tolist(), dtype=torch.float32)

            
            data_list.append(Data(
                x=x,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y=y,
                desc=desc
            ))

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])


In [6]:
dataset = Tox21Dataset('.')

Processing...
Done!


A **Murcko Scaffold** is a technique used to group molecules based on their core structural components. It identifies the essential building blocks by removing side chains and other non-core components, leaving behind the important ring systems and connecting chains. This method is widely used in medicinal chemistry and drug design to identify core structures that have preferential activity against specific targets, which is very useful in **molecular property prediction**.

In [7]:
def GetMurckoScaffold(data, train_frac: float, val_frac: float, test_frac: float):
    scaffold_indices = {}
    for idx, smiles in enumerate(data):
        scaffold_smiles = MurckoScaffold.MurckoScaffoldSmiles(smiles)
        #scaffold_smiles = MolToSmiles(scaffold)
        scaffold_indices.setdefault(scaffold_smiles, []).append(idx)

    groups = sorted(scaffold_indices.values(), key=len, reverse=True)
    n_total = len(data)
    n_train = int(train_frac * n_total)
    n_valid = int(val_frac * n_total)

    train_idx, valid_idx, test_idx = [], [], []
    for group in groups:
        if len(train_idx) + len(group) <= n_train:
            train_idx.extend(group)
        elif len(valid_idx) + len(group) <= n_valid:
            valid_idx.extend(group)
        else:
            test_idx.extend(group)

    return train_idx, valid_idx, test_idx

In [8]:
np.random.seed(3411)
torch.manual_seed(3411)

<torch._C.Generator at 0x7f576d24b230>

In [9]:
def train_epoch(model, loader, optimizer, criterion, device, use_desc=True):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        if use_desc:
            logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch, batch.desc)
        else:
            logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        mask = (batch.y >= 0).float()
        bs, nt = logits.size()
        batch_y = batch.y.view(bs, nt)
        mask = mask.view(bs, nt)
        loss = (criterion(logits, batch_y) * mask).sum() / mask.sum()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    return total_loss / len(loader.dataset)

In [10]:
def evaluate(model, loader, device, use_desc=True):
    model.eval()
    y, preds = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            if use_desc:
                logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch, batch.desc)
            else:
                logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            y.append(batch.y.cpu())
            preds.append(torch.sigmoid(logits).cpu())
    return torch.cat(preds, dim=0).numpy(), torch.cat(y, dim=0).numpy()

In [11]:
df = pd.read_csv('raw/tox21.csv').fillna(0).reset_index(drop=True)
smiles_list = df['smiles'].tolist()
train_idx, valid_idx, test_idx = GetMurckoScaffold(smiles_list, 0.8, 0.1, 0.1)

train_ds = Subset(dataset, train_idx)
val_ds   = Subset(dataset, valid_idx)
test_ds  = Subset(dataset, test_idx)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)#, sampler=sampler)
val_loader   = DataLoader(val_ds,   batch_size=32, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=32)

sample = dataset[0]
in_channels = sample.x.size(1)
num_tasks   = sample.y.size(0)
desc_dim = sample.desc.size(0)

print("There are", num_tasks, "tasks.")



There are 12 tasks.


In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        # logits: shape [N, ...], raw outputs
        # targets: same shape, 0 or 1
        probas = torch.sigmoid(logits)
        # p_t: prob of true class
        p_t = probas * targets + (1 - probas) * (1 - targets)
        # alpha factor
        alpha_factor = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        # focal weight
        focal_weight = alpha_factor * (1 - p_t) ** self.gamma
        # binary cross‐entropy per example
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        loss = focal_weight * bce

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

# GINE with Jumping Knowledge

In [24]:
class GINEWithJK(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_tasks: int,
        n_layers: int,
        jk_mode: str = 'max',            # default to 'max'
        set2set_steps: int = 3,
        use_set2set: bool = True,
    ):
        super().__init__()
        self.n_layers = n_layers
        self.use_set2set = use_set2set
        self.hidden_dim = hidden_dim

        # ——————————————————————————————
        # 1) Virtual‐node embedding (one learnable vector per graph):
        self.virtualnode_embedding = nn.Embedding(1, hidden_dim)
        # We will repeat this for each graph in the batch in `forward()`.

        # ——————————————————————————————
        # 2) First GINE layer sees: [input_dim ⊕ hidden_dim] → hidden_dim
        self.conv1 = GINEConv(
            nn.Sequential(
                nn.Linear(input_dim + hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ),
            edge_dim=8,
        )

        # ——————————————————————————————
        # 3) Additional GINE layers, each sees: [hidden_dim ⊕ hidden_dim] → hidden_dim
        self.convs = nn.ModuleList()
        for _ in range(n_layers):
            mlp = nn.Sequential(
                nn.Linear(hidden_dim + hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.convs.append(GINEConv(mlp, edge_dim=8))

        # ——————————————————————————————
        # 4) JumpingKnowledge: 
        #    - If jk_mode == 'cat', output size = hidden_dim * (n_layers + 1)
        #    - If jk_mode in {'max','lstm'}, output size = hidden_dim
        self.jk = JumpingKnowledge(mode=jk_mode, channels=hidden_dim, num_layers=n_layers + 1)

        # Compute jk_out_dim for pooling / final fc:
        if jk_mode == 'cat':
            jk_out_dim = hidden_dim * (n_layers + 1)
        else:
            jk_out_dim = hidden_dim

        # ——————————————————————————————
        # 5) Pooling: Set2Set(jk_out_dim, steps) or global_mean_pool
        if use_set2set:
            self.pool = Set2Set(jk_out_dim, processing_steps=set2set_steps)
            fc_in_dim = 2 * jk_out_dim  # because Set2Set doubles the dimension
        else:
            self.pool = global_mean_pool
            fc_in_dim = jk_out_dim

        # ——————————————————————————————
        # 6) Final linear (graph → num_tasks)
        self.fc = nn.Linear(fc_in_dim, num_tasks)


    def forward(self, x, edge_index, edge_attr, batch):
        """
        x:          [num_nodes, input_dim]
        edge_index: [2, num_edges]
        edge_attr:  [num_edges, 8]
        batch:      [num_nodes]   # maps each node to a graph‐ID
        """

        # ——————————————————————————————
        # Step 1: build one virtual‐node embedding per graph in the batch
        batch_size = int(batch.max().item()) + 1
        vnode = self.virtualnode_embedding(
            torch.zeros(batch_size, dtype=torch.long, device=x.device)
        )  # shape: [batch_size, hidden_dim]

        # expand vnode to one vector per node:
        vnode_expanded = vnode[batch]  # shape: [num_nodes, hidden_dim]

        # ——————————————————————————————
        # Step 2: First conv: concat original features + vnode, then pass through GINEConv
        x_in = torch.cat([x, vnode_expanded], dim=1)  # [num_nodes, input_dim + hidden_dim]
        h = F.relu(self.conv1(x_in, edge_index, edge_attr))  # [num_nodes, hidden_dim]

        xs = [h]  # for JumpingKnowledge

        # ——————————————————————————————
        # Step 3: Additional conv layers:
        for conv in self.convs:
            # re‐fetch vnode_expanded each layer (same vector)
            vnode_expanded = vnode[batch]  # [num_nodes, hidden_dim]
            h_in = torch.cat([h, vnode_expanded], dim=1)  # [num_nodes, hidden_dim * 2]
            h = F.relu(conv(h_in, edge_index, edge_attr))  # [num_nodes, hidden_dim]
            xs.append(h)

        # ——————————————————————————————
        # Step 4: JumpingKnowledge aggregation:
        #    - If jk_mode='cat', x_jk has size [num_nodes, hidden_dim*(n_layers+1)]
        #    - If jk_mode in {'max','lstm'}, x_jk has size [num_nodes, hidden_dim]
        x_jk = self.jk(xs)

        # ——————————————————————————————
        # Step 5: Graph‐level pooling
        if self.use_set2set:
            # Set2Set expects input dim = jk_out_dim
            x_graph = self.pool(x_jk, batch)
            # shape: [batch_size, 2 * jk_out_dim]
        else:
            x_graph = self.pool(x_jk, batch)
            # shape: [batch_size, jk_out_dim]

        # ——————————————————————————————
        # Step 6: Final MLP → num_tasks
        out = self.fc(x_graph)  # [batch_size, num_tasks]
        return out


In [30]:
def objective(trial):
    lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
    hidden_dim = trial.suggest_categorical("hidden_dim", [192, 384, 576])
    weight_decay = trial.suggest_loguniform("weight_decay", 5e-5, 1e-3)
    n_layers = trial.suggest_categorical("n_layers", [3, 4])
    set2set_steps = trial.suggest_categorical("set2set_steps", [3, 4, 5])
    alpha = trial.suggest_uniform("alpha", 0.10, 0.40)
    gamma = trial.suggest_uniform("gamma", 1.2, 1.6)
    model = GINEWithJK(
        in_channels,
        hidden_dim=hidden_dim,
        num_tasks=num_tasks,
        n_layers=n_layers,
        jk_mode='max',
        set2set_steps=set2set_steps
    ).to(device)

    criterion = FocalLoss(alpha=alpha, gamma=gamma, reduction='none')
    optimizer = optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay,
    )
    
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=30,
        eta_min=1e-5
    )

    for epoch in range(1, 21):
        loss = train_epoch(model, train_loader, optimizer, criterion, device, False)
        ps_val, ys_val = evaluate(model, val_loader, device, False)
        ys_val = ys_val.reshape(ps_val.shape)
        aucs = []
        for i in range(ps_val.shape[1]):
            mask = ys_val[:, i] >= 0
            if mask.sum() > 0:
                aucs.append(roc_auc_score(ys_val[mask, i], ps_val[mask, i]))
        scheduler.step(np.mean(aucs)) # we use the mean like below
    
    ps_test, ys_test = evaluate(model, test_loader, device, False)
    ys_test = ys_test.reshape(ps_test.shape)
    aucs = []
    for i in range(ps_test.shape[1]):
        mask = ys_test[:, i] >= 0
        if mask.sum() > 0:
            aucs.append(roc_auc_score(ys_test[mask, i], ps_test[mask, i]))
    
    return np.mean(aucs)

In [32]:
study = optuna.create_study(
    direction="maximize",
    sampler=optuna.samplers.TPESampler(seed=3411),  # repeatable
    pruner=optuna.pruners.MedianPruner(n_warmup_steps=2),  # optional
)

optuna.logging.set_verbosity(optuna.logging.CRITICAL)

study.optimize(objective, n_trials=50, timeout=3411)

[I 2025-05-31 20:34:38,006] A new study created in memory with name: no-name-87aa9691-b2a5-47aa-891d-c6b3482c5257
  lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
  weight_decay = trial.suggest_loguniform("weight_decay", 5e-5, 1e-3)
  alpha = trial.suggest_uniform("alpha", 0.10, 0.40)
  gamma = trial.suggest_uniform("gamma", 1.2, 1.6)
  lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
  weight_decay = trial.suggest_loguniform("weight_decay", 5e-5, 1e-3)
  alpha = trial.suggest_uniform("alpha", 0.10, 0.40)
  gamma = trial.suggest_uniform("gamma", 1.2, 1.6)
  lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
  weight_decay = trial.suggest_loguniform("weight_decay", 5e-5, 1e-3)
  alpha = trial.suggest_uniform("alpha", 0.10, 0.40)
  gamma = trial.suggest_uniform("gamma", 1.2, 1.6)
  lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
  weight_decay = trial.suggest_loguniform("weight_decay", 5e-5, 1e-3)
  alpha = trial.suggest_uniform("alpha", 0.10, 0.40)
  gamma = trial.suggest_uniform("gamma", 

In [33]:
print("Number of finished trials:", len(study.trials))
print("Best trial:")
trial = study.best_trial
print("  Value (val_loss):", trial.value)
print("  Params: ")
for key, val in trial.params.items():
    print(f"{key}: {val}")

Number of finished trials: 42
Best trial:
  Value (val_loss): 0.6711583766076324
  Params: 
lr: 5.605748567213143e-05
hidden_dim: 192
weight_decay: 7.377934206677797e-05
n_layers: 4
set2set_steps: 4
alpha: 0.25746549424033294
gamma: 1.41338724579117


In [35]:
jk_model = GINEWithJK(
    input_dim = in_channels,
    hidden_dim = trial.params["hidden_dim"],
    num_tasks = num_tasks,
    n_layers = trial.params["n_layers"],
    jk_mode = 'max',
    set2set_steps = trial.params["set2set_steps"],
).to(device)

print(jk_model)

GINEWithJK(
  (virtualnode_embedding): Embedding(1, 192)
  (conv1): GINEConv(nn=Sequential(
    (0): Linear(in_features=199, out_features=192, bias=True)
    (1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=192, out_features=192, bias=True)
  ))
  (convs): ModuleList(
    (0-3): 4 x GINEConv(nn=Sequential(
      (0): Linear(in_features=384, out_features=192, bias=True)
      (1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Linear(in_features=192, out_features=192, bias=True)
    ))
  )
  (jk): JumpingKnowledge(max)
  (pool): Set2Set(192, 384)
  (fc): Linear(in_features=384, out_features=12, bias=True)
)


In [36]:
jk_optimizer = optim.AdamW(
    jk_model.parameters(),
    lr = trial.params["lr"],
    weight_decay = trial.params["weight_decay"],
)

jk_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    jk_optimizer,
    mode="max",           # we track val AUC, so “max”
    factor=0.5,           # halve the lr
    patience=5,           # after 5 epochs with no AUC gain
    min_lr=1e-6
)

In [37]:
jk_criterion = FocalLoss(alpha=trial.params["alpha"], gamma=trial.params["gamma"], reduction='none')

In [38]:
for epoch in range(1, 21):
    loss = train_epoch(jk_model, train_loader, jk_optimizer, jk_criterion, device, False)
    ps_val, ys_val = evaluate(jk_model, val_loader, device, False)
    ys_val = ys_val.reshape(ps_val.shape)
    aucs = []
    for i in range(ps_val.shape[1]):
        mask = ys_val[:, i] >= 0
        if mask.sum() > 0:
            aucs.append(roc_auc_score(ys_val[mask, i], ps_val[mask, i]))
    jk_scheduler.step(np.mean(aucs)) # we use the mean like below
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Val AUC: {np.mean(aucs):.3f}")

Epoch 01 | Loss: 0.0419 | Val AUC: 0.597
Epoch 02 | Loss: 0.0292 | Val AUC: 0.609
Epoch 03 | Loss: 0.0289 | Val AUC: 0.629
Epoch 04 | Loss: 0.0287 | Val AUC: 0.627
Epoch 05 | Loss: 0.0285 | Val AUC: 0.631
Epoch 06 | Loss: 0.0284 | Val AUC: 0.631
Epoch 07 | Loss: 0.0282 | Val AUC: 0.661
Epoch 08 | Loss: 0.0277 | Val AUC: 0.654
Epoch 09 | Loss: 0.0275 | Val AUC: 0.650
Epoch 10 | Loss: 0.0272 | Val AUC: 0.664
Epoch 11 | Loss: 0.0271 | Val AUC: 0.673
Epoch 12 | Loss: 0.0268 | Val AUC: 0.666
Epoch 13 | Loss: 0.0269 | Val AUC: 0.674
Epoch 14 | Loss: 0.0267 | Val AUC: 0.673
Epoch 15 | Loss: 0.0267 | Val AUC: 0.689
Epoch 16 | Loss: 0.0265 | Val AUC: 0.689
Epoch 17 | Loss: 0.0264 | Val AUC: 0.687
Epoch 18 | Loss: 0.0263 | Val AUC: 0.660
Epoch 19 | Loss: 0.0264 | Val AUC: 0.683
Epoch 20 | Loss: 0.0261 | Val AUC: 0.699


In [39]:
ps_test, ys_test = evaluate(jk_model, test_loader, device, False)
ys_test = ys_test.reshape(ps_test.shape)
aucs = []
for i in range(ps_test.shape[1]):
    mask = ys_test[:, i] >= 0
    if mask.sum() > 0:
        aucs.append(roc_auc_score(ys_test[mask, i], ps_test[mask, i]))
print(f"Testing | Loss: {loss:.4f} | Test AUC: {np.mean(aucs):.3f}")

Testing | Loss: 0.0261 | Test AUC: 0.646
