# Interpretability Analysis: Full Results Generation

This notebook runs the complete interpretability pipeline and generates all figures for the results section:

- **GaTmCC Performance**: Training curves and test metrics
- **PCA Baseline**: Performance comparison
- **SHAP Analysis**: Protein importance (100 test samples)
- **Attention Analysis**: Attention patterns and correlations
- **Comparative Analysis**: Model comparisons and visualizations
- **Graph Bias Scale**: Learned parameter analysis


In [1]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import json
import os
from scipy.stats import pearsonr, spearmanr
from collections import defaultdict

# Determine project root
cwd = Path(os.getcwd())
if (cwd / "src").exists() and (cwd / "interpretability").exists():
    PROJECT_ROOT = cwd
elif (cwd.parent / "src").exists() and (cwd.parent / "interpretability").exists():
    PROJECT_ROOT = cwd.parent
else:
    PROJECT_ROOT = Path().resolve().parent

sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "src"))
sys.path.insert(0, str(PROJECT_ROOT / "interpretability"))

import config
from model import GraphTransformerClassifier
from graph_prior import load_graph_prior, get_graph_features_as_tensors
from utils import load_trained_model, load_data, get_output_dirs, DEFAULT_PATHS, compute_graph_distances

plt.rcParams['figure.dpi'] = 100
sns.set_style('whitegrid')

# Setup output directories
plots_dir, results_dir = get_output_dirs(PROJECT_ROOT / "results")
output_dir = plots_dir / "Model_Comparison_Plots"
output_dir.mkdir(parents=True, exist_ok=True)

print(f"Project root: {PROJECT_ROOT}")
print(f"Output directory: {output_dir}")


Project root: /Users/leophelan/Code/projects/graph_transformer_proteomics/CleanedProject/Executable_Project_Code
Output directory: /Users/leophelan/Code/projects/graph_transformer_proteomics/CleanedProject/Executable_Project_Code/results/plots/Model_Comparison_Plots


## 1. Run Interpretability Analyses

Execute SHAP, attention, and PCA baseline analyses.


In [None]:
# Run SHAP analysis
print("=" * 70)
print("Running SHAP Analysis (100 test samples)")
print("=" * 70)
try:
    from shap_analysis import main as shap_main
    shap_main()
    print("✓ SHAP analysis complete\n")
except Exception as e:
    print(f"✗ SHAP analysis failed: {e}")
    import traceback
    traceback.print_exc()

# Run attention analysis
print("=" * 70)
print("Running Attention Analysis")
print("=" * 70)
try:
    from attention_analysis import main as attn_main
    attn_main()
    print("✓ Attention analysis complete\n")
except Exception as e:
    print(f"✗ Attention analysis failed: {e}")
    import traceback
    traceback.print_exc()

# Run PCA baseline
print("=" * 70)
print("Training PCA95 Baseline")
print("=" * 70)
try:
    from pca_baseline import main as pca_main
    pca_main()
    print("✓ PCA baseline complete\n")
except Exception as e:
    print(f"✗ PCA baseline failed: {e}")
    import traceback
    traceback.print_exc()


Running SHAP Analysis (100 test samples)

SHAP Analysis: Protein Importance

Output directory: /Users/leophelan/Code/projects/graph_transformer_proteomics/CleanedProject/Executable_Project_Code/results/plots/SHAP_Plots
Device: cpu (optimized for SHAP compatibility)
Using 8 CPU threads

Loading model and data...
Loaded prior: 198 proteins, 1184 edges
Loading data from /Users/leophelan/Code/projects/graph_transformer_proteomics/CleanedProject/Executable_Project_Code/data/processed_datasets/tcga_pancan_rppa_compiled.csv...
Loaded 7523 samples
Found 198 protein columns in CSV
Filtering samples: 7523/7523 have ≤50.0% missing values
After filtering: 7523 samples across 32 cancer types
Cancer type distribution:
CANCER_TYPE_ACRONYM
ACC      45
BLCA    343
BRCA    876
CESC    166
CHOL     30
COAD    346
DLBC     33
ESCA    125
GBM     231
HNSC    211
KICH     62
KIRC    455
KIRP    209
LGG     428
LIHC    179
LUAD    360
LUSC    317
MESO     63
OV      414
PAAD    122
PCPG     79
PRAD    350
RE



## 2. Load Pre-computed Results


