In [1]:
import os
import sys
import torch
import argparse

# Manually specify the project root
project_root = "/home/yuanqi/Desktop/pycharm_projects/yuanqi_xue"

# Change working directory to project root
os.chdir(project_root)

# Add src/ to sys.path
src_path = os.path.join(project_root, "src")
if src_path not in sys.path:
    sys.path.append(src_path)

print(f"Working directory set to: {os.getcwd()}")
print(f"'src' added to sys.path: {src_path}")

from utils.Configures import data_args, train_args, model_args, mcts_args
from data_processing.dataUtils import load_dataset
from models import setup_model
from models.train import train
from evaluation.test import test
from evaluation.explanation import exp_visualize

  from .autonotebook import tqdm as notebook_tqdm


Working directory set to: /home/yuanqi/Desktop/pycharm_projects/yuanqi_xue
'src' added to sys.path: /home/yuanqi/Desktop/pycharm_projects/yuanqi_xue/src


In [2]:
# ========== Main Experiment Pipeline ==========
def main(args):
    # Step 1: Load dataset and construct dataloaders
    print('====================Loading Data====================')
    dataset, input_dim, output_dim, dataloader = load_dataset()
    print(f"Dataset Name: {data_args.dataset_name}")
    print(f"Dataset Length: {len(dataset)}")

    # Step 2: Initialize model and loss function
    print('====================Setting Up Model====================')
    gnnNets, criterion = setup_model(input_dim, output_dim, model_args)
    print(f"gnnNets: {gnnNets}")

    # Step 3: Train the model with prototype alignment and joint optimization
    print('====================Training Model==================')
    ckpt_dir = f"./src/checkpoint/{data_args.dataset_name}/"
    train(args.clst, args.sep, dataset, dataloader, gnnNets, output_dim, criterion, ckpt_dir)

    # Step 4: Evaluate the best model on the test set
    print('====================Testing==================')
    best_checkpoint = torch.load(os.path.join(ckpt_dir, f'{model_args.model_name}_best.pth'))
    gnnNets.update_state_dict(best_checkpoint['net'])
    test(dataloader['test'], gnnNets, criterion)

    # Step 5: Output explanations for selected graphs
    print('====================Generating Explanations====================')
    exp_visualize(dataset, dataloader, gnnNets, output_dim)

In [3]:
class Args:
    clst = 0.0  # cluster loss weight
    sep = 0.0   # separation loss weight

args = Args()
main(args)

dataloader used
Dataset Name: MUTAG
Dataset Length: 188




gnnNets: _GnnNets(
  (model): GINNet(
    (gnn_layers): ModuleList(
      (0): GINConv(nn=Sequential(
        (0): Linear(in_features=7, out_features=128, bias=False)
        (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Linear(in_features=128, out_features=128, bias=False)
        (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      ))
      (1): GINConv(nn=Sequential(
        (0): Linear(in_features=128, out_features=128, bias=False)
        (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Linear(in_features=128, out_features=128, bias=False)
        (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      ))
      (2): GINConv(nn=Sequential(
        (0): Linear(in_features=128, out_features=128, bias=False)
        (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=T

### Sampled Explanations (MUTAG Dataset)

The MUTAG dataset contains molecular graphs of nitroaromatic compounds labeled by mutagenicity. Our model is trained to extract **prototype subgraphs** (i.e., key explanatory regions) that ideally align with known mutagenic structures, such as **NO<sub>2</sub> groups (nitro) attached to aromatic rings**.

Below, we present three randomly selected prototype subgraphs from the model's output. Each image shows the full molecular graph, with **bolded edges representing the key subgraph** extracted by the model. These visualizations allow us to assess whether the model accurately isolates chemically meaningful explanations.

---

**Figure 1: Ground-truth explanation successfully captured**  
_Example 40 - from results folder_  
![Figure 1](./sampled_explanations/example_40_sampled.png)

- The bolded subgraph includes the **NO<sub>2</sub> group** and part of a **fused aromatic ring**, both known to correlate with mutagenicity.
- This explanation aligns strongly with ground-truth domain knowledge.
- A high-quality prototype and interpretability success case.

---

**Figure 2: Excessive C-C edge and missing one N-O edge**  
_Example 19 - from results folder_  
![Figure 2](./sampled_explanations/example_19_sampled.png)

- The model **misses one of the N-O bonds** in the nitro group.
- It also includes an **extra C-C bond** that is not chemically relevant to the mutagenic substructure.
- This indicates **partial faithfulness** - the model recognizes the general area of importance but fails to precisely localize the explanation.

---

**Figure 3: Partial capture of mutagenic structure**  
_Example 202 - from results folder_  
![Figure 3](./sampled_explanations/example_202_sampled.png)

- The molecule contains **two NO<sub>2</sub> groups**, but only one is captured in the bolded subgraph.
- The highlighted region covers part of the upper aromatic ring, omitting the second relevant NO<sub>2</sub> group.
- Suggests either subgraph size limitations or model uncertainty in explanation coverage.

---

### Summary

| Example      | Model Behavior                                             | Qualitative Assessment                |
|--------------|------------------------------------------------------------|---------------------------|
| `example_40` | Captures NO<sub>2</sub> and aromatic ring (true explanation)  | Ground-truth aligned   |
| `example_19` | Misses N-O edge, includes irrelevant C-C edge              | Imperfect explanation  |
| `example_202`| Captures only one NO<sub>2</sub> group, misses a second key region | Partial coverage       |
