# Tox21 GINENet

We’re solving a multi-task binary toxicity prediction problem on Tox21. Concretely, for each molecule the model outputs 12 probabilities, one for each assay (e.g. NR-AR, SR-ARE, p53, etc.). At training time we use a binary cross-entropy loss (with masking for missing labels) over those 12 tasks, and at the end of each epoch we compute the ROC-AUC per task (then average) on the held-out validation set to see how well the model is distinguishing actives vs. inactives across all assays.

In [1]:
%pip -q install rdkit-pypi torch_geometric pennylane

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.1/56.1 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m58.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m60.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m81.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m930.8/930.8 kB[0m [31m32.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m78.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m73.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━

In [2]:
import pandas as pd
from math import ceil

# PennyLane Imports
import pennylane as qml
from pennylane import numpy as np
from pennylane.templates import AngleEmbedding, StronglyEntanglingLayers

# Pytorch Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Subset, WeightedRandomSampler
from torch.serialization import safe_globals, add_safe_globals

# Pytorch Geometric Imports
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.loader import DataLoader
from torch_geometric.data.data import DataTensorAttr, DataEdgeAttr
from torch_geometric.data.storage import GlobalStorage
from torch_geometric.nn import (
    GINEConv,
    GatedGraphConv,
    Set2Set,
    GlobalAttention,
    global_add_pool
)
from torch_geometric.nn.models import JumpingKnowledge

# RDKit Imports
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.Chem import MolFromSmiles, MolToSmiles, rdchem
from rdkit.Chem import Descriptors

from sklearn.metrics import roc_auc_score
import os



In [3]:
!wget wget https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/tox21.csv.gz
!gunzip tox21.csv.gz
!mkdir -p raw
!mv tox21.csv raw/

--2025-05-30 13:01:18--  http://wget/
Resolving wget (wget)... failed: Name or service not known.
wget: unable to resolve host address ‘wget’
--2025-05-30 13:01:18--  https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/tox21.csv.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 125310 (122K) [application/octet-stream]
Saving to: ‘tox21.csv.gz’


2025-05-30 13:01:19 (1.54 MB/s) - ‘tox21.csv.gz’ saved [125310/125310]

FINISHED --2025-05-30 13:01:19--
Total wall clock time: 0.7s
Downloaded: 1 files, 122K in 0.08s (1.54 MB/s)


In [4]:
add_safe_globals([DataTensorAttr, DataEdgeAttr, GlobalStorage])

# Dataset

In [5]:
class Tox21Dataset(InMemoryDataset):
    def __init__(self, root: str = '.', transform=None, pre_transform=None):
        """
        Expects:
        root/
            raw/tox21.csv
        Will create:
            processed/data.pt
        """
        super().__init__(root, transform, pre_transform)
        # Load the processed data
        with safe_globals([DataTensorAttr, DataEdgeAttr, GlobalStorage]):
            self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        # File expected in root/raw/
        return ['tox21.csv']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # No download step needed; CSV is already in place
        return

    def process(self):
        df = pd.read_csv(self.raw_paths[0])
        df = df.fillna(value=0)
        
        if "mol_id" in df.columns:
            df = df.drop(columns=["mol_id"])

        # 3) convert all non-smiles columns to numeric
        non_smiles = [c for c in df.columns if c != "smiles"]
        df[non_smiles] = df[non_smiles].apply(pd.to_numeric, errors="coerce")
        df = df.reset_index(drop=True)

        data_list = []
        bond_types = [
            rdchem.BondType.SINGLE,
            rdchem.BondType.DOUBLE,
            rdchem.BondType.TRIPLE,
            rdchem.BondType.AROMATIC,
        ]
        stereo_types = [
            rdchem.BondStereo.STEREONONE,
            rdchem.BondStereo.STEREOZ,
            rdchem.BondStereo.STEREOE,
            rdchem.BondStereo.STEREOANY
        ]

        for _, row in df.iterrows():
            smiles = row["smiles"]
            mol = MolFromSmiles(smiles, sanitize=True)
            if mol is None:
                continue

            # node features x
            atom_feats = [
                [
                    a.GetAtomicNum(),
                    a.GetDegree(),
                    a.GetFormalCharge(),
                    a.GetNumRadicalElectrons(),
                    a.GetTotalNumHs(),
                    int(a.GetIsAromatic()),
                    int(a.IsInRing())
                ]
                for a in mol.GetAtoms()
            ]
            x = torch.tensor(atom_feats, dtype=torch.float32)

            # get molecule descriptors
            descriptors = [
                Descriptors.MolWt(mol),
                Descriptors.MolLogP(mol),
                Descriptors.TPSA(mol),
                Descriptors.NumHAcceptors(mol),
                Descriptors.NumHDonors(mol)
            ]

            # build edge_index AND edge_attr in lock‐step
            edge_index = []
            edge_attr  = []
            for bond in mol.GetBonds():
                i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                # one‐hot encode this bond’s type
                bt = bond.GetBondType()
                bfeat = [int(bt == t) for t in bond_types]
                st = bond.GetStereo()
                sfeat = [int(st == s) for s in stereo_types]
                feat = bfeat + sfeat

                # add both directions
                edge_index += [[i, j], [j, i]]
                edge_attr  += [feat, feat]

            edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
            edge_attr  = torch.tensor(edge_attr, dtype=torch.float32)  # [2E, len(bond_types)]
            desc = torch.tensor(descriptors, dtype=torch.float32)

            # your labels
            y = torch.tensor(row[[c for c in df.columns if c!="smiles"]].tolist(), dtype=torch.float32)

            
            data_list.append(Data(
                x=x,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y=y,
                desc=desc
            ))

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])