In [None]:
# Load SHAP results
shap_results_path = plots_dir / 'SHAP_Plots' / 'top_proteins.json'
with open(shap_results_path, 'r') as f:
    shap_data = json.load(f)

shap_proteins = [p['protein'] for p in shap_data]
shap_importance = np.array([p['importance'] for p in shap_data])
shap_dict = {p['protein']: p['importance'] for p in shap_data}

print(f"Loaded SHAP: {len(shap_data)} proteins")
print(f"Top 5: {shap_proteins[:5]}")

# Load PCA results
pca_results_path = plots_dir / 'PCA_Cox_Plots' / 'top_proteins.json'
with open(pca_results_path, 'r') as f:
    pca_data = json.load(f)

pca_proteins = [p['protein'] for p in pca_data]
pca_importance = np.array([p['importance'] for p in pca_data])
pca_dict = {p['protein']: p['importance'] for p in pca_data}

print(f"\nLoaded PCA-Logistic: {len(pca_data)} proteins")
print(f"Top 5: {pca_proteins[:5]}")

# Load PPI network
prior_data = np.load(DEFAULT_PATHS['prior'], allow_pickle=True)
A = prior_data['A']
all_proteins = prior_data['protein_cols'].tolist()

print(f"\nLoaded PPI network: {A.shape[0]} proteins, {int(A.sum()//2)} edges")


## 3. Load Model and Extract Attention Scores


In [None]:
# Load model and extract attention
device = 'cpu'
model, graph_prior, label_info = load_trained_model(device=device)
data_splits, protein_names, data_loaders = load_data(return_dataloaders=True)

# Extract attention using hooks
class AttentionExtractor:
    def __init__(self, model):
        self.model = model
        self.attention_maps = []
        self.hooks = []
        
    def register_hooks(self):
        for layer in self.model.transformer:
            hook = layer.self_attn.register_forward_hook(self._hook_fn)
            self.hooks.append(hook)
    
    def _hook_fn(self, module, input, output):
        attn_weights = output[1]  # (batch, heads, seq_len, seq_len)
        self.attention_maps.append(attn_weights.detach().cpu())
    
    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []

extractor = AttentionExtractor(model)
extractor.register_hooks()

# Run on test data
model.eval()
test_loader = data_loaders['test']
all_attention = []

with torch.no_grad():
    for batch_x, batch_y in test_loader:
        batch_x = batch_x.to(device)
        _ = model(batch_x)
        if extractor.attention_maps:
            all_attention.append(extractor.attention_maps)
            extractor.attention_maps = []

extractor.remove_hooks()

# Aggregate attention across layers, heads, and samples
if all_attention:
    # Flatten list structure: [[layer1_heads], [layer2_heads], ...] per sample
    n_layers = len(model.transformer)
    n_heads = model.n_heads
    n_proteins = A.shape[0]
    
    attention_matrix = np.zeros((n_proteins, n_proteins))
    count = 0
    
    for sample_attn in all_attention:
        for layer_idx, layer_attn in enumerate(sample_attn):
            # layer_attn: (batch, heads, seq_len, seq_len)
            batch_attn = layer_attn[0]  # (heads, seq_len, seq_len)
            for head_idx in range(n_heads):
                head_attn = batch_attn[head_idx].numpy()  # (seq_len, seq_len)
                # Remove CLS token (first row/col)
                protein_attn = head_attn[1:, 1:]  # (n_proteins, n_proteins)
                attention_matrix += protein_attn
                count += 1
    
    attention_matrix /= count
    
    # Compute per-protein attention scores (sum of attention received)
    protein_attention = {}
    for i, protein in enumerate(all_proteins):
        protein_attention[protein] = float(attention_matrix[:, i].sum())
    
    print(f"Extracted attention from {len(all_attention)} samples")
    print(f"Averaged over {n_layers} layers × {n_heads} heads")
else:
    attention_matrix = None
    protein_attention = {}
    print("No attention extracted")


## 4. Model Performance Metrics


In [None]:
# Evaluate model performance
from sklearn.metrics import accuracy_score, f1_score, classification_report

model.eval()
test_x, test_y = data_splits['test']
test_x_tensor = torch.FloatTensor(test_x).to(device)

with torch.no_grad():
    logits = model(test_x_tensor)
    preds = logits.argmax(dim=1).cpu().numpy()

