# Tox21 GINENet

We’re solving a multi-task binary toxicity prediction problem on Tox21. Concretely, for each molecule the model outputs 12 probabilities, one for each assay (e.g. NR-AR, SR-ARE, p53, etc.). At training time we use a binary cross-entropy loss (with masking for missing labels) over those 12 tasks, and at the end of each epoch we compute the ROC-AUC per task (then average) on the held-out validation set to see how well the model is distinguishing actives vs. inactives across all assays.

In [1]:
%pip -q install rdkit-pypi torch_geometric pennylane

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.1/56.1 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m52.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m42.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m56.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m930.8/930.8 kB[0m [31m40.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m67.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m67.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━

In [2]:
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Subset, WeightedRandomSampler
from torch.serialization import safe_globals, add_safe_globals

# Pytorch Geometric Imports
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.loader import DataLoader
from torch_geometric.data.data import DataTensorAttr, DataEdgeAttr
from torch_geometric.data.storage import GlobalStorage
from torch_geometric.nn import (
    GINEConv,
    GatedGraphConv,
    Set2Set,
    GlobalAttention,
    global_add_pool,
    global_mean_pool
)
from torch_geometric.nn.models import JumpingKnowledge

# RDKit Imports
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.Chem import MolFromSmiles, MolToSmiles, rdchem
from rdkit.Chem import Descriptors

from sklearn.metrics import roc_auc_score
import os
import optuna
from pennylane import numpy as np
import pennylane as qml
from pennylane.templates import (
    AngleEmbedding,
    AmplitudeEmbedding,
    StronglyEntanglingLayers
)

add_safe_globals([DataTensorAttr, DataEdgeAttr, GlobalStorage])



In [3]:
!wget wget https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/tox21.csv.gz
!gunzip tox21.csv.gz
!mkdir -p raw
!mv tox21.csv raw/

--2025-06-02 20:09:50--  http://wget/
Resolving wget (wget)... failed: Name or service not known.
wget: unable to resolve host address ‘wget’
--2025-06-02 20:09:50--  https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/tox21.csv.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 125310 (122K) [application/octet-stream]
Saving to: ‘tox21.csv.gz’


2025-06-02 20:09:50 (5.61 MB/s) - ‘tox21.csv.gz’ saved [125310/125310]

FINISHED --2025-06-02 20:09:50--
Total wall clock time: 0.3s
Downloaded: 1 files, 122K in 0.02s (5.61 MB/s)


# Dataset

In [4]:
class Tox21Dataset(InMemoryDataset):
    def __init__(self, root: str = '.', transform=None, pre_transform=None):
        """
        Expects:
        root/
            raw/tox21.csv
        Will create:
            processed/data.pt
        """
        super().__init__(root, transform, pre_transform)
        # Load the processed data
        with safe_globals([DataTensorAttr, DataEdgeAttr, GlobalStorage]):
            self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        # File expected in root/raw/
        return ['tox21.csv']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # No download step needed; CSV is already in place
        return

    def process(self):
        df = pd.read_csv(self.raw_paths[0])
        df = df.fillna(value=0)
        
        if "mol_id" in df.columns:
            df = df.drop(columns=["mol_id"])

        # 3) convert all non-smiles columns to numeric
        non_smiles = [c for c in df.columns if c != "smiles"]
        df[non_smiles] = df[non_smiles].apply(pd.to_numeric, errors="coerce")
        df = df.reset_index(drop=True)

        data_list = []
        bond_types = [
            rdchem.BondType.SINGLE,
            rdchem.BondType.DOUBLE,
            rdchem.BondType.TRIPLE,
            rdchem.BondType.AROMATIC,
        ]
        stereo_types = [
            rdchem.BondStereo.STEREONONE,
            rdchem.BondStereo.STEREOZ,
            rdchem.BondStereo.STEREOE,
            rdchem.BondStereo.STEREOANY
        ]

        for _, row in df.iterrows():
            smiles = row["smiles"]
            mol = MolFromSmiles(smiles, sanitize=True)
            if mol is None:
                continue

            # node features x
            atom_feats = [
                [
                    a.GetAtomicNum(),
                    a.GetDegree(),
                    a.GetFormalCharge(),
                    a.GetNumRadicalElectrons(),
                    a.GetTotalNumHs(),
                    int(a.GetIsAromatic()),
                    int(a.IsInRing())
                ]
                for a in mol.GetAtoms()
            ]
            x = torch.tensor(atom_feats, dtype=torch.float32)

            # get molecule descriptors
            descriptors = [
                Descriptors.MolWt(mol),
                Descriptors.MolLogP(mol),
                Descriptors.TPSA(mol),
                Descriptors.NumHAcceptors(mol),
                Descriptors.NumHDonors(mol)
            ]

            # build edge_index AND edge_attr in lock‐step
            edge_index = []
            edge_attr  = []
            for bond in mol.GetBonds():
                i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                # one‐hot encode this bond’s type
                bt = bond.GetBondType()
                bfeat = [int(bt == t) for t in bond_types]
                st = bond.GetStereo()
                sfeat = [int(st == s) for s in stereo_types]
                feat = bfeat + sfeat

                # add both directions
                edge_index += [[i, j], [j, i]]
                edge_attr  += [feat, feat]

            edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
            edge_attr  = torch.tensor(edge_attr, dtype=torch.float32)  # [2E, len(bond_types)]
            desc = torch.tensor(descriptors, dtype=torch.float32)

            # your labels
            y = torch.tensor(row[[c for c in df.columns if c!="smiles"]].tolist(), dtype=torch.float32)

            
            data_list.append(Data(
                x=x,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y=y,
                desc=desc
            ))

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])