In [6]:
dataset = Tox21Dataset('.')

Processing...
Done!


A **Murcko Scaffold** is a technique used to group molecules based on their core structural components. It identifies the essential building blocks by removing side chains and other non-core components, leaving behind the important ring systems and connecting chains. This method is widely used in medicinal chemistry and drug design to identify core structures that have preferential activity against specific targets, which is very useful in **molecular property prediction**.

In [7]:
def GetMurckoScaffold(data, train_frac: float, val_frac: float, test_frac: float):
    scaffold_indices = {}
    for idx, smiles in enumerate(data):
        scaffold_smiles = MurckoScaffold.MurckoScaffoldSmiles(smiles)
        #scaffold_smiles = MolToSmiles(scaffold)
        scaffold_indices.setdefault(scaffold_smiles, []).append(idx)

    groups = sorted(scaffold_indices.values(), key=len, reverse=True)
    n_total = len(data)
    n_train = int(train_frac * n_total)
    n_valid = int(val_frac * n_total)

    train_idx, valid_idx, test_idx = [], [], []
    for group in groups:
        if len(train_idx) + len(group) <= n_train:
            train_idx.extend(group)
        elif len(valid_idx) + len(group) <= n_valid:
            valid_idx.extend(group)
        else:
            test_idx.extend(group)

    return train_idx, valid_idx, test_idx

In [8]:
np.random.seed(3411)
torch.manual_seed(3411)

<torch._C.Generator at 0x7cfec360c710>

In [9]:
def QLinearCircuit(dev, num_qubits):
    @qml.qnode(dev, interface='torch')
    def func(inputs, weights):
        AngleEmbedding(
            features=inputs,
            wires=range(num_qubits),
            rotation='Y'
        )
        StronglyEntanglingLayers(
            weights=weights,
            wires=range(num_qubits),
            imprimitive=qml.CNOT,
        )
        return [qml.expval(qml.PauliZ(w)) for w in range(num_qubits)]

    return qml.qnn.TorchLayer(func, {"weights": (1, num_qubits, 3)})

In [10]:
class QReLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.C1 = torch.tensor(0.01)
        self.C2 = torch.tensor(2)

    def forward(self, x):
        return torch.where(
            x <= 0,
            self.C1 * x - self.C2 * x,
            x
        )

