In [1]:
pip install ogb


Collecting ogb
  Downloading ogb-1.3.6-py3-none-any.whl.metadata (6.2 kB)
Collecting outdated>=0.2.0 (from ogb)
  Downloading outdated-0.2.2-py2.py3-none-any.whl.metadata (4.7 kB)
Collecting littleutils (from outdated>=0.2.0->ogb)
  Downloading littleutils-0.2.4-py3-none-any.whl.metadata (679 bytes)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.6.0->ogb)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.6.0->ogb)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.6.0->ogb)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.6.0->ogb)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5

In [2]:
pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m45.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [3]:
pip install torch==2.5.0 --force-reinstall

Collecting torch==2.5.0
  Downloading torch-2.5.0-cp311-cp311-manylinux1_x86_64.whl.metadata (28 kB)
Collecting filelock (from torch==2.5.0)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting typing-extensions>=4.8.0 (from torch==2.5.0)
  Downloading typing_extensions-4.13.2-py3-none-any.whl.metadata (3.0 kB)
Collecting networkx (from torch==2.5.0)
  Downloading networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch==2.5.0)
  Downloading jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec (from torch==2.5.0)
  Downloading fsspec-2025.3.2-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.5.0)
  Using cached nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.5.0)
  Using cached nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti

In [4]:
import torch
from ogb.linkproppred import LinkPropPredDataset, Evaluator
import numpy as np
from torch_geometric.data import Data
import torch.nn as nn
from torch_geometric.utils import negative_sampling, scatter, add_self_loops, softmax
import torch.nn.functional as F
import torch.nn.init as init
from torch_geometric.nn import MessagePassing
from sklearn.metrics import accuracy_score, precision_score, recall_score

In [5]:
def load_dataset():
    dataset = LinkPropPredDataset(name="ogbl-ddi")
    split_edge = dataset.get_edge_split()
    data = dataset[0]
    return data, split_edge, dataset
data, split_edge, dataset = load_dataset()

Downloading http://snap.stanford.edu/ogb/data/linkproppred/ddi.zip


Downloaded 0.04 GB: 100%|██████████| 46/46 [00:05<00:00,  7.68it/s]


Extracting dataset/ddi.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:00<00:00, 27.77it/s]

Saving...



  train = torch.load(osp.join(path, 'train.pt'))
  valid = torch.load(osp.join(path, 'valid.pt'))
  test = torch.load(osp.join(path, 'test.pt'))


In [6]:
edge_index = torch.from_numpy(split_edge['train']['edge']).long().T.contiguous()
num_nodes = 4267
print(f"Edge Index Shape: {edge_index.shape}")

Edge Index Shape: torch.Size([2, 1067911])


In [7]:
def compute_edge_features(edge_index, chunk_size=1000):
    device = edge_index.device
    row, col = edge_index
    edge_index = torch.cat([edge_index, torch.stack([col, row])], dim=1)

    adj = torch.sparse_coo_tensor(edge_index, torch.ones(edge_index.size(1), device=device),
                                  (num_nodes, num_nodes), device=device)
    degrees = torch.sparse.sum(adj, dim=1).to_dense()

    src, dst = torch.triu_indices(num_nodes, num_nodes, offset=1, device=device)
    num_pairs = src.size(0)

    feat_list = []
    pair_list = []

    for i in range(0, num_pairs, chunk_size):
        s = src[i:i+chunk_size]
        d = dst[i:i+chunk_size]

        adj_s = torch.index_select(adj.to_dense(), 0, s)
        adj_d = torch.index_select(adj.to_dense(), 0, d)

        inter = adj_s * adj_d
        union = ((adj_s + adj_d) > 0).float()

        cn = inter.sum(dim=1)
        jc = cn / union.sum(dim=1).clamp(min=1)
        aa = (inter / torch.log(degrees + 1e-10)[None, :]).nan_to_num(0).sum(dim=1)
        pa = degrees[s] * degrees[d]
        ra = (inter / degrees[None, :].clamp(min=1)).nan_to_num(0).sum(dim=1)
        si = 2 * cn / (degrees[s] + degrees[d]).clamp(min=1)
        hpi = cn / torch.min(degrees[s], degrees[d]).clamp(min=1)
        hdi = cn / torch.max(degrees[s], degrees[d]).clamp(min=1)

        feats = torch.stack([cn, jc, aa, pa, ra, si, hpi, hdi], dim=1)
        feat_list.append(feats)
        pair_list.append(torch.stack([s, d], dim=1))

        del adj_s, adj_d, inter, union, feats
        torch.cuda.empty_cache()

    all_pairs = torch.cat(pair_list, dim=0)  # [E, 2]
    all_feats = torch.cat(feat_list, dim=0)  # [E, 8]

    return all_feats, all_pairs

