# Full Training and Explanation Example

This notebook demonstrates MEGAN's complete workflow including multi-channel explanation extraction and visualization. Building on the basic usage example, this notebook focuses on MEGAN's unique explainability features and the systematic analysis of attention-based explanations.

## MEGAN's Multi-Channel Architecture in Practice

**Dual-Channel Regression**: For regression tasks, MEGAN employs two explanation channels that learn to identify molecular features contributing positively and negatively to the target property. This separation provides more nuanced explanations than single-attention mechanisms.

**Explanation Co-Training**: During training, MEGAN simultaneously optimizes prediction accuracy and explanation consistency through self-supervised learning, ensuring that attention weights capture genuine molecular relationships rather than spurious correlations.

**Attention-Based Interpretability**: Unlike post-hoc explanation methods, MEGAN's explanations are generated during the forward pass as an integral part of the prediction process, making them more reliable and computationally efficient.

# 🧪 Dataset Loading and Splitting

We'll prepare train and test datasets to evaluate MEGAN's performance and extract explanations from test samples.

In [None]:
import os
import csv
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from rich.pretty import pprint

plt.style.use('default')


In [None]:
# --- loading the dataset ---
# Load the molecular dataset - same as basic example
PATH: str = os.getcwd()
DATASET_PATH: str = os.path.join(PATH, "clogp.csv")

dataset: pd.DataFrame = pd.read_csv(DATASET_PATH)

print('Dataset size:', len(dataset))
print('\nDataset preview:')
print(dataset.head())

# --- train-test split ---
# Create a reproducible train-test split for unbiased performance evaluation
random.seed(42)
np.random.seed(42)

# Split into 80% train, 20% test
indices = list(range(len(dataset)))
test_size = int(0.2 * len(dataset))
test_indices = random.sample(indices, k=test_size)
train_indices = [i for i in indices if i not in test_indices]

train_dataset = dataset.iloc[train_indices].copy()
test_dataset = dataset.iloc[test_indices].copy()

# Save split datasets for reproducibility and potential reuse
TRAIN_PATH = os.path.join(PATH, "train_clogp.csv")
TEST_PATH = os.path.join(PATH, "test_clogp.csv")

train_dataset.to_csv(TRAIN_PATH, index=False)
test_dataset.to_csv(TEST_PATH, index=False)

print(f'\nTrain set size: {len(train_dataset)}')
print(f'Test set size: {len(test_dataset)}')
print(f'Train set saved to: {TRAIN_PATH}')
print(f'Test set saved to: {TEST_PATH}')

# ⚙️ Dataset Processing

We'll use the same MoleculeProcessing pipeline to ensure consistency with the basic example. This processing converts SMILES strings to graph dictionaries that MEGAN can process.

In [None]:
from visual_graph_datasets.processing.molecules import MoleculeProcessing

# Initialize the molecular processing pipeline - same as basic example
processing = MoleculeProcessing()

# Test the processing with a sample molecule to verify setup
SAMPLE_SMILES = 'C1=CC=C2C=C(CCN)C=CC2=C1'
sample_graph: dict = processing.process(SAMPLE_SMILES)

print('Graph attributes:')
pprint(list(sample_graph.keys()))
print(f'\nNumber of nodes: {len(sample_graph["node_attributes"])}')
print(f'Number of edges: {len(sample_graph["edge_attributes"])}')
print(f'Node features dimension: {processing.get_num_node_attributes()}')
print(f'Edge features dimension: {processing.get_num_edge_attributes()}')

# 🤖 MEGAN Model Configuration

Here we configure MEGAN with explanation capabilities enabled. The key parameters that distinguish MEGAN from standard GNNs are the explanation-related settings that enable the dual-channel attention mechanism and co-training objectives.

In [None]:
import pytorch_lightning as pl
from torch_geometric.loader import DataLoader
from graph_attention_student import Megan, SmilesDataset

# --- MEGAN configuration for explainable predictions ---
model = Megan(
    # Input dimensions must match the MoleculeProcessing output
    node_dim=processing.get_num_node_attributes(),
    edge_dim=processing.get_num_edge_attributes(),
    
    # Graph encoder architecture - deeper networks can capture more complex patterns
    units=[64, 64, 64],  # Three message-passing layers for hierarchical feature learning
    final_units=[64, 32, 1],  # Prediction MLP: embedding -> intermediate -> single logP value
    
    # Task configuration for regression
    prediction_mode='regression',
    learning_rate=1e-4,
    
    # --- MEGAN's unique explanation configuration ---
    importance_mode='regression',  # Explanation channels tailored for continuous targets
    
    # KEY: importance_factor > 0 activates explanation co-training
    # This forces attention weights to be predictive of the target independently
    # Higher values prioritize explanation quality over pure prediction accuracy
    importance_factor=1.0,
    
    # Sparsity regularization prevents the model from highlighting everything as "important"
    # This promotes focused, interpretable explanations rather than diffuse attention
    sparsity_factor=0.5,
    
    # Controls the baseline attention level - affects explanation granularity
    # Higher values create more sparse (selective) explanations
    importance_offset=1.0,
    
    # For regression: 2 channels capture positive and negative evidence separately
    # This dual-channel approach provides more nuanced explanations than single attention
    num_channels=2,
)