In [11]:
class SmallQNN(nn.Module):
    def __init__(self, num_qubits: int):
        super().__init__()
        qdev = qml.device('default.qubit', wires=range(num_qubits))
        self.qfunc = QLinearCircuit(qdev, num_qubits)
        self.qrelu = QReLU()
        self.bn_gate = nn.BatchNorm1d(num_qubits)
        self.in_channels = num_qubits

    def forward(self, inputs):
        x = self.qfunc(inputs)
        x = self.bn_gate(x)
        return self.qrelu(x)

In [12]:
class GINE(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        desc_dim: int,
        num_tasks: int,
        n_layers: int,
        set2set_steps: int = 3,
    ):
        super().__init__()
        
        # First GINE layer
        self.conv1 = GINEConv(
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            ),
            edge_dim=8
        )

        # Additional GINE layers with skip connections
        self.conv_block = nn.ModuleList()
        for _ in range(n_layers):
            mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.conv_block.append(GINEConv(mlp, edge_dim=8))

        # Optional gated updates
        self.gate = GatedGraphConv(hidden_dim, 3)
        self.bn_gate = nn.BatchNorm1d(hidden_dim)

        # Replace mean pool with Set2Set
        self.pool = Set2Set(hidden_dim, processing_steps=set2set_steps)

        # Fully connected: include descriptor dimension
        self.fc = nn.Linear(2 * hidden_dim + desc_dim, num_tasks)

    def forward(self, x, edge_index, edge_attr, batch, descriptors):
        # 1st GINE layer
        x = self.conv1(x, edge_index, edge_attr)

        # GINE layers + skip connections + dropout
        for block in self.conv_block:
            h = block(x, edge_index, edge_attr)
            x = (x + h).relu()
            x = F.dropout(x, p=0.1, training=self.training)
    
        # (Optional) gated graph conv and BN
        x = self.gate(x, edge_index)
        x = self.bn_gate(x)

        # Set2Set pooling: output shape [batch_size, 2*hidden_dim]
        x = self.pool(x, batch)

        # Concatenate descriptor features
        # descriptors should be shape [batch_size, desc_dim]
        descs = descriptors.view(-1, 5)
        x = torch.cat([x, descs], dim=1)

        # Final prediction
        return self.fc(x)

In [13]:
def train_epoch(model, loader, optimizer, criterion, device, use_desc=True):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        if use_desc:
            logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch, batch.desc)
        else:
            logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        mask = (batch.y >= 0).float()
        bs, nt = logits.size()
        batch_y = batch.y.view(bs, nt)
        mask = mask.view(bs, nt)
        loss = (criterion(logits, batch_y) * mask).sum() / mask.sum()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    return total_loss / len(loader.dataset)

In [14]:
def evaluate(model, loader, device, use_desc=True):
    model.eval()
    y, preds = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            if use_desc:
                logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch, batch.desc)
            else:
                logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            y.append(batch.y.cpu())
            preds.append(torch.sigmoid(logits).cpu())
    return torch.cat(preds, dim=0).numpy(), torch.cat(y, dim=0).numpy()

In [15]:
df = pd.read_csv('raw/tox21.csv').fillna(0).reset_index(drop=True)
smiles_list = df['smiles'].tolist()
train_idx, valid_idx, test_idx = GetMurckoScaffold(smiles_list, 0.8, 0.1, 0.1)

train_ds = Subset(dataset, train_idx)
val_ds   = Subset(dataset, valid_idx)
test_ds  = Subset(dataset, test_idx)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)#, sampler=sampler)
val_loader   = DataLoader(val_ds,   batch_size=32, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=32)

sample = dataset[0]
in_channels = sample.x.size(1)
num_tasks   = sample.y.size(0)
desc_dim = sample.desc.size(0)

print("There are", num_tasks, "tasks.")



There are 12 tasks.