test_acc = accuracy_score(test_y, preds)
test_f1_macro = f1_score(test_y, preds, average='macro')

print("=" * 70)
print("GaTmCC Performance")
print("=" * 70)
print(f"Test Accuracy: {test_acc*100:.2f}%")
print(f"Test F1 (macro): {test_f1_macro*100:.2f}%")
print(f"Number of classes: {label_info['n_classes']}")
print("=" * 70)

# Load PCA baseline results if available
pca_stats_path = plots_dir / 'PCA_Cox_Plots' / 'pca_stats.txt'
pca_test_acc = None
if pca_stats_path.exists():
    with open(pca_stats_path, 'r') as f:
        content = f.read()
        for line in content.split('\n'):
            if 'Test accuracy' in line.lower():
                try:
                    pca_test_acc = float(line.split(':')[1].strip().replace('%', '')) / 100
                except:
                    pass

if pca_test_acc:
    print(f"\nPCA95+LogReg Test Accuracy: {pca_test_acc*100:.2f}%")
    print(f"Difference: {(pca_test_acc - test_acc)*100:.2f}% (PCA advantage)")


## 5. Generate Result Figures

### Figure 3: SHAP vs Attention Correlation


In [None]:
# Fig 3a: SHAP vs Attention scatter plot (Top 50)
top50_shap = shap_proteins[:50]
top50_pca_shap = pca_proteins[:50] if len(pca_proteins) >= 50 else pca_proteins

common_proteins = [p for p in top50_shap if p in protein_attention and p in all_proteins]
shap_vals = [shap_dict[p] for p in common_proteins]
attn_vals = [protein_attention[p] for p in common_proteins]

# Color: red if in top 50 of either transformer or PCA SHAP
overlap_top50 = set(top50_shap[:50]) | set(top50_pca_shap[:50])
colors = ['darkred' if p in overlap_top50 else 'steelblue' for p in common_proteins]

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: Scatter plot
ax = axes[0]
ax.scatter(shap_vals, attn_vals, c=colors, alpha=0.7, s=100, edgecolors='black', linewidth=0.5)

# Label top 10
for i, protein in enumerate(common_proteins[:10]):
    ax.annotate(protein.split('|')[1] if '|' in protein else protein[:10],
                (shap_vals[i], attn_vals[i]),
                xytext=(5, 5), textcoords='offset points',
                fontsize=8, alpha=0.8,
                bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7, edgecolor='none'))

if len(shap_vals) > 2:
    corr, p_val = pearsonr(shap_vals, attn_vals)
    ax.text(0.05, 0.95, f'Pearson r = {corr:.3f}\np = {p_val:.3f}',
           transform=ax.transAxes, fontsize=11,
           verticalalignment='top',
           bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))

ax.set_xlabel('SHAP Importance', fontsize=12, fontweight='bold')
ax.set_ylabel('Attention Score', fontsize=12, fontweight='bold')
ax.set_title('SHAP vs Attention (Transformer Top 50)\nRed = Top 50 in Transformer or PCA SHAP', 
             fontsize=13, fontweight='bold')
ax.grid(alpha=0.3)