print('MEGAN model configured with explanation capabilities:')
print(f'- Explanation channels: {model.num_channels}')
print(f'- Explanation co-training factor: {model.importance_factor}')
print(f'- Sparsity regularization factor: {model.sparsity_factor}')

In [None]:
# --- Training with SmilesDataset for efficient molecular data streaming ---
# SmilesDataset provides lazy loading and on-the-fly graph conversion
# This is memory-efficient for large molecular datasets
train_smiles_dataset = SmilesDataset(
    dataset=TRAIN_PATH,
    smiles_column='smiles',
    target_columns=['value'],
    processing=processing,  # Must use the same processing pipeline for consistency
    reservoir_sampling=True,  # Enables proper shuffling without loading entire dataset
)

# Configure DataLoader for batch processing
train_loader = DataLoader(
    train_smiles_dataset,
    batch_size=64,  # Balanced for stable gradients and memory efficiency
    drop_last=True,  # Prevents BatchNorm issues with variable-sized final batches
    num_workers=4,   # Parallel SMILES->graph conversion
    prefetch_factor=2,  # Pre-load batches to hide processing latency
)

# --- MEGAN training with multi-objective optimization ---
trainer = pl.Trainer(
    max_epochs=150,  # Extended training needed for explanation convergence
    accelerator='auto',
    devices='auto',
)

print('Starting MEGAN training with explanation co-training...')
print('Training objectives:')
print('1. Prediction accuracy (MSE loss)')
print('2. Explanation consistency (importance_factor * explanation_loss)')  
print('3. Attention sparsity (sparsity_factor * sparsity_loss)')

trainer.fit(model, train_dataloaders=train_loader)
model.eval()  # Critical: switch to evaluation mode for consistent inference
print('MEGAN training completed with explanation capabilities!')

# 📊 MEGAN Performance Evaluation

We evaluate the trained model on unseen test data to assess both prediction accuracy and the reliability of the explanation mechanism.

In [None]:
# --- Evaluate MEGAN on both training and test sets ---
# This compares performance on seen vs unseen data to assess overfitting
train_predictions = []
train_targets = []

print('Evaluating MEGAN on training set...')
for idx, row in train_dataset.iterrows():
    smiles = row['smiles']
    target = row['value']
    
    try:
        graph = processing.process(smiles)
        result = model.forward_graph(graph)
        prediction = result['graph_output'].item()
        
        train_predictions.append(prediction)
        train_targets.append(target)
        
    except Exception as e:
        print(f'Skipping molecule {smiles}: {e}')
        continue

train_predictions = np.array(train_predictions)
train_targets = np.array(train_targets)

# Evaluate on test set
test_predictions = []
test_targets = []
test_smiles_list = []

print('Evaluating MEGAN on test set...')
for idx, row in test_dataset.iterrows():
    smiles = row['smiles']
    target = row['value']
    
    try:
        graph = processing.process(smiles)
        result = model.forward_graph(graph)
        prediction = result['graph_output'].item()
        
        test_predictions.append(prediction)
        test_targets.append(target)
        test_smiles_list.append(smiles)
        
    except Exception as e:
        print(f'Skipping molecule {smiles}: {e}')
        continue

test_predictions = np.array(test_predictions)
test_targets = np.array(test_targets)

# Calculate metrics for both sets
train_mse = mean_squared_error(train_targets, train_predictions)
train_mae = mean_absolute_error(train_targets, train_predictions)
train_r2 = r2_score(train_targets, train_predictions)
train_rmse = np.sqrt(train_mse)

test_mse = mean_squared_error(test_targets, test_predictions)
test_mae = mean_absolute_error(test_targets, test_predictions)
test_r2 = r2_score(test_targets, test_predictions)
test_rmse = np.sqrt(test_mse)

print(f'\nMEGAN Training Set Performance:')
print(f'Samples: {len(train_predictions)} | RMSE: {train_rmse:.4f} | MAE: {train_mae:.4f} | R²: {train_r2:.4f}')

print(f'\nMEGAN Test Set Performance:')
print(f'Samples: {len(test_predictions)} | RMSE: {test_rmse:.4f} | MAE: {test_mae:.4f} | R²: {test_r2:.4f}')