In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [19]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        # logits: shape [N, ...], raw outputs
        # targets: same shape, 0 or 1
        probas = torch.sigmoid(logits)
        # p_t: prob of true class
        p_t = probas * targets + (1 - probas) * (1 - targets)
        # alpha factor
        alpha_factor = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        # focal weight
        focal_weight = alpha_factor * (1 - p_t) ** self.gamma
        # binary cross‐entropy per example
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        loss = focal_weight * bce

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

# Model 1: Classic GINE 

In [17]:
model = GINE(
    in_channels,
    hidden_dim = 384,
    num_tasks = num_tasks,
    desc_dim = desc_dim,
    n_layers = 3
).to(device)

print(model)

GINE(
  (conv1): GINEConv(nn=Sequential(
    (0): Linear(in_features=7, out_features=384, bias=True)
    (1): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=384, out_features=384, bias=True)
  ))
  (conv_block): ModuleList(
    (0-2): 3 x GINEConv(nn=Sequential(
      (0): Linear(in_features=384, out_features=384, bias=True)
      (1): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=384, out_features=384, bias=True)
    ))
  )
  (gate): GatedGraphConv(384, num_layers=3)
  (bn_gate): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): Set2Set(384, 768)
  (fc): Linear(in_features=773, out_features=12, bias=True)
)


In [20]:
optimizer = optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=1e-5,
)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="max",           # we track val AUC, so “max”
    factor=0.5,           # halve the lr
    patience=5,           # after 5 epochs with no AUC gain
    min_lr=1e-6
)

criterion = FocalLoss(alpha=0.10, gamma=1.0, reduction='none')

In [21]:
for epoch in range(1, 21):
    loss = train_epoch(model, train_loader, optimizer, criterion, device)
    ps_val, ys_val = evaluate(model, val_loader, device)
    ys_val = ys_val.reshape(ps_val.shape)
    aucs = []
    for i in range(ps_val.shape[1]):
        mask = ys_val[:, i] >= 0
        if mask.sum() > 0:
            aucs.append(roc_auc_score(ys_val[mask, i], ps_val[mask, i]))
    scheduler.step(np.mean(aucs)) # we use the mean like below
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Val AUC: {np.mean(aucs):.3f}")

Epoch 01 | Loss: 0.1738 | Val AUC: 0.555
Epoch 02 | Loss: 0.0341 | Val AUC: 0.610
Epoch 03 | Loss: 0.0276 | Val AUC: 0.586
Epoch 04 | Loss: 0.0254 | Val AUC: 0.631
Epoch 05 | Loss: 0.0224 | Val AUC: 0.653
Epoch 06 | Loss: 0.0207 | Val AUC: 0.658
Epoch 07 | Loss: 0.0195 | Val AUC: 0.673
Epoch 08 | Loss: 0.0184 | Val AUC: 0.680
Epoch 09 | Loss: 0.0177 | Val AUC: 0.678
Epoch 10 | Loss: 0.0175 | Val AUC: 0.680
Epoch 11 | Loss: 0.0171 | Val AUC: 0.693
Epoch 12 | Loss: 0.0168 | Val AUC: 0.699
Epoch 13 | Loss: 0.0167 | Val AUC: 0.700
Epoch 14 | Loss: 0.0164 | Val AUC: 0.688
Epoch 15 | Loss: 0.0162 | Val AUC: 0.697
Epoch 16 | Loss: 0.0163 | Val AUC: 0.703
Epoch 17 | Loss: 0.0165 | Val AUC: 0.707
Epoch 18 | Loss: 0.0159 | Val AUC: 0.700
Epoch 19 | Loss: 0.0157 | Val AUC: 0.703
Epoch 20 | Loss: 0.0157 | Val AUC: 0.711


In [22]:
ps_test, ys_test = evaluate(model, test_loader, device)
ys_test = ys_test.reshape(ps_test.shape)
aucs = []
for i in range(ps_test.shape[1]):
    mask = ys_test[:, i] >= 0
    if mask.sum() > 0:
        aucs.append(roc_auc_score(ys_test[mask, i], ps_test[mask, i]))