In [8]:

def build_all_feats_matrix(all_pairs, all_feats):
    u = torch.minimum(all_pairs[:, 0], all_pairs[:, 1])
    v = torch.maximum(all_pairs[:, 0], all_pairs[:, 1])

    all_feats_matrix = torch.zeros((num_nodes, num_nodes, all_feats.size(1)),
                                   dtype=all_feats.dtype,
                                   device=all_feats.device)
    all_feats_matrix[u, v] = all_feats

    return all_feats_matrix


In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
edge_index=edge_index.to(device)
edge_feats, all_pairs = compute_edge_features(edge_index)


In [10]:
means = edge_feats.mean(dim=0, keepdim=True)
stds = edge_feats.std(dim=0, keepdim=True)
edge_feats[:, 0] = torch.log1p(edge_feats[:, 0])  # cn
edge_feats[:, 2] = torch.log1p(edge_feats[:, 2])  # aa
edge_feats[:, 3] = torch.log1p(edge_feats[:, 3])  # pa
edge_feats[:, 4] = torch.log1p(edge_feats[:, 4])  # ra

edge_feats = (edge_feats - means) / (stds + 1e-10)
edge_features = build_all_feats_matrix(all_pairs, edge_feats)

In [11]:
def fast_undirected_negative_sampling(edge_index, num_nodes, num_samples):
    u, v = edge_index
    pos_u = torch.min(u, v)
    pos_v = torch.max(u, v)
    pos_pairs = pos_u * num_nodes + pos_v  # unique undirected ID
    pos_set = pos_pairs.unique()

    neg_set = set()
    max_trials = num_samples * 10
    trials = 0

    while len(neg_set) < num_samples and trials < max_trials:
        i = torch.randint(0, num_nodes, (num_samples * 2,))
        j = torch.randint(0, num_nodes, (num_samples * 2,))
        mask = i != j
        i, j = i[mask], j[mask]

        u = torch.min(i, j)
        v = torch.max(i, j)
        pair_ids = u * num_nodes + v

        valid_mask = ~torch.isin(pair_ids, pos_set)
        u, v, pair_ids = u[valid_mask], v[valid_mask], pair_ids[valid_mask]

        for a, b, pid in zip(u.tolist(), v.tolist(), pair_ids.tolist()):
            if pid not in neg_set:
                neg_set.add(pid)
                if len(neg_set) == num_samples:
                    break
        trials += 1

    neg_pairs = torch.tensor([(pid // num_nodes, pid % num_nodes) for pid in neg_set], dtype=torch.long).t()
    neg_sym = torch.cat([neg_pairs, neg_pairs[[1, 0]]], dim=1)  # add reverse edges
    return neg_sym



In [12]:
num_edges = edge_index.shape[1]

perm = torch.randperm(num_edges)
shuffled_edges = edge_index[:, perm]

split_idx = int(0.8 * num_edges)
message_passing_edges = shuffled_edges[:, :split_idx]
train_supervision_edges = shuffled_edges[:, split_idx:]

def add_reverse_edges(edges):
    return torch.cat([edges, edges[[1, 0]]], dim=1)

message_passing_edges = add_reverse_edges(message_passing_edges)
train_supervision_edges = add_reverse_edges(train_supervision_edges)

print(f"Message Passing Edges Shape: {message_passing_edges.shape}")
print(f"Train Supervision Edges Shape: {train_supervision_edges.shape}")

num_neg_samples = train_supervision_edges.shape[1] // 2
neg_edge_index = fast_undirected_negative_sampling(edge_index.to('cpu'), num_nodes, num_neg_samples)
print(f"Negative Edge Index Shape: {neg_edge_index.shape}")


Message Passing Edges Shape: torch.Size([2, 1708656])
Train Supervision Edges Shape: torch.Size([2, 427166])
Negative Edge Index Shape: torch.Size([2, 427166])


In [13]:
class EdgeAwareGINLayer(MessagePassing):
    def __init__(self, in_dim, hidden_dim, eps=0.0):
        super().__init__(aggr='add')  # sum over weighted neighbor features

        self.mlp_phi = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(0.3),
            nn.ReLU()
        )

        self.mlp_a = nn.Sequential(
            nn.Linear(8, 32),
            nn.LayerNorm(32),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(32, 1)
        )


        self.eps = nn.Parameter(torch.Tensor([eps]))

    def forward(self, x, edge_index):
        x = x.to(self.mlp_phi[0].weight.device)
        return self.propagate(edge_index, x=x)

    def message(self, x_j, edge_index_i, edge_index_j):
        i, j = torch.min(edge_index_i, edge_index_j), torch.max(edge_index_i, edge_index_j)
        edge_feats = edge_features[i, j]

        raw_scores = self.mlp_a(edge_feats).squeeze(-1)
        return raw_scores.unsqueeze(-1) * x_j

    def update(self, aggr_out, x):
        out = (1 + self.eps) * x + aggr_out
        return self.mlp_phi(out)

class GINLinkPredictor(nn.Module):
    def __init__(self, num_nodes, hidden_dim):
        super().__init__()

        self.node_emb = nn.Embedding(num_nodes, hidden_dim)
        init.xavier_uniform_(self.node_emb.weight)
        self.gin1 = EdgeAwareGINLayer(hidden_dim, hidden_dim)
        self.gin2 = EdgeAwareGINLayer(hidden_dim, hidden_dim)
        self.edge_predictor1 = nn.Sequential(
                nn.Linear(2 * hidden_dim, 100),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(100, 60),
                nn.ReLU(),
                nn.Linear(60, 30),
                nn.ReLU(),
                nn.Linear(30, 1)
                )


    def forward(self, edge_index):
        x = self.node_emb.weight.to(edge_index.device)

        x = self.gin1(x, edge_index)
        x = self.gin2(x, edge_index)

        return x

    def predict_edges(self, x, edge_index):
        src, dst = edge_index

        node_embeddings = torch.cat([x[src], x[dst]], dim=1)

        return self.edge_predictor1(node_embeddings).squeeze(-1)

@torch.compile
def train_step(model, optimizer, scheduler, edge_index, train_edges, neg_edges, val_edges, val_neg_edges, device):
    model.train()
    optimizer.zero_grad()

    x = model(edge_index.to(device))

    pos = model.predict_edges(x, train_edges.to(device))
    neg = model.predict_edges(x, neg_edges.to(device))

    scores = torch.cat([pos, neg], dim=0)
    labels = torch.cat([
        torch.ones_like(pos),
        torch.zeros_like(neg)
    ], dim=0)

    bce_loss_fn = nn.BCEWithLogitsLoss()
    loss = bce_loss_fn(scores, labels)
    loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)

    optimizer.step()
    scheduler.step()
    model.eval()
    with torch.no_grad():
        x = model(edge_index.to(device))
        pos_val = model.predict_edges(x, val_edges.to(device))
        neg_val = model.predict_edges(x, val_neg_edges.to(device))
        input_dict = {"y_pred_pos": pos_val.view(-1), "y_pred_neg": neg_val.view(-1)}
        result = run_eval(input_dict)
    return loss, result["hits@20"]

In [14]:
@torch.no_grad()
@torch.compile
def test_step(model, edge_index, test_pos_edges, test_neg_edges, evaluator, device):
    model.eval()
    x = model(edge_index.to(device))

    pos_preds = model.predict_edges(x, test_pos_edges.to(device))
    neg_preds = model.predict_edges(x, test_neg_edges.to(device))

    avg_pos_score = pos_preds.mean().item()
    avg_neg_score = neg_preds.mean().item()
    top20_neg_avg = torch.topk(neg_preds, 20, largest=True).values.mean().item()

    print(f"🔹 Avg Positive Edge Score: {avg_pos_score:.15f}")
    print(f"🔻 Avg Negative Edge Score: {avg_neg_score:.15f}")
    print(f"🔺 Avg Top 20 Negative Edge Score: {top20_neg_avg:.15f}")

    input_dict = {"y_pred_pos": pos_preds.view(-1), "y_pred_neg": neg_preds.view(-1)}
    result = evaluator.eval(input_dict)

    pos_labels = torch.ones_like(pos_preds)
    neg_labels = torch.zeros_like(neg_preds)

    y_true = torch.cat([pos_labels, neg_labels], dim=0)
    y_pred = torch.cat([pos_preds, neg_preds], dim=0)
    y_pred = (torch.sigmoid(y_pred) > 0.5).float()

    correct = (y_pred == y_true).sum().item()
    accuracy = correct / y_true.numel()

    true_positives = ((y_pred == 1) & (y_true == 1)).sum().item()
    predicted_positives = (y_pred == 1).sum().item()
    actual_positives = (y_true == 1).sum().item()

    precision = true_positives / predicted_positives if predicted_positives > 0 else 0.0
    recall = true_positives / actual_positives if actual_positives > 0 else 0.0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

    print(f"✅ Accuracy: {accuracy:.4f}")
    print(f"🎯 Precision: {precision:.4f}")
    print(f"🔄 Recall: {recall:.4f}")
    print(f"⭐ F1 Score: {f1:.4f}")

    return result["hits@20"]


In [15]:
evaluator = Evaluator(name="ogbl-ddi")

In [16]:
@torch._dynamo.disable
def run_eval(input_dict):
    return evaluator.eval(input_dict)

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAVE_PATH = "model_path"


In [18]:
import os

model = GINLinkPredictor(num_nodes=4267, hidden_dim=256).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999)
model.compile()
print(f"Using device: {device}")
best = 0.0
num_epochs=3000