# Create side-by-side regression plots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Training set plot (red)
ax1.scatter(train_targets, train_predictions, alpha=0.6, color='red', s=20)
train_range = [train_targets.min(), train_targets.max()]
ax1.plot(train_range, train_range, 'k--', lw=2)
ax1.set_xlabel('True cLogP')
ax1.set_ylabel('Predicted cLogP')
ax1.set_title(f'Training Set (n={len(train_predictions)})\nRMSE: {train_rmse:.3f} | MAE: {train_mae:.3f} | R²: {train_r2:.3f}')
ax1.grid(True, alpha=0.3)
ax1.set_aspect('equal', adjustable='box')

# Test set plot (blue)
ax2.scatter(test_targets, test_predictions, alpha=0.6, color='blue', s=20)
test_range = [test_targets.min(), test_targets.max()]
ax2.plot(test_range, test_range, 'k--', lw=2)
ax2.set_xlabel('True cLogP')
ax2.set_ylabel('Predicted cLogP')
ax2.set_title(f'Test Set (n={len(test_predictions)})\nRMSE: {test_rmse:.3f} | MAE: {test_mae:.3f} | R²: {test_r2:.3f}')
ax2.grid(True, alpha=0.3)
ax2.set_aspect('equal', adjustable='box')

plt.tight_layout()
plt.show()

# 🔍 Extracting MEGAN's Multi-Channel Explanations

Now we demonstrate how to extract and interpret MEGAN's dual-channel attention mechanisms. Unlike standard GNNs, MEGAN provides separate importance scores for positive and negative evidence, enabling more detailed analysis of model reasoning.

In [None]:
# Select representative examples for explanation analysis
np.random.seed(42)
explanation_indices = np.random.choice(len(test_smiles_list), size=10, replace=False)

print(f'Selected {len(explanation_indices)} examples for MEGAN explanation analysis:')
for i, idx in enumerate(explanation_indices):
    smiles = test_smiles_list[idx]
    target = test_targets[idx]
    prediction = test_predictions[idx]
    print(f'{i+1}. SMILES: {smiles[:50]}... | True: {target:.3f} | Pred: {prediction:.3f}')

In [None]:
# Function to normalize MEGAN's importance values for consistent visualization
def normalize_importances(importances):
    """
    Normalize MEGAN importance values to [0, 1] range for consistent visualization.
    This ensures that attention weights are comparable across different molecules
    and channels, enabling meaningful visual comparisons.
    """
    min_val = np.min(importances)
    max_val = np.max(importances)
    if max_val > min_val:
        return (importances - min_val) / (max_val - min_val)
    else:
        return np.zeros_like(importances)

# Extract MEGAN's dual-channel explanations for selected molecules
explanation_results = []

print('Extracting MEGAN explanations...')
for idx in explanation_indices:
    smiles = test_smiles_list[idx]
    target = test_targets[idx]
    prediction = test_predictions[idx]
    
    # Process molecule and extract full MEGAN output including attention weights
    graph = processing.process(smiles)
    result = model.forward_graph(graph)
    
    # MEGAN's key outputs:
    # - node_importance: (num_nodes, num_channels) - attention weights for each atom
    # - edge_importance: (num_edges, num_channels) - attention weights for each bond
    node_importance = result['node_importance']
    edge_importance = result['edge_importance']
    
    # Normalize for consistent visualization across molecules
    node_importance_norm = normalize_importances(node_importance)
    edge_importance_norm = normalize_importances(edge_importance)
    
    explanation_results.append({
        'smiles': smiles,
        'target': target,
        'prediction': prediction,
        'graph': graph,
        'node_importance': node_importance_norm,  # Shape: (num_atoms, 2) for dual channels
        'edge_importance': edge_importance_norm,  # Shape: (num_bonds, 2) for dual channels
    })

print(f'Extracted MEGAN explanations for {len(explanation_results)} molecules.')
print(f'Each explanation has {explanation_results[0]["node_importance"].shape[1]} attention channels.')
print('Channel 0: Features contributing negatively to logP (hydrophilic)')  
print('Channel 1: Features contributing positively to logP (hydrophobic)')

# 🎨 Visualizing MEGAN's Dual-Channel Explanations

MEGAN's visualization system overlays attention-based importance maps onto molecular structures. The dual-channel approach separates positive and negative evidence, providing clearer interpretation of model reasoning than single-channel attention mechanisms.

In [None]:
from visual_graph_datasets.visualization.importances import plot_node_importances_background
from visual_graph_datasets.visualization.importances import plot_edge_importances_background
from visual_graph_datasets.visualization.base import draw_image
import tempfile

# Create comprehensive visualization showing MEGAN's dual-channel explanations
num_examples = len(explanation_results)
num_channels = 2  # MEGAN's dual channels for regression
fig_width = 4 * (num_channels + 1)  # Original molecule + 2 explanation channels
fig_height = 3 * num_examples

fig, axes = plt.subplots(
    nrows=num_examples, 
    ncols=num_channels + 1,  # Molecular structure + dual explanation channels
    figsize=(fig_width, fig_height)
)