print(f"Testing | Loss: {loss:.4f} | Test AUC: {np.mean(aucs):.3f}")

Testing | Loss: 0.0157 | Test AUC: 0.711


# Model 2: GINE with Jumping Knowledge

In [28]:
class GINEWithJK(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_tasks: int,
        n_layers: int,
        jk_mode: str = 'cat',
        set2set_steps: int = 3,
        use_set2set: bool = True,
    ):
        super().__init__()
        self.n_layers = n_layers
        self.use_set2set = use_set2set

        # first GINE
        self.conv1 = GINEConv(
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ),
            edge_dim=8,
        )

        # additional GINE layers
        self.convs = nn.ModuleList()
        for _ in range(n_layers):
            mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.convs.append(GINEConv(mlp, edge_dim=8))

        # Jumping Knowledge to concat all layer outputs
        self.jk = JumpingKnowledge(mode=jk_mode, channels=hidden_dim, num_layers=n_layers+1)

        # pooling: either Set2Set or mean
        if use_set2set:
            # input to Set2Set is hidden_dim * (n_layers+1)
            self.pool = Set2Set(hidden_dim * (n_layers + 1), processing_steps=set2set_steps)
            # final linear sees 2x that dim
            fc_in_dim = 2 * hidden_dim * (n_layers + 1)
        else:
            self.pool = global_mean_pool
            fc_in_dim = hidden_dim * (n_layers + 1)

        # final FC
        self.fc = nn.Linear(fc_in_dim, num_tasks)

    def forward(self, x, edge_index, edge_attr, batch):
        xs = []

        # layer 0
        x0 = F.relu(self.conv1(x, edge_index, edge_attr))
        xs.append(x0)

        # layers 1…n
        x = x0
        for conv in self.convs:
            x = F.relu(conv(x, edge_index, edge_attr))
            xs.append(x)

        # Jumping Knowledge aggregation (e.g. cat)
        x_jk = self.jk(xs)  # -> [num_nodes, hidden_dim*(n_layers+1)]

        # graph-level pooling
        if self.use_set2set:
            # Set2Set returns [batch_size, 2 * hidden_dim * (n_layers+1)]
            x_graph = self.pool(x_jk, batch)
        else:
            x_graph = self.pool(x_jk, batch)  # global_mean_pool

        # final predictions
        out = self.fc(x_graph)
        return out


In [29]:
jk_model = GINEWithJK(
    in_channels,
    hidden_dim = 384,
    num_tasks = num_tasks,
    n_layers = 3,
    jk_mode = 'cat',
    set2set_steps = 3,
).to(device)

print(jk_model)

GINEWithJK(
  (conv1): GINEConv(nn=Sequential(
    (0): Linear(in_features=7, out_features=384, bias=True)
    (1): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=384, out_features=384, bias=True)
  ))
  (convs): ModuleList(
    (0-2): 3 x GINEConv(nn=Sequential(
      (0): Linear(in_features=384, out_features=384, bias=True)
      (1): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=384, out_features=384, bias=True)
    ))
  )
  (jk): JumpingKnowledge(cat)
  (pool): Set2Set(1536, 3072)
  (fc): Linear(in_features=3072, out_features=12, bias=True)
)


In [30]:
jk_optimizer = optim.AdamW(
    jk_model.parameters(),
    lr=3e-4,
    weight_decay=1e-5,
)

jk_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    jk_optimizer,
    mode="max",           # we track val AUC, so “max”
    factor=0.5,           # halve the lr
    patience=5,           # after 5 epochs with no AUC gain
    min_lr=1e-6
)

In [31]:
jk_criterion = FocalLoss(alpha=0.10, gamma=1.0, reduction='none')

In [32]:
for epoch in range(1, 21):
    loss = train_epoch(jk_model, train_loader, jk_optimizer, jk_criterion, device, False)
    ps_val, ys_val = evaluate(jk_model, val_loader, device, False)
    ys_val = ys_val.reshape(ps_val.shape)
    aucs = []
    for i in range(ps_val.shape[1]):
        mask = ys_val[:, i] >= 0
        if mask.sum() > 0:
            aucs.append(roc_auc_score(ys_val[mask, i], ps_val[mask, i]))
    jk_scheduler.step(np.mean(aucs)) # we use the mean like below
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Val AUC: {np.mean(aucs):.3f}")