In [5]:
dataset = Tox21Dataset('.')

Processing...
Done!


A **Murcko Scaffold** is a technique used to group molecules based on their core structural components. It identifies the essential building blocks by removing side chains and other non-core components, leaving behind the important ring systems and connecting chains. This method is widely used in medicinal chemistry and drug design to identify core structures that have preferential activity against specific targets, which is very useful in **molecular property prediction**.

In [6]:
def GetMurckoScaffold(data, train_frac: float, val_frac: float, test_frac: float):
    scaffold_indices = {}
    for idx, smiles in enumerate(data):
        scaffold_smiles = MurckoScaffold.MurckoScaffoldSmiles(smiles)
        #scaffold_smiles = MolToSmiles(scaffold)
        scaffold_indices.setdefault(scaffold_smiles, []).append(idx)

    groups = sorted(scaffold_indices.values(), key=len, reverse=True)
    n_total = len(data)
    n_train = int(train_frac * n_total)
    n_valid = int(val_frac * n_total)

    train_idx, valid_idx, test_idx = [], [], []
    for group in groups:
        if len(train_idx) + len(group) <= n_train:
            train_idx.extend(group)
        elif len(valid_idx) + len(group) <= n_valid:
            valid_idx.extend(group)
        else:
            test_idx.extend(group)

    return train_idx, valid_idx, test_idx

In [7]:
np.random.seed(3411)
torch.manual_seed(3411)

<torch._C.Generator at 0x7f1cd57331f0>

In [8]:
def train_epoch(model, loader, optimizer, criterion, device, use_desc=True):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        if use_desc:
            logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch, batch.desc)
        else:
            logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        mask = (batch.y >= 0).float()
        bs, nt = logits.size()
        batch_y = batch.y.view(bs, nt)
        mask = mask.view(bs, nt)
        loss = (criterion(logits, batch_y) * mask).sum() / mask.sum()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    return total_loss / len(loader.dataset)

def evaluate(model, loader, device, use_desc=True):
    model.eval()
    y, preds = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            if use_desc:
                logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch, batch.desc)
            else:
                logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            y.append(batch.y.cpu())
            preds.append(torch.sigmoid(logits).cpu())
    return torch.cat(preds, dim=0).numpy(), torch.cat(y, dim=0).numpy()

In [9]:
df = pd.read_csv('raw/tox21.csv').fillna(0).reset_index(drop=True)
smiles_list = df['smiles'].tolist()
train_idx, valid_idx, test_idx = GetMurckoScaffold(smiles_list, 0.8, 0.1, 0.1)

train_ds = Subset(dataset, train_idx)
val_ds   = Subset(dataset, valid_idx)
test_ds  = Subset(dataset, test_idx)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)#, sampler=sampler)
val_loader   = DataLoader(val_ds,   batch_size=32, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=32)

sample = dataset[0]
in_channels = sample.x.size(1)
num_tasks   = sample.y.size(0)
desc_dim = sample.desc.size(0)

print("There are", num_tasks, "tasks.")



There are 12 tasks.