# Right: Correlation distribution
ax = axes[1]
ax.hist(shap_vals, bins=20, alpha=0.6, label='SHAP', color='steelblue', edgecolor='black')
ax_twin = ax.twinx()
ax_twin.hist(attn_vals, bins=20, alpha=0.6, label='Attention', color='coral', edgecolor='black')
ax.set_xlabel('SHAP Importance', fontsize=12)
ax.set_ylabel('Frequency (SHAP)', fontsize=12, color='steelblue')
ax_twin.set_ylabel('Frequency (Attention)', fontsize=12, color='coral')
ax.set_title('Distribution Comparison', fontsize=13, fontweight='bold')
ax.grid(alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(output_dir / 'fig3_shap_vs_attention.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Saved: {output_dir / 'fig3_shap_vs_attention.png'}")
print(f"Correlation: r={corr:.3f}, p={p_val:.3f}")


### Figure 4: Attention-based Relational Structure (FASN Example)


In [None]:
# Fig 4: FASN attention connections
if attention_matrix is not None:
    # Find FASN in proteins
    fasn_idx = None
    for i, p in enumerate(all_proteins):
        if 'FASN' in p:
            fasn_idx = i
            break
    
    if fasn_idx is not None:
        # Get top proteins FASN attends to
        fasn_attention = attention_matrix[fasn_idx, :]
        top_indices = np.argsort(fasn_attention)[-20:][::-1]
        
        # Create submatrix for visualization
        selected_indices = [fasn_idx] + [idx for idx in top_indices if idx != fasn_idx][:19]
        selected_proteins = [all_proteins[i] for i in selected_indices]
        
        attention_sub = attention_matrix[np.ix_(selected_indices, selected_indices)]
        ppi_sub = A[np.ix_(selected_indices, selected_indices)]
        
        fig, axes = plt.subplots(1, 2, figsize=(16, 7))
        
        # Left: Attention heatmap
        ax = axes[0]
        sns.heatmap(attention_sub, cmap='RdYlBu_r', square=True,
                   xticklabels=[p.split('|')[1] if '|' in p else p[:15] for p in selected_proteins],
                   yticklabels=[p.split('|')[1] if '|' in p else p[:15] for p in selected_proteins],
                   cbar_kws={'label': 'Attention Weight'},
                   ax=ax, fmt='.3f')
        ax.set_title('FASN Attention Pattern\n(Top 20 connected proteins)', 
                    fontsize=14, fontweight='bold', pad=15)
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
        plt.setp(ax.get_yticklabels(), rotation=0)
        
        # Right: PPI network overlay
        ax = axes[1]
        sns.heatmap(ppi_sub, cmap='Blues', square=True,
                   xticklabels=[p.split('|')[1] if '|' in p else p[:15] for p in selected_proteins],
                   yticklabels=[p.split('|')[1] if '|' in p else p[:15] for p in selected_proteins],
                   cbar_kws={'label': 'PPI Edge'},
                   ax=ax, fmt='d')
        ax.set_title('PPI Network Structure\n(Same proteins)', 
                    fontsize=14, fontweight='bold', pad=15)
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
        plt.setp(ax.get_yticklabels(), rotation=0)
        
        plt.tight_layout()
        plt.savefig(output_dir / 'fig4_fasn_attention.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        # Print connections
        print("FASN top attention connections:")
        for i, idx in enumerate(top_indices[:10], 1):
            protein = all_proteins[idx]
            attn_val = fasn_attention[idx]
            ppi_edge = int(A[fasn_idx, idx])
            print(f"  {i:2d}. {protein:30s} | Attention: {attn_val:.4f} | PPI: {ppi_edge}")
        
        print(f"\nSaved: {output_dir / 'fig4_fasn_attention.png'}")
    else:
        print("FASN not found in protein list")
else:
    print("Attention matrix not available")


### Figure 5: Graph Bias Scale Analysis


In [None]:
# Extract graph_bias_scale from model
graph_bias_scales = []
for layer in model.transformer:
    if hasattr(layer.self_attn, 'graph_bias_scale'):
        bias_scale = layer.self_attn.graph_bias_scale.detach().cpu().numpy()
        graph_bias_scales.append(bias_scale)

if graph_bias_scales:
    graph_bias_scales = np.array(graph_bias_scales)  # (n_layers, n_heads)
    n_layers, n_heads = graph_bias_scales.shape
    
    # Statistics
    mean_val = graph_bias_scales.mean()
    std_val = graph_bias_scales.std()
    var_val = graph_bias_scales.var()
    min_val = graph_bias_scales.min()
    max_val = graph_bias_scales.max()
    range_val = max_val - min_val
    cv = std_val / mean_val if mean_val != 0 else 0
    
    print("=" * 70)
    print("Graph Bias Scale Statistics")
    print("=" * 70)
    print(f"Mean: {mean_val:.6f}")
    print(f"Std: {std_val:.6f}")
    print(f"Variance: {var_val:.6f}")
    print(f"Min: {min_val:.6f}")
    print(f"Max: {max_val:.6f}")
    print(f"Range: {range_val:.6f}")
    print(f"Coefficient of Variation: {cv:.6f}")
    print(f"Shape: {n_layers} layers × {n_heads} heads")
    print("=" * 70)
    
    # Fig 5: Visualizations
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    
    # Top left: Bar chart per head (mean across layers)
    ax = axes[0, 0]
    head_means = graph_bias_scales.mean(axis=0)
    ax.bar(range(n_heads), head_means, color='steelblue', edgecolor='black', alpha=0.8)
    ax.axhline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Overall mean: {mean_val:.4f}')
    ax.set_xlabel('Attention Head', fontsize=12, fontweight='bold')
    ax.set_ylabel('Mean Graph Bias Scale', fontsize=12, fontweight='bold')
    ax.set_title('Graph Bias Scale by Head\n(Mean across layers)', fontsize=13, fontweight='bold')
    ax.set_xticks(range(n_heads))
    ax.set_xticklabels([f'H{i+1}' for i in range(n_heads)])
    ax.legend()
    ax.grid(alpha=0.3, axis='y')
    
    # Top right: Histogram
    ax = axes[0, 1]
    ax.hist(graph_bias_scales.flatten(), bins=30, color='steelblue', edgecolor='black', alpha=0.7)
    ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.4f}')
    ax.set_xlabel('Graph Bias Scale', fontsize=12, fontweight='bold')
    ax.set_ylabel('Frequency', fontsize=12, fontweight='bold')
    ax.set_title('Distribution of Graph Bias Scale\n(All layers × heads)', fontsize=13, fontweight='bold')
    ax.legend()
    ax.grid(alpha=0.3, axis='y')
    
    # Bottom left: Heatmap (layers × heads)
    ax = axes[1, 0]
    sns.heatmap(graph_bias_scales, cmap='RdYlBu_r', annot=True, fmt='.4f',
               xticklabels=[f'H{i+1}' for i in range(n_heads)],
               yticklabels=[f'L{i+1}' for i in range(n_layers)],
               cbar_kws={'label': 'Graph Bias Scale'},
               ax=ax, square=True)
    ax.set_xlabel('Attention Head', fontsize=12, fontweight='bold')
    ax.set_ylabel('Layer', fontsize=12, fontweight='bold')
    ax.set_title('Graph Bias Scale Heatmap\n(Layers × Heads)', fontsize=13, fontweight='bold')
    
    # Bottom right: Box plot by layer
    ax = axes[1, 1]
    data_for_box = [graph_bias_scales[i, :] for i in range(n_layers)]
    bp = ax.boxplot(data_for_box, labels=[f'L{i+1}' for i in range(n_layers)],
                   patch_artist=True)
    for patch in bp['boxes']:
        patch.set_facecolor('steelblue')
        patch.set_alpha(0.7)
    ax.axhline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.4f}')
    ax.set_xlabel('Layer', fontsize=12, fontweight='bold')
    ax.set_ylabel('Graph Bias Scale', fontsize=12, fontweight='bold')
    ax.set_title('Graph Bias Scale Distribution by Layer', fontsize=13, fontweight='bold')
    ax.legend()
    ax.grid(alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(output_dir / 'fig5_graph_bias_scale.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\nSaved: {output_dir / 'fig5_graph_bias_scale.png'}")
else:
    print("Graph bias scale not found in model")


In [None]:
# Fig 2: Model comparison bar chart
if pca_test_acc:
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    # Left: Accuracy comparison
    ax = axes[0]
    models = ['PCA95+LogReg', 'GaTmCC']
    accuracies = [pca_test_acc * 100, test_acc * 100]
    colors = ['steelblue', 'darkgreen']
    
    bars = ax.bar(models, accuracies, color=colors, edgecolor='black', linewidth=1.5, alpha=0.8)
    for bar, acc in zip(bars, accuracies):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
               f'{acc:.2f}%', ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    ax.set_ylabel('Test Accuracy (%)', fontsize=12, fontweight='bold')
    ax.set_title('Model Comparison: Accuracy', fontsize=13, fontweight='bold')
    ax.set_ylim(0, max(accuracies) + 5)
    ax.grid(True, alpha=0.3, axis='y')
    
    # Right: F1 comparison
    ax = axes[1]
    # Try to get PCA F1 from stats file
    pca_f1 = None
    if pca_stats_path.exists():
        with open(pca_stats_path, 'r') as f:
            content = f.read()
            for line in content.split('\n'):
                if 'f1' in line.lower() and 'macro' in line.lower():
                    try:
                        pca_f1 = float(line.split(':')[1].strip().replace('%', '')) / 100
                    except:
                        pass
    
    if pca_f1:
        f1_scores = [pca_f1 * 100, test_f1_macro * 100]
        bars = ax.bar(models, f1_scores, color=colors, edgecolor='black', linewidth=1.5, alpha=0.8)
        for bar, f1 in zip(bars, f1_scores):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                   f'{f1:.2f}%', ha='center', va='bottom', fontsize=11, fontweight='bold')
        ax.set_ylabel('Test F1 (macro) (%)', fontsize=12, fontweight='bold')
        ax.set_title('Model Comparison: F1 Score', fontsize=13, fontweight='bold')
        ax.set_ylim(0, max(f1_scores) + 5)
    else:
        ax.text(0.5, 0.5, 'F1 scores not available', ha='center', va='center',
               transform=ax.transAxes, fontsize=12)
        ax.set_title('Model Comparison: F1 Score', fontsize=13, fontweight='bold')
    
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(output_dir / 'fig2_model_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"Saved: {output_dir / 'fig2_model_comparison.png'}")
else:
    print("PCA baseline results not available for comparison")


### Figure 1: Training Curves (if available)

Note: Training logs may not be available if model was pretrained. This cell checks for log files.


In [None]:
# Check for training logs
log_paths = [
    PROJECT_ROOT / "results" / "classifiers" / "cancer_type_classifiers" / "transformer" / "training.log",
    PROJECT_ROOT / "results" / "classifiers" / "cancer_type_classifiers" / "transformer" / "logs" / "training.log",
    PROJECT_ROOT.parent / "CleanedProject" / "Results" / "classifiers" / "cancer_type_classifiers" / "transformer" / "training.log",
]

training_log = None
for path in log_paths:
    if path.exists():
        training_log = path
        break

if training_log:
    print(f"Found training log: {training_log}")
    # Parse log file (basic parsing - adjust based on actual log format)
    epochs = []
    train_losses = []
    val_accs = []
    
    with open(training_log, 'r') as f:
        for line in f:
            if 'epoch' in line.lower() and 'loss' in line.lower():
                # Try to extract epoch and loss (adjust regex based on actual format)
                import re
                epoch_match = re.search(r'epoch[:\s]+(\d+)', line, re.I)
                loss_match = re.search(r'loss[:\s]+([\d.]+)', line, re.I)
                if epoch_match and loss_match:
                    epochs.append(int(epoch_match.group(1)))
                    train_losses.append(float(loss_match.group(1)))
            if 'val' in line.lower() and 'acc' in line.lower():
                acc_match = re.search(r'acc[:\s]+([\d.]+)', line, re.I)
                if acc_match:
                    val_accs.append(float(acc_match.group(1)))
    
    if epochs and train_losses:
        # Ensure same length
        min_len = min(len(epochs), len(train_losses), len(val_accs) if val_accs else len(epochs))
        epochs = epochs[:min_len]
        train_losses = train_losses[:min_len]
        if val_accs:
            val_accs = val_accs[:min_len]
        
        fig, axes = plt.subplots(1, 2, figsize=(14, 6))
        
        # Left: Training loss
        ax = axes[0]
        ax.plot(epochs, train_losses, 'b-', linewidth=2, label='Training Loss', marker='o', markersize=4)
        ax.set_xlabel('Epoch', fontsize=12, fontweight='bold')
        ax.set_ylabel('Loss', fontsize=12, fontweight='bold')
        ax.set_title('Training Loss Over Time', fontsize=13, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Right: Validation accuracy
        ax = axes[1]
        if val_accs:
            ax.plot(epochs, [a*100 for a in val_accs], 'g-', linewidth=2, label='Validation Accuracy', marker='s', markersize=4)
            ax.set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
        else:
            ax.text(0.5, 0.5, 'Validation accuracy\nnot found in log', 
                   ha='center', va='center', transform=ax.transAxes, fontsize=12)
        ax.set_xlabel('Epoch', fontsize=12, fontweight='bold')
        ax.set_title('Validation Accuracy Over Time', fontsize=13, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(output_dir / 'fig1_training_curves.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        print(f"Saved: {output_dir / 'fig1_training_curves.png'}")
    else:
        print("Could not parse training log")
else:
    print("Training log not found. If you have training logs, place them in:")
    print("  results/classifiers/cancer_type_classifiers/transformer/training.log")


## Summary

All interpretability analyses complete. Generated figures:
- **Fig 1**: Training curves (if available)
- **Fig 2**: Model comparison (PCA vs Transformer)
- **Fig 3**: SHAP vs Attention correlation
- **Fig 4**: Attention relational structure (FASN example)
- **Fig 5**: Graph bias scale analysis

All figures saved to: `results/plots/Model_Comparison_Plots/`