Epoch 01 | Loss: 0.0225 | Val AUC: 0.633
Epoch 02 | Loss: 0.0173 | Val AUC: 0.669
Epoch 03 | Loss: 0.0167 | Val AUC: 0.648
Epoch 04 | Loss: 0.0166 | Val AUC: 0.672
Epoch 05 | Loss: 0.0164 | Val AUC: 0.699
Epoch 06 | Loss: 0.0161 | Val AUC: 0.700
Epoch 07 | Loss: 0.0160 | Val AUC: 0.695
Epoch 08 | Loss: 0.0158 | Val AUC: 0.703
Epoch 09 | Loss: 0.0155 | Val AUC: 0.700
Epoch 10 | Loss: 0.0154 | Val AUC: 0.701
Epoch 11 | Loss: 0.0154 | Val AUC: 0.709
Epoch 12 | Loss: 0.0151 | Val AUC: 0.704
Epoch 13 | Loss: 0.0150 | Val AUC: 0.700
Epoch 14 | Loss: 0.0149 | Val AUC: 0.696
Epoch 15 | Loss: 0.0146 | Val AUC: 0.707
Epoch 16 | Loss: 0.0147 | Val AUC: 0.704
Epoch 17 | Loss: 0.0145 | Val AUC: 0.715
Epoch 18 | Loss: 0.0142 | Val AUC: 0.724
Epoch 19 | Loss: 0.0141 | Val AUC: 0.715
Epoch 20 | Loss: 0.0138 | Val AUC: 0.695


In [33]:
ps_test, ys_test = evaluate(jk_model, test_loader, device, False)
ys_test = ys_test.reshape(ps_test.shape)
aucs = []
for i in range(ps_test.shape[1]):
    mask = ys_test[:, i] >= 0
    if mask.sum() > 0:
        aucs.append(roc_auc_score(ys_test[mask, i], ps_test[mask, i]))
print(f"Testing | Loss: {loss:.4f} | Test AUC: {np.mean(aucs):.3f}")

Testing | Loss: 0.0138 | Test AUC: 0.712


# Model 3: GINE with QNN as Initial Mixing Layer

In [35]:
class GINEWithQNN(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_tasks: int,
        n_layers: int,
        set2set_steps: int = 3,
    ):
        super().__init__()
        self.n_layers = n_layers
        self.qnn = SmallQNN(num_qubits=input_dim)

        # first GINE
        self.conv1 = GINEConv(
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ),
            edge_dim=8,
        )

        # additional GINE layers
        self.convs = nn.ModuleList()
        for _ in range(n_layers):
            mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.convs.append(GINEConv(mlp, edge_dim=8))


        # pooling: either Set2Set or mean
        # If using Set2Set, pool_in_dim = hidden_dim*n_layers
        self.pool = Set2Set(hidden_dim, processing_steps=set2set_steps)
        # otherwise: self.pool = global_mean_pool

        # final FC — input dim doubles due to Set2Set concat
        self.fc = nn.Linear(2 * hidden_dim, num_tasks)

    def forward(self, x, edge_index, edge_attr, batch):
        x = self.qnn(x) # qnn

        # Initial GINEConv Layer
        x = F.relu(self.conv1(x, edge_index, edge_attr)) 

        # Main GINEConv Layers
        for conv in self.convs:
            x = F.relu(conv(x, edge_index, edge_attr))

        # Pooling
        x = self.pool(x, batch)  # -> [batch_size, 2*hidden_dim*(n_layers)]

        # Final prediction
        return self.fc(x)

In [36]:
qnn_model = GINEWithQNN(
    in_channels,
    hidden_dim = 384,
    num_tasks = num_tasks,
    n_layers = 3,
    set2set_steps = 3,
).to(device)