In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        # logits: shape [N, ...], raw outputs
        # targets: same shape, 0 or 1
        probas = torch.sigmoid(logits)
        # p_t: prob of true class
        p_t = probas * targets + (1 - probas) * (1 - targets)
        # alpha factor
        alpha_factor = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        # focal weight
        focal_weight = alpha_factor * (1 - p_t) ** self.gamma
        # binary cross‐entropy per example
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        loss = focal_weight * bce

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

# GINE with Jumping Knowledge

In [12]:
import warnings
warnings.filterwarnings('ignore')

In [13]:
class OriginalGINE(nn.Module):
    def __init__(
        self,
        input_dim: int,
        num_tasks: int,
        hidden_dim: int = 300,
        n_layers: int = 5,
        jk_mode: str = 'cat',
        dropout: float = 0.5,
    ):
        super().__init__()
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.dropout = dropout

        # Build n_layers GINEConv blocks, each with MLP: Linear→BatchNorm→ReLU→Linear
        self.convs = nn.ModuleList()
        for i in range(n_layers):
            in_dim = input_dim if i == 0 else hidden_dim
            mlp = nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.convs.append(GINEConv(mlp, edge_dim=8))

        # JumpingKnowledge over all layer outputs
        self.jk = JumpingKnowledge(mode=jk_mode, channels=hidden_dim, num_layers=n_layers)

        # Final MLP: (hidden_dim * n_layers) → 128 → num_tasks
        jk_out_dim = hidden_dim * n_layers if jk_mode == 'cat' else hidden_dim
        self.lin1 = nn.Linear(jk_out_dim, 128)
        self.lin2 = nn.Linear(128, num_tasks)

    def forward(self, x, edge_index, edge_attr, batch):
        xs = []
        h = x
        for conv in self.convs:
            h = conv(h, edge_index, edge_attr)
            h = F.relu(h)
            h = F.dropout(h, p=self.dropout, training=self.training)
            xs.append(h)

        # JumpingKnowledge aggregation
        h_jk = self.jk(xs)  # shape: [num_nodes, hidden_dim * n_layers] if mode='cat'

        # Graph‐level mean pooling
        g = global_mean_pool(h_jk, batch)  # [batch_size, jk_out_dim]

        # Final 2‐layer MLP
        out = F.relu(self.lin1(g))
        out = F.dropout(out, p=self.dropout, training=self.training)
        out = self.lin2(out)  # [batch_size, num_tasks]
        return out


In [14]:
jk_model = OriginalGINE(
    input_dim = in_channels,
    num_tasks = num_tasks,
    hidden_dim = 300,
    n_layers = 5,
    jk_mode = 'cat'
).to(device)

print(jk_model)

OriginalGINE(
  (convs): ModuleList(
    (0): GINEConv(nn=Sequential(
      (0): Linear(in_features=7, out_features=300, bias=True)
      (1): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=300, out_features=300, bias=True)
    ))
    (1-4): 4 x GINEConv(nn=Sequential(
      (0): Linear(in_features=300, out_features=300, bias=True)
      (1): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=300, out_features=300, bias=True)
    ))
  )
  (jk): JumpingKnowledge(cat)
  (lin1): Linear(in_features=1500, out_features=128, bias=True)
  (lin2): Linear(in_features=128, out_features=12, bias=True)
)


In [15]:
jk_optimizer = optim.AdamW(
    jk_model.parameters(),
    lr = 1e-3,
    weight_decay = 5e-4
)

In [18]:
jk_criterion = nn.BCEWithLogitsLoss()
#jk_criterion = FocalLoss(alpha=0.25, gamma=2.0, reduction='none')

In [19]:
for epoch in range(1, 26):
    loss = train_epoch(jk_model, train_loader, jk_optimizer, jk_criterion, device, False)
    ps_val, ys_val = evaluate(jk_model, val_loader, device, False)
    ys_val = ys_val.reshape(ps_val.shape)
    aucs = []
    for i in range(ps_val.shape[1]):
        mask = ys_val[:, i] >= 0
        if mask.sum() > 0:
            aucs.append(roc_auc_score(ys_val[mask, i], ps_val[mask, i]))
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Val AUC: {np.mean(aucs):.3f}")