for epoch in range(num_epochs):
    loss, val_hits_at_20 = train_step(model, optimizer, scheduler, message_passing_edges, train_supervision_edges, neg_edge_index,val_edges, val_neg_edges, device)
    print(f"Epoch {epoch+1}, Train Loss: {loss:.15f}, Val hits at 20: {val_hits_at_20:.15f}")

    if val_hits_at_20 > best:
        best = val_hits_at_20
        torch.save(model.state_dict(), SAVE_PATH)
        print(f"Model saved at epoch {epoch+1} with loss {loss:.15f}")


Using device: cuda


W0419 12:07:40.548000 231 torch/_logging/_internal.py:1081] [4/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored


Epoch 1, Train Loss: 0.694143533706665, Val hits at 20: 0.026998479275446
Model saved at epoch 1 with loss 0.694143533706665
Epoch 2, Train Loss: 0.678128480911255, Val hits at 20: 0.022211567994367
Epoch 3, Train Loss: 0.660415947437286, Val hits at 20: 0.021245196233398
Epoch 4, Train Loss: 0.640608668327332, Val hits at 20: 0.019986665567949
Epoch 5, Train Loss: 0.619268238544464, Val hits at 20: 0.018031448284128
Epoch 6, Train Loss: 0.597509145736694, Val hits at 20: 0.018788064934189
Epoch 7, Train Loss: 0.575061917304993, Val hits at 20: 0.019619594123860


W0419 12:09:37.574000 231 torch/_dynamo/convert_frame.py:844] [5/8] torch._dynamo hit config.cache_size_limit (8)
W0419 12:09:37.574000 231 torch/_dynamo/convert_frame.py:844] [5/8]    function: 'step' (/usr/local/lib/python3.11/dist-packages/torch/optim/adam.py:189)
W0419 12:09:37.574000 231 torch/_dynamo/convert_frame.py:844] [5/8]    last reason: 5/0: L['self'].param_groups[0]['lr'] == 0.001                    
W0419 12:09:37.574000 231 torch/_dynamo/convert_frame.py:844] [5/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0419 12:09:37.574000 231 torch/_dynamo/convert_frame.py:844] [5/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.


Epoch 8, Train Loss: 0.553680896759033, Val hits at 20: 0.019836840488729
Epoch 9, Train Loss: 0.533146560192108, Val hits at 20: 0.020765755979893
Epoch 10, Train Loss: 0.513334035873413, Val hits at 20: 0.021200248709632
Epoch 11, Train Loss: 0.493969798088074, Val hits at 20: 0.021290143757164
Epoch 12, Train Loss: 0.475940585136414, Val hits at 20: 0.020878124789308
Epoch 13, Train Loss: 0.460966706275940, Val hits at 20: 0.020518544599180
Epoch 14, Train Loss: 0.449489742517471, Val hits at 20: 0.021125336170021
Epoch 15, Train Loss: 0.440307080745697, Val hits at 20: 0.021679688963136
Epoch 16, Train Loss: 0.435202449560165, Val hits at 20: 0.021784566518590
Epoch 17, Train Loss: 0.432927638292313, Val hits at 20: 0.021612267677487
Epoch 18, Train Loss: 0.432333678007126, Val hits at 20: 0.020458614567492
Epoch 19, Train Loss: 0.432580739259720, Val hits at 20: 0.018413502236139
Epoch 20, Train Loss: 0.432417780160904, Val hits at 20: 0.018301133426724
Epoch 21, Train Loss: 0.430

In [19]:
def load_model(save_path, num_nodes, device):
    model = GINLinkPredictor(num_nodes=num_nodes, hidden_dim=256).to(device)
    model.load_state_dict(torch.load(save_path, map_location=device, weights_only = False))
    model.eval()

    return model

In [32]:
loaded_model = load_model(SAVE_PATH, num_nodes=4267, device=device)


In [33]:
message_passing_edges = torch.load("message_passing_edges_path",weights_only = False).to(device)


In [34]:
val_edges = torch.tensor(split_edge['valid']['edge'].T, dtype=torch.long, device=device)
val_neg_edges = torch.tensor(split_edge['valid']['edge_neg'].T, dtype=torch.long, device=device)

hits_at_20 = test_step(loaded_model, message_passing_edges.to(device), val_edges, val_neg_edges, evaluator, device)
print(f"Validation Hits@20 (OGB Evaluator): {hits_at_20:.15f}")

🔹 Avg Positive Edge Score: 8.911684036254883
🔻 Avg Negative Edge Score: -8.524980545043945
🔺 Avg Top 20 Negative Edge Score: 9.181633949279785
✅ Accuracy: 0.9715
🎯 Precision: 0.9693
🔄 Recall: 0.9808
⭐ F1 Score: 0.9750
Validation Hits@20 (OGB Evaluator): 0.723550255077197


In [35]:
test_pos_edges = torch.tensor(split_edge['test']['edge'].T, dtype=torch.long, device=device)
test_neg_edges = torch.tensor(split_edge['test']['edge_neg'].T, dtype=torch.long, device=device)

hits_at_20 = test_step(loaded_model, message_passing_edges.to(device), test_pos_edges, test_neg_edges, evaluator, device)
print(f"Test Hits@20 (OGB Evaluator): {hits_at_20:.15f}")

🔹 Avg Positive Edge Score: 8.239475250244141
🔻 Avg Negative Edge Score: -8.508003234863281
🔺 Avg Top 20 Negative Edge Score: 7.999917507171631
✅ Accuracy: 0.9706
🎯 Precision: 0.9752
🔄 Recall: 0.9744
⭐ F1 Score: 0.9748
Test Hits@20 (OGB Evaluator): 0.694708927327345


In [36]:
def print_trainable_params(model):
    print("Trainable Parameters:")
    total_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name}: {param.shape} | {param.numel()} params")
            total_params += param.numel()

    print(f"Total Trainable Parameters: {total_params}")