# Ensure axes is always 2D for consistent indexing
if num_examples == 1:
    axes = axes.reshape(1, -1)

# Color scheme for MEGAN's dual channels
channel_colors = ['royalblue', 'orangered']  # Blue for negative, red for positive  
channel_labels = ['Hydrophilic Evidence', 'Hydrophobic Evidence']

print('Creating MEGAN explanation visualizations...')
for row, result in enumerate(explanation_results):
    smiles = result['smiles']
    target = result['target']
    prediction = result['prediction']
    graph = result['graph']
    node_importance = result['node_importance']
    edge_importance = result['edge_importance']
    
    # Generate molecular structure visualization using MoleculeProcessing
    mol_fig, node_positions = processing.visualize_as_figure(smiles, width=400, height=400)
    # Coordinate transformation for matplotlib compatibility
    node_positions[:, 1] = 400 - node_positions[:, 1]  # Invert y-axis for proper overlay alignment
    
    # Save molecular visualization for overlay operations
    temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
    mol_fig.savefig(temp_file.name)
    temp_file.close()
    
    # Store node positions in graph dict for visualization functions
    graph['node_positions'] = node_positions
    
    # Column 0: Original molecular structure
    ax_orig = axes[row, 0]
    ax_orig.imshow(np.array(mol_fig.canvas.renderer.buffer_rgba()))
    ax_orig.set_title(f'Molecule {row+1}\nTrue: {target:.3f}, Pred: {prediction:.3f}')
    ax_orig.axis('off')
    
    # Columns 1-2: MEGAN's dual explanation channels
    for channel in range(num_channels):
        ax = axes[row, channel + 1]
        
        # Draw base molecular structure
        ax.imshow(np.array(mol_fig.canvas.renderer.buffer_rgba()))
        
        # Overlay MEGAN's node importance (atomic contributions)
        # Higher intensity = greater importance in this channel
        plot_node_importances_background(
            ax=ax,
            g=graph,
            node_positions=node_positions,
            node_importances=node_importance[:, channel],
            color=channel_colors[channel],
            radius=20,  # Size of importance circles around atoms
            v_min=0,
            v_max=1,
        )
        
        # Overlay MEGAN's edge importance (bond contributions) 
        # Line thickness/opacity indicates bond importance
        plot_edge_importances_background(
            ax=ax,
            g=graph,
            node_positions=node_positions,
            edge_importances=edge_importance[:, channel],
            color=channel_colors[channel],
            thickness=10,  # Width of importance lines over bonds
            v_min=0,
            v_max=1,
        )
        
        ax.set_title(f'{channel_labels[channel]}')
        ax.axis('off')
    
    plt.close(mol_fig)  # Clean up temporary figure

plt.tight_layout()
plt.suptitle('MEGAN Dual-Channel Explanations for Test Molecules', y=1.02, fontsize=16)
plt.show()

print('\nExplanation Guide:')
print('• Blue highlights: Molecular features that decrease logP (increase hydrophilicity)')
print('• Red highlights: Molecular features that increase logP (increase hydrophobicity)')
print('• Intensity indicates the strength of contribution according to MEGAN attention weights')
print('• Both atoms (circles) and bonds (lines) can contribute to the final prediction')

# 📋 MEGAN Analysis Summary

This notebook demonstrated MEGAN's complete explainable AI workflow for molecular property prediction:

## Key MEGAN Features Demonstrated:

### 1. **Multi-Objective Training**
- **Prediction accuracy**: Standard regression loss for logP values
- **Explanation consistency**: Self-supervised loss ensuring attention weights are predictive
- **Sparsity regularization**: Promotes focused, interpretable explanations

### 2. **Dual-Channel Architecture** 
- **Channel separation**: Distinct attention mechanisms for positive/negative evidence
- **Chemical relevance**: Hydrophobic vs. hydrophilic feature identification
- **Nuanced explanations**: More detailed than single-attention approaches

### 3. **Integrated Explanation Generation**
- **Forward-pass explanations**: Generated during prediction, not post-hoc
- **Node and edge importance**: Both atomic and bonding contributions captured
- **Visualization-ready output**: Direct integration with molecular structure displays

## Advantages Over Standard GNNs:

**Interpretability**: MEGAN explanations reveal *why* predictions are made, not just *what* is predicted.

**Chemical Validity**: Attention weights align with known structure-property relationships in chemistry.

**Reliability**: Co-training ensures explanations genuinely affect predictions rather than being mere visualizations.

**Efficiency**: Explanations are generated simultaneously with predictions, requiring no additional computational overhead during inference.

This explainable AI approach makes MEGAN particularly valuable for scientific applications where understanding model reasoning is crucial for hypothesis generation, experimental design, and regulatory approval processes.