Epoch 01 | Loss: 0.2346 | Val AUC: 0.671
Epoch 02 | Loss: 0.2082 | Val AUC: 0.668
Epoch 03 | Loss: 0.2039 | Val AUC: 0.681
Epoch 04 | Loss: 0.2022 | Val AUC: 0.696
Epoch 05 | Loss: 0.1997 | Val AUC: 0.629
Epoch 06 | Loss: 0.1983 | Val AUC: 0.690
Epoch 07 | Loss: 0.1967 | Val AUC: 0.706
Epoch 08 | Loss: 0.1958 | Val AUC: 0.704
Epoch 09 | Loss: 0.1932 | Val AUC: 0.689
Epoch 10 | Loss: 0.1918 | Val AUC: 0.692
Epoch 11 | Loss: 0.1900 | Val AUC: 0.718
Epoch 12 | Loss: 0.1906 | Val AUC: 0.716
Epoch 13 | Loss: 0.1881 | Val AUC: 0.712
Epoch 14 | Loss: 0.1875 | Val AUC: 0.726
Epoch 15 | Loss: 0.1868 | Val AUC: 0.717
Epoch 16 | Loss: 0.1851 | Val AUC: 0.702
Epoch 17 | Loss: 0.1841 | Val AUC: 0.727
Epoch 18 | Loss: 0.1848 | Val AUC: 0.711
Epoch 19 | Loss: 0.1832 | Val AUC: 0.717
Epoch 20 | Loss: 0.1831 | Val AUC: 0.725
Epoch 21 | Loss: 0.1809 | Val AUC: 0.718
Epoch 22 | Loss: 0.1819 | Val AUC: 0.726
Epoch 23 | Loss: 0.1805 | Val AUC: 0.733
Epoch 24 | Loss: 0.1796 | Val AUC: 0.719
Epoch 25 | Loss:

In [20]:
ps_test, ys_test = evaluate(jk_model, test_loader, device, False)
ys_test = ys_test.reshape(ps_test.shape)
aucs = []
for i in range(ps_test.shape[1]):
    mask = ys_test[:, i] >= 0
    if mask.sum() > 0:
        aucs.append(roc_auc_score(ys_test[mask, i], ps_test[mask, i]))
print(f"Testing | Loss: {loss:.4f} | Test AUC: {np.mean(aucs):.3f}")

Testing | Loss: 0.1794 | Test AUC: 0.711


### Try Adding Quantum Computing

In [27]:
def QLinearCircuit(dev, num_qubits):
    def func(inputs, weights):
        AngleEmbedding(
            features=inputs,
            wires=range(num_qubits),
            rotation='Y'
        )
        StronglyEntanglingLayers(
            weights=weights,
            wires=range(num_qubits),
            imprimitive=qml.CNOT,
        )
        return [qml.expval(qml.PauliZ(w)) for w in range(num_qubits)]

    qn = qml.QNode(func, dev, interface='torch')
    #qml.add_noise(qn, noise_model = model_pl)
    return qml.qnn.TorchLayer(qn, {"weights": (1, num_qubits, 3)})

In [28]:
class QReLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.C1 = torch.tensor(0.01)
        self.C2 = torch.tensor(2)

    def forward(self, x):
        return torch.where(
            x <= 0,
            self.C1 * x - self.C2 * x,
            x
        )

In [37]:
class SmallQNN(nn.Module):
    def __init__(self, num_qubits: int):
        super().__init__()
        qdev = qml.device(
            "default.qubit",
            wires=num_qubits,
            #shots=256,
        )
        self.qfunc = QLinearCircuit(qdev, num_qubits)
        self.qrelu = QReLU()
        self.bn_gate = nn.BatchNorm1d(num_qubits)
        self.in_channels = num_qubits

    def forward(self, inputs):
        x = self.qfunc(inputs)
        x = self.bn_gate(x)
        return self.qrelu(x)

In [47]:
class GINEWithQNN(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_tasks: int,
        n_layers: int,
    ):
        super().__init__()
        self.n_layers = n_layers
        self.qnn = SmallQNN(num_qubits=input_dim)

        # first GINE
        self.conv1 = GINEConv(
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ),
            edge_dim=8,
        )

        # additional GINE layers
        self.convs = nn.ModuleList()
        for _ in range(n_layers-1):
            mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.convs.append(GINEConv(mlp, edge_dim=8))

        # final FC — input dim doubles due to Set2Set concat
        self.fc = nn.Linear(hidden_dim, num_tasks)

    def forward(self, x, edge_index, edge_attr, batch):
        x = self.qnn(x) # qnn

        # Initial GINEConv Layer
        x = F.relu(self.conv1(x, edge_index, edge_attr)) 

        # Main GINEConv Layers
        for conv in self.convs:
            x = F.relu(conv(x, edge_index, edge_attr))

        # Pooling
        x = global_mean_pool(x, batch)

        # Final prediction
        return self.fc(x)