print_trainable_params(loaded_model)


Trainable Parameters:
node_emb.weight: torch.Size([4267, 256]) | 1092352 params
gin1.eps: torch.Size([1]) | 1 params
gin1.mlp_phi.0.weight: torch.Size([256, 256]) | 65536 params
gin1.mlp_phi.0.bias: torch.Size([256]) | 256 params
gin1.mlp_phi.1.weight: torch.Size([256]) | 256 params
gin1.mlp_phi.1.bias: torch.Size([256]) | 256 params
gin1.mlp_phi.4.weight: torch.Size([256, 256]) | 65536 params
gin1.mlp_phi.4.bias: torch.Size([256]) | 256 params
gin1.mlp_phi.5.weight: torch.Size([256]) | 256 params
gin1.mlp_phi.5.bias: torch.Size([256]) | 256 params
gin1.mlp_a.0.weight: torch.Size([32, 8]) | 256 params
gin1.mlp_a.0.bias: torch.Size([32]) | 32 params
gin1.mlp_a.1.weight: torch.Size([32]) | 32 params
gin1.mlp_a.1.bias: torch.Size([32]) | 32 params
gin1.mlp_a.4.weight: torch.Size([1, 32]) | 32 params
gin1.mlp_a.4.bias: torch.Size([1]) | 1 params
gin2.eps: torch.Size([1]) | 1 params
gin2.mlp_phi.0.weight: torch.Size([256, 256]) | 65536 params
gin2.mlp_phi.0.bias: torch.Size([256]) | 256 par