print(qnn_model)

GINEWithQNN(
  (qnn): SmallQNN(
    (qfunc): <Quantum Torch Layer: func=func>
    (qrelu): QReLU()
    (bn_gate): BatchNorm1d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv1): GINEConv(nn=Sequential(
    (0): Linear(in_features=7, out_features=384, bias=True)
    (1): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=384, out_features=384, bias=True)
  ))
  (convs): ModuleList(
    (0-2): 3 x GINEConv(nn=Sequential(
      (0): Linear(in_features=384, out_features=384, bias=True)
      (1): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=384, out_features=384, bias=True)
    ))
  )
  (pool): Set2Set(384, 768)
  (fc): Linear(in_features=768, out_features=12, bias=True)
)


In [37]:
qnn_optimizer = optim.AdamW(
    qnn_model.parameters(),
    lr=1e-4,
    weight_decay=1e-5,
)

qnn_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    qnn_optimizer,
    mode="max",           # we track val AUC, so “max”
    factor=0.5,           # halve the lr
    patience=5,           # after 5 epochs with no AUC gain
    min_lr=1e-6
)

In [41]:
qnn_criterion = FocalLoss(alpha=0.38, gamma=0.938, reduction='none')

In [42]:
for epoch in range(1, 21):
    loss = train_epoch(qnn_model, train_loader, qnn_optimizer, qnn_criterion, device, False)
    ps_val, ys_val = evaluate(qnn_model, val_loader, device, False)
    ys_val = ys_val.reshape(ps_val.shape)
    aucs = []
    for i in range(ps_val.shape[1]):
        mask = ys_val[:, i] >= 0
        if mask.sum() > 0:
            aucs.append(roc_auc_score(ys_val[mask, i], ps_val[mask, i]))
    qnn_scheduler.step(np.mean(aucs)) # we use the mean like below
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Val AUC: {np.mean(aucs):.3f}")

Epoch 01 | Loss: 0.0403 | Val AUC: 0.709
Epoch 02 | Loss: 0.0392 | Val AUC: 0.710
Epoch 03 | Loss: 0.0377 | Val AUC: 0.724
Epoch 04 | Loss: 0.0370 | Val AUC: 0.714
Epoch 05 | Loss: 0.0354 | Val AUC: 0.724
Epoch 06 | Loss: 0.0347 | Val AUC: 0.732
Epoch 07 | Loss: 0.0341 | Val AUC: 0.730
Epoch 08 | Loss: 0.0336 | Val AUC: 0.718
Epoch 09 | Loss: 0.0334 | Val AUC: 0.728
Epoch 10 | Loss: 0.0329 | Val AUC: 0.719
Epoch 11 | Loss: 0.0325 | Val AUC: 0.716
Epoch 12 | Loss: 0.0321 | Val AUC: 0.725
Epoch 13 | Loss: 0.0310 | Val AUC: 0.721
Epoch 14 | Loss: 0.0306 | Val AUC: 0.723
Epoch 15 | Loss: 0.0306 | Val AUC: 0.720
Epoch 16 | Loss: 0.0302 | Val AUC: 0.722
Epoch 17 | Loss: 0.0299 | Val AUC: 0.718
Epoch 18 | Loss: 0.0299 | Val AUC: 0.713
Epoch 19 | Loss: 0.0290 | Val AUC: 0.723
Epoch 20 | Loss: 0.0288 | Val AUC: 0.720


In [43]:
ps_test, ys_test = evaluate(qnn_model, test_loader, device, False)
ys_test = ys_test.reshape(ps_test.shape)
aucs = []
for i in range(ps_test.shape[1]):
    mask = ys_test[:, i] >= 0
    if mask.sum() > 0:
        aucs.append(roc_auc_score(ys_test[mask, i], ps_test[mask, i]))
print(f"Testing | Loss: {loss:.4f} | Test AUC: {np.mean(aucs):.3f}")

Testing | Loss: 0.0288 | Test AUC: 0.679