In [48]:
qnn_model = GINEWithQNN(
    in_channels,
    hidden_dim = 300,
    num_tasks = num_tasks,
    n_layers = 5,
).to(device)

print(qnn_model)

GINEWithQNN(
  (qnn): SmallQNN(
    (qfunc): <Quantum Torch Layer: func=func>
    (qrelu): QReLU()
    (bn_gate): BatchNorm1d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv1): GINEConv(nn=Sequential(
    (0): Linear(in_features=7, out_features=300, bias=True)
    (1): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=300, out_features=300, bias=True)
  ))
  (convs): ModuleList(
    (0-3): 4 x GINEConv(nn=Sequential(
      (0): Linear(in_features=300, out_features=300, bias=True)
      (1): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=300, out_features=300, bias=True)
    ))
  )
  (fc): Linear(in_features=300, out_features=12, bias=True)
)


In [49]:
qnn_optimizer = optim.AdamW(
    qnn_model.parameters(),
    lr=1e-3,
    weight_decay=5e-4,
)

In [50]:
qnn_criterion = FocalLoss(alpha=0.25, gamma=2.1, reduction='none')

In [51]:
for epoch in range(1, 31):
    loss = train_epoch(qnn_model, train_loader, qnn_optimizer, qnn_criterion, device, False)
    ps_val, ys_val = evaluate(qnn_model, val_loader, device, False)
    ys_val = ys_val.reshape(ps_val.shape)
    aucs = []
    for i in range(ps_val.shape[1]):
        mask = ys_val[:, i] >= 0
        if mask.sum() > 0:
            aucs.append(roc_auc_score(ys_val[mask, i], ps_val[mask, i]))
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Val AUC: {np.mean(aucs):.3f}")

Epoch 01 | Loss: 0.0200 | Val AUC: 0.655
Epoch 02 | Loss: 0.0181 | Val AUC: 0.673
Epoch 03 | Loss: 0.0178 | Val AUC: 0.681
Epoch 04 | Loss: 0.0175 | Val AUC: 0.699
Epoch 05 | Loss: 0.0173 | Val AUC: 0.698
Epoch 06 | Loss: 0.0171 | Val AUC: 0.682
Epoch 07 | Loss: 0.0169 | Val AUC: 0.684
Epoch 08 | Loss: 0.0169 | Val AUC: 0.671
Epoch 09 | Loss: 0.0167 | Val AUC: 0.704
Epoch 10 | Loss: 0.0166 | Val AUC: 0.682
Epoch 11 | Loss: 0.0165 | Val AUC: 0.704
Epoch 12 | Loss: 0.0165 | Val AUC: 0.715
Epoch 13 | Loss: 0.0162 | Val AUC: 0.698
Epoch 14 | Loss: 0.0162 | Val AUC: 0.683
Epoch 15 | Loss: 0.0161 | Val AUC: 0.702
Epoch 16 | Loss: 0.0160 | Val AUC: 0.709
Epoch 17 | Loss: 0.0160 | Val AUC: 0.694
Epoch 18 | Loss: 0.0157 | Val AUC: 0.704
Epoch 19 | Loss: 0.0157 | Val AUC: 0.711
Epoch 20 | Loss: 0.0156 | Val AUC: 0.710
Epoch 21 | Loss: 0.0156 | Val AUC: 0.713
Epoch 22 | Loss: 0.0153 | Val AUC: 0.714
Epoch 23 | Loss: 0.0154 | Val AUC: 0.726
Epoch 24 | Loss: 0.0153 | Val AUC: 0.708
Epoch 25 | Loss:

In [52]:
ps_test, ys_test = evaluate(qnn_model, test_loader, device, False)
ys_test = ys_test.reshape(ps_test.shape)
aucs = []
for i in range(ps_test.shape[1]):
    mask = ys_test[:, i] >= 0
    if mask.sum() > 0:
        aucs.append(roc_auc_score(ys_test[mask, i], ps_test[mask, i]))
print(f"Testing | Loss: {loss:.4f} | Test AUC: {np.mean(aucs):.3f}")

Testing | Loss: 0.0146 | Test AUC: 0.696
