# AMLGentex Framework Tutorial

This notebook demonstrates the complete AMLGentex workflow for synthetic AML detection data generation and federated learning.

## Overview

AMLGentex is a framework for:
1. **Generating** synthetic AML transaction data with configurable patterns
2. **Preprocessing** transactions into ML-ready features
3. **Training** ML models in three settings:
   - **Centralized**: All data combined
   - **Federated**: Privacy-preserving collaborative learning
   - **Isolated**: Each institution trains independently
4. **Visualizing** results and transaction networks

## Convention Over Configuration

The framework follows a **convention-over-configuration** approach:
- Paths are auto-discovered from experiment name
- Clients (banks) are auto-discovered from data
- Results follow standard directory structure

**Standard experiment structure:**
```
experiments/my_experiment/
├── config/
│   ├── data.yaml              # Data generation config
│   ├── preprocessing.yaml     # Feature engineering config
│   └── models.yaml            # Model training config
├── spatial/                   # Generated spatial graph
├── temporal/                  # Generated transactions
├── preprocessed/              # ML-ready features
│   ├── centralized/           # Combined data
│   └── clients/               # Per-bank data
├── results/                   # Training results
│   ├── centralized/
│   ├── federated/
│   └── isolated/
└── visualizations/            # Plots and analysis
```

## Setup

First, let's import the necessary libraries and set up our experiment.

In [None]:
import os
import sys
from pathlib import Path
import yaml
import pickle
import pandas as pd
import matplotlib.pyplot as plt

# Add project root to path
project_root = Path.cwd()
sys.path.insert(0, str(project_root))

# Set experiment name - this is the ONLY thing you need to configure!
EXPERIMENT = "tutorial_demo"
experiment_root = project_root / "experiments" / EXPERIMENT

print(f"Project root: {project_root}")
print(f"Experiment: {EXPERIMENT}")
print(f"Experiment root: {experiment_root}")

## Step 1: Create Experiment Configuration

We'll create a minimal experiment with 1,000 accounts and 2 banks.

### 1.1 Create Directory Structure

In [None]:
# Create experiment directories
config_dir = experiment_root / "config"
os.makedirs(config_dir, exist_ok=True)

print(f"✓ Created experiment structure at: {experiment_root}")

### 1.2 Create Data Generation Config

In [None]:
# Instead of creating from scratch, copy the working config from template
import shutil

config_files = ['data.yaml',
                'preprocessing.yaml', 
                'models.yaml']

template_config = project_root / "experiments" / "template_experiment" / "config"

if template_config.exists():
    # Copy the config files
    for f in config_files:
        shutil.copy(template_config / f, config_dir / f)
        print(f"Copied {f} from {template_config}")
    
    # Update simulation_name in data.yaml to match experiment name
    data_yaml_path = config_dir / 'data.yaml'
    with open(data_yaml_path, 'r') as f:
        data_config = yaml.safe_load(f)
    
    data_config['general']['simulation_name'] = EXPERIMENT
    
    with open(data_yaml_path, 'w') as f:
        yaml.dump(data_config, f, default_flow_style=False, sort_keys=False)
    
    print(f"\n✓ Updated simulation_name to: {EXPERIMENT}")
else:
    print("Template config not found. Please ensure template_experiment exists.")

### 1.3 Create Supporting Configuration Files

We'll copy minimal configuration files from the template experiment.

In [None]:
# Copy supporting files from template
template_dir = project_root / "experiments" / "template_experiment" / "config"

support_files = [
    'accounts.csv',
    'alertPatterns.csv', 
    'normalModels.csv',
    'degree.csv',
    'transactionType.csv',
    'demographics.csv'
]

for file in support_files:
    src = template_dir / file
    dst = config_dir / file
    if src.exists():
        shutil.copy(src, dst)
        print(f"✓ Copied {file}")
    else:
        print(f"⚠ Warning: {file} not found in template")

## Step 2: Generate Synthetic Transaction Data

Now we'll generate synthetic AML transaction data with both normal and suspicious (SAR) patterns.

**Convention**: Paths are automatically constructed from experiment name!

In [None]:
from src.data_creation import DataGenerator
from src.utils.config import load_data_config
from src.utils.logging import configure_logging
import tempfile

# Load config with auto-discovered paths
data_config = load_data_config(str(config_dir / 'data.yaml'))

print("Configuration loaded with auto-discovered paths:")
print(f"  Input dir: {data_config['input']['directory']}")
print(f"  Output dir: {data_config['output']['directory']}")
print()

# DataGenerator expects a config file path, so we create a temporary file
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as tmp:
    yaml.dump(data_config, tmp, default_flow_style=False, sort_keys=False)
    tmp_config = tmp.name

try:
    print("Generating synthetic transaction data...")
    print("This may take a few minutes...\n")
    
    # Enable verbose logging to see progress
    configure_logging(verbose=True)
    
    generator = DataGenerator(tmp_config)
    # force_spatial=True ensures fresh data generation with latest format
    # (includes demographics-based SALARY, AGE, CITY columns)
    tx_log_file = generator(force_spatial=True)
    
    print(f"\n✓ Transaction data generated: {tx_log_file}")
finally:
    os.unlink(tmp_config)

# Load and inspect the generated data
df = pd.read_parquet(tx_log_file)
print(f"\nGenerated {len(df):,} transactions")
print(f"Banks: {df['bankOrig'].unique().tolist()}")
print(f"SAR transactions: {df['isSAR'].sum():,} ({df['isSAR'].mean()*100:.2f}%)")
print(f"Time range: steps {df['step'].min()} to {df['step'].max()}")

## Step 2.1: Optimize Data Generation (Optional)

**This step is optional but recommended** - it uses Bayesian optimization to tune data generation parameters to achieve target model performance under operational constraints.

The optimizer uses **two-phase spatial generation**:
1. **Baseline (once)**: Generates normal accounts, graph structure, and demographics
2. **Alert Injection (per trial)**: Loads baseline and injects alerts with trial's ML selector configuration

This allows efficient exploration of ML selector weights without regenerating the entire graph.

### What It Optimizes

**Temporal parameters** (in `optimisation_bounds.temporal`):
- SAR transaction amounts, spending patterns, behavioral features

**ML Selector parameters** (in `optimisation_bounds.ml_selector`):
- Structure weights: How graph centrality influences account selection for SAR patterns
- KYC weights: How account attributes (balance, salary, age) influence selection
- Participation decay: How participation probability decays for multi-pattern accounts

### Optimization Objective

Multi-objective optimization balancing:
1. Utility metric loss: `|achieved_metric - target|`
2. Feature importance variance: Lower variance = more stable features

Results in a **Pareto front** of optimal trade-offs.

### Three Optimization Scenarios

Choose one based on your operational scenario:

**1. Precision@K (Alert Budget Constraint)**
- Use when: Your team can investigate K alerts per day
- Example: "We can review 100 alerts daily. Optimize for 80% precision in top 100."
- Parameters: `constraint_type='K'`, `constraint_value=100`, `utility_metric='precision'`, `target=0.8`

**2. Recall at FPR≤α (Regulatory Constraint)**
- Use when: Compliance requires FPR below a threshold
- Example: "FPR must be ≤1%. Optimize for 70% recall at this limit."
- Parameters: `constraint_type='fpr'`, `constraint_value=0.01`, `utility_metric='recall'`, `target=0.7`

**3. Precision at Recall≥target (Coverage Constraint)**
- Use when: Must detect a minimum fraction of SARs
- Example: "Must detect 70% of SARs. Optimize for 60% precision at this recall."
- Parameters: `constraint_type='recall'`, `constraint_value=0.7`, `utility_metric='precision'`, `target=0.6`

### 2.1.1 Configure and Run Optimization

In [None]:
from src.data_tuning import DataTuner
from src.data_creation import DataGenerator
from src.feature_engineering import DataPreprocessor
from src.utils.config import load_data_config, load_preprocessing_config
from src.utils.logging import configure_logging, set_verbosity
import yaml

# ===== CONFIGURE YOUR OPTIMIZATION OBJECTIVE =====
# Choose ONE of the three objectives below:

# Option 1: Precision@K (Alert Budget)
# CONSTRAINT_TYPE = 'K'
# CONSTRAINT_VALUE = 100  # Top 100 alerts
# UTILITY_METRIC = 'precision'
# TARGET = 0.02  # only 2% are true positives out of the K reported ones

#Option 2: Recall at FPR (Regulatory Constraint)
# CONSTRAINT_TYPE = 'fpr'
# CONSTRAINT_VALUE = 0.01  # FPR ≤ 1%
# UTILITY_METRIC = 'recall'
# TARGET = 0.02  # 2% recall

# Option 3: Precision at Recall (Coverage Constraint)
CONSTRAINT_TYPE = 'recall'
CONSTRAINT_VALUE = 0.5  # Recall ≥ 50%
UTILITY_METRIC = 'precision'
TARGET = 0.02  # 2% precision

N_TRIALS = 10  # Number of data optimization trials
N_MODEL_TRIALS = 20  # Number of model hyperparameter trials per data configuration
# ==================================================

print("=" * 60)
print("DATA OPTIMIZATION")
print("=" * 60)
if CONSTRAINT_TYPE == 'K':
    print(f"\nObjective: Optimize {UTILITY_METRIC} in top {int(CONSTRAINT_VALUE)} reported alerts")
elif CONSTRAINT_TYPE == 'fpr':
    print(f"\nObjective: Optimize {UTILITY_METRIC} at FPR ≤ {CONSTRAINT_VALUE}")
elif CONSTRAINT_TYPE == 'recall':
    print(f"\nObjective: Optimize {UTILITY_METRIC} at Recall ≥ {CONSTRAINT_VALUE}")
print(f"Target: {TARGET:.1%} {UTILITY_METRIC}")
print(f"Data optimization trials: {N_TRIALS}")
print(f"Model hyperparameter trials per data config: {N_MODEL_TRIALS}")
print()

# Set up paths
data_yaml_path = str(config_dir / 'data.yaml')

# Load configs with auto-discovered paths
data_config = load_data_config(data_yaml_path)
preproc_config = load_preprocessing_config(str(config_dir / 'preprocessing.yaml'))

# Write config back with absolute paths
with open(data_yaml_path, 'w') as f:
    yaml.dump(data_config, f, default_flow_style=False, sort_keys=False)

# Load model config
with open(config_dir / 'models.yaml', 'r') as f:
    models_config = yaml.safe_load(f)

# Set up optimization directory
bo_dir = experiment_root / 'data_tuning'
os.makedirs(bo_dir, exist_ok=True)

# Create tuning config
tuning_config = {
    'preprocess': preproc_config.copy(),
    **models_config
}
tuning_config['preprocess']['preprocessed_data_dir'] = str(bo_dir / 'preprocessed')
tuning_config['DecisionTreeClassifier']['default']['device'] = 'cpu'

print(f"Optimizer will directly modify: {data_yaml_path}\n")

# Initialize generator and preprocessor
# NOTE: No verbose parameter needed - the optimizer controls logging automatically
generator = DataGenerator(data_yaml_path)
preprocessor = DataPreprocessor(preproc_config)

# Initialize data tuner
print("Initializing data optimizer...")
print("Two-phase generation:")
print("  1. Generate baseline (normal accounts + demographics) - ONCE")
print("  2. For each trial:")
print("     a. Inject alerts from baseline with trial's ML selector weights")
print("     b. Run temporal simulation")
print("     c. Preprocess into features")
print(f"     d. Tune DecisionTreeClassifier hyperparameters ({N_MODEL_TRIALS} trials)")
print(f"     e. Measure {UTILITY_METRIC} and feature stability")
print()

# The DataTuner automatically silences data generation during optimization trials
# but still shows trial progress (Trial 1/N: precision=0.xxx, loss=0.xxx)
tuner = DataTuner(
    data_conf_file=data_yaml_path,
    config=tuning_config,
    generator=generator,
    preprocessor=preprocessor,
    target=TARGET,
    constraint_type=CONSTRAINT_TYPE,
    constraint_value=CONSTRAINT_VALUE,
    utility_metric=UTILITY_METRIC,
    model='DecisionTreeClassifier',
    bo_dir=str(bo_dir),
    seed=42,
    num_trials_model=N_MODEL_TRIALS
)

# Run optimization
print(f"Starting Bayesian optimization ({N_TRIALS} data trials × {N_MODEL_TRIALS} model trials each)...")
print("This may take some time depending on data size and number of trials.\n")

best_trials = tuner(n_trials=N_TRIALS)

print("\n" + "=" * 60)
print("OPTIMIZATION COMPLETE")
print("=" * 60)
print(f"\nFound {len(best_trials)} Pareto-optimal solutions")
print(f"Results saved to: {bo_dir}")
print(f"  - pareto_front.png: Visualization of trade-offs")
print(f"  - best_trials.txt: Parameter values for each solution")
print(f"  - data_tuning_study.db: Full optimization history")
print(f"\nNote: {data_yaml_path} now contains parameters from the last trial.")

### 2.1.2 Visualize Pareto Front and Select Solution

The optimizer returns multiple Pareto-optimal solutions. Each represents a different trade-off:
- **Low FPR loss** = data closely matches target FPR
- **Low feature importance loss** = more stable/consistent features

Let's visualize the results and select the best trial.

In [None]:
# Display Pareto front
from IPython.display import Image, display

pareto_front_path = bo_dir / 'pareto_front.png'
if pareto_front_path.exists():
    display(Image(filename=str(pareto_front_path)))
else:
    print("Pareto front plot not found")

# Show best trials
print("\n" + "=" * 60)
print("PARETO-OPTIMAL SOLUTIONS")
print("=" * 60)

for i, trial in enumerate(best_trials):
    utility_loss = trial.values[0]
    importance_loss = trial.values[1]
    
    # Get achieved utility from trial attributes
    achieved_utility = trial.user_attrs.get('utility_metric', TARGET)
    
    print(f"\nTrial {trial.number} (Solution #{i+1}):")
    print(f"  {UTILITY_METRIC.capitalize()}: {achieved_utility:.4f} (target: {TARGET:.4f}, loss: {utility_loss:.4f})")
    print(f"  Feature stability loss: {importance_loss:.4f}")
    print(f"  Parameters:")
    for param, value in trial.params.items():
        if isinstance(value, float):
            print(f"    {param}: {value:.4f}")
        else:
            print(f"    {param}: {value}")

print("\n" + "=" * 60)
print("RECOMMENDATION")
print("=" * 60)
print("\nTo select a trial:")
print(f"  1. Low {UTILITY_METRIC} loss = better match to target")
print("  2. Low importance loss = more stable features")
print("  3. Balance both for general use")
print(f"\nTrial {best_trials[0].number} has the best balance")
print(f"Run next cell to apply it (or choose a different trial number)")

### 2.1.3 Apply Selected Trial and Regenerate Data

Choose which trial to use and regenerate the data with optimized parameters.

In [15]:
# ===== SELECT WHICH TRIAL TO USE =====
SELECTED_TRIAL = best_trials[0].number  # Default: use first (best balance)
# You can change this to any trial number from the list above
# ======================================

print("=" * 60)
print("APPLYING OPTIMIZED PARAMETERS")
print("=" * 60)

# Update data.yaml with selected trial's parameters
tuner.optimizer.update_config_with_trial(SELECTED_TRIAL)

# Regenerate data with optimized parameters
print("\nRegenerating data with optimized parameters...")
print("This may take a few minutes...\n")

# Reload config and regenerate
data_config_optimized = load_data_config(data_yaml_path)
with open(data_yaml_path, 'w') as f:
    yaml.dump(data_config_optimized, f, default_flow_style=False, sort_keys=False)

# Reinitialize generator with updated config
generator = DataGenerator(data_yaml_path)
optimized_tx_log = generator()

print(f"\n✓ Optimized transaction data generated: {optimized_tx_log}")

# Load and inspect the optimized data
df_optimized = pd.read_parquet(optimized_tx_log)
print(f"\nGenerated {len(df_optimized):,} transactions")
print(f"Banks: {df_optimized['bankOrig'].unique().tolist()}")
print(f"SAR transactions: {df_optimized['isSAR'].sum():,} ({df_optimized['isSAR'].mean()*100:.2f}%)")
print(f"Time range: steps {df_optimized['step'].min()} to {df_optimized['step'].max()}")

# Compare with original (if available)
if 'df' in globals():
    print("\n" + "=" * 60)
    print("COMPARISON: Original vs Optimized")
    print("=" * 60)
    print(f"  Total transactions: {len(df):,} → {len(df_optimized):,}")
    print(f"  SAR rate: {df['isSAR'].mean()*100:.2f}% → {df_optimized['isSAR'].mean()*100:.2f}%")

print("\n✅ Data optimization complete!")
if CONSTRAINT_TYPE == 'K':
    print(f"The optimized data should achieve ~{TARGET:.1%} {UTILITY_METRIC} in top {int(CONSTRAINT_VALUE)} alerts.")
elif CONSTRAINT_TYPE == 'fpr':
    print(f"The optimized data should achieve ~{TARGET:.1%} {UTILITY_METRIC} at FPR ≤ {CONSTRAINT_VALUE}.")
elif CONSTRAINT_TYPE == 'recall':
    print(f"The optimized data should achieve ~{TARGET:.1%} {UTILITY_METRIC} at Recall ≥ {CONSTRAINT_VALUE}.")
print(f"\nProceed to Step 3 to preprocess this optimized data.")

---

**Note about Step 2.1:**
- **If you ran optimization**: The data has been regenerated with optimized parameters. Proceed to Step 3.
- **If you skipped optimization**: You can proceed directly to Step 3 with the original data from Step 2.

The optimization step tunes two categories of parameters:

**1. Temporal parameters** (transaction behaviors):
- SAR amounts, spending patterns, behavioral timing

**2. ML Selector parameters** (account selection for SAR patterns):
- Structure weights: degree, betweenness, pagerank centrality
- KYC weights: balance, salary, age
- Participation decay for multi-pattern accounts

This is particularly useful when:
- **Precision@K**: You have a fixed alert review budget and need high precision
- **Recall at FPR**: Regulatory requirements constrain false positive rates
- **Precision at Recall**: You must maintain minimum SAR detection coverage

---

## Step 3: Preprocess Data for ML

Convert raw transactions into ML-ready features with graph structure.

**Convention**: Creates both centralized and per-client datasets automatically!

In [None]:
from src.feature_engineering import DataPreprocessor, summarize_dataset
from src.utils.config import load_preprocessing_config
import warnings

warnings.filterwarnings('ignore', module='matplotlib.text')

# Load config with auto-discovered paths
preproc_config = load_preprocessing_config(str(config_dir / 'preprocessing.yaml'))

print("Preprocessing configuration:")
print(f"  Raw data: {preproc_config['raw_data_file']}")
print(f"  Output dir: {preproc_config['preprocessed_data_dir']}")
print(f"  Learning mode: {preproc_config.get('learning_mode', 'inductive')}")
print()

preprocessor = DataPreprocessor(preproc_config)
datasets = preprocessor(preproc_config['raw_data_file'])

# =============================================================================
# SAVE PREPROCESSED DATA
# =============================================================================
preprocessed_dir = Path(preproc_config['preprocessed_data_dir'])
os.makedirs(preprocessed_dir / 'centralized', exist_ok=True)
os.makedirs(preprocessed_dir / 'clients', exist_ok=True)

# Save centralized datasets
print("\nSaving centralized datasets:")
for name, dataset in datasets.items():
    dataset.to_parquet(preprocessed_dir / 'centralized' / f'{name}.parquet', index=False)
    print(f"  {name}: {len(dataset):,} samples")

# Save per-client datasets
banks = datasets['trainset_nodes']['bank'].unique()
print(f"\nSaving per-client datasets for {len(banks)} banks:")

for bank in banks:
    bank_dir = preprocessed_dir / 'clients' / str(bank)
    os.makedirs(bank_dir, exist_ok=True)
    
    for split in ['trainset', 'valset', 'testset']:
        # Save nodes
        df_nodes = datasets[f'{split}_nodes']
        bank_nodes = df_nodes[df_nodes['bank'] == bank]
        bank_nodes.to_parquet(bank_dir / f'{split}_nodes.parquet', index=False)
        
        # Save edges (intra-bank only)
        if f'{split}_edges' in datasets:
            bank_accounts = set(bank_nodes['account'])
            df_edges = datasets[f'{split}_edges']
            bank_edges = df_edges[
                (df_edges['src'].isin(bank_accounts)) & 
                (df_edges['dst'].isin(bank_accounts))
            ]
            bank_edges.to_parquet(bank_dir / f'{split}_edges.parquet', index=False)
    
    print(f"  ✓ {bank}")

# Generate summary
print("\nGenerating dataset summary...")
summarize_dataset(str(preprocessed_dir), raw_data_file=preproc_config['raw_data_file'])

# =============================================================================
# FEATURE ANALYSIS PLOTS
# =============================================================================
print("\n" + "=" * 60)
print("FEATURE ANALYSIS")
print("=" * 60)

train_nodes = datasets.get('trainset_nodes')
if train_nodes is not None:
    preprocessor.plot_feature_analysis(train_nodes, str(preprocessed_dir))
    
    from IPython.display import Image, display
    import glob
    
    feature_plots = sorted(glob.glob(str(preprocessed_dir / 'features_*.png')))
    for plot_path in feature_plots:
        group_name = Path(plot_path).stem.replace('features_', '').replace('_', ' ').title()
        print(f"\n{group_name}:")
        display(Image(filename=plot_path, width=800))
    
    city_plot = preprocessed_dir / 'city_analysis.png'
    if city_plot.exists():
        print("\nCity distribution:")
        display(Image(filename=str(city_plot), width=900))

print("\n✅ Preprocessing complete!")


## Step 4: Train ML Models

Train a Graph Neural Network (GraphSAGE) in three different settings.

**Convention**: All paths and clients are auto-discovered!

### 4.1 Centralized Training

Train on all data combined (baseline).

In [None]:
from src.ml.training.centralized import centralized
from src.ml import clients, models
from src.utils.config import load_training_config
import matplotlib.pyplot as plt
import yaml

# =============================================================================
# CENTRALIZED TRAINING - All Models
# =============================================================================
print("=" * 60)
print("CENTRALIZED TRAINING - All Models")
print("=" * 60)

# Load models config to get all available models
with open(config_dir / 'models.yaml', 'r') as f:
    models_config = yaml.safe_load(f)

# All models to evaluate
all_models = ['DecisionTreeClassifier', 'RandomForestClassifier', 'GradientBoostingClassifier',
              'LogisticRegressor', 'MLP', 'GCN', 'GAT', 'GraphSAGE']

results_all = {}

for model_name in all_models:
    if model_name not in models_config:
        print(f"\n⚠ {model_name} not in config, skipping")
        continue
    
    print(f"\n{'─'*60}")
    print(f"Training: {model_name}")
    print(f"{'─'*60}")
    
    config = load_training_config(
        str(config_dir / 'models.yaml'),
        model_name,
        setting='centralized'
    )
    config['device'] = 'cpu'
    
    Client = getattr(clients, config['client_type'])
    Model = getattr(models, model_name)
    
    results = centralized(seed=42, Client=Client, Model=Model, **config)
    
    # Save results
    results_dir = experiment_root / 'results' / 'centralized' / model_name
    os.makedirs(results_dir, exist_ok=True)
    with open(results_dir / 'results.pkl', 'wb') as f:
        pickle.dump(results, f)
    
    metrics = list(results.values())[0]['testset']
    results_all[model_name] = metrics
    
    print(f"  Test AP: {metrics['average_precision'][-1]:.4f}, F1: {metrics['f1'][-1]:.4f}")

# =============================================================================
# COMPARISON: Precision-Recall Curves
# =============================================================================
print("\n" + "=" * 60)
print("PR CURVE COMPARISON - All Models")
print("=" * 60)

# Colors for different model types
colors = {
    'DecisionTreeClassifier': '#1f77b4',
    'RandomForestClassifier': '#2ca02c', 
    'GradientBoostingClassifier': '#9467bd',
    'LogisticRegressor': '#ff7f0e',
    'MLP': '#d62728',
    'GCN': '#8c564b',
    'GAT': '#e377c2',
    'GraphSAGE': '#17becf'
}

fig, ax = plt.subplots(figsize=(12, 8))

for model_name, metrics in results_all.items():
    precision, recall, _ = metrics['precision_recall_curve']
    ap = metrics['average_precision'][-1]
    color = colors.get(model_name, '#333333')
    ax.plot(recall, precision, linewidth=2, color=color,
            label=f'{model_name} (AP={ap:.3f})')

ax.set_xlabel('Recall', fontsize=12)
ax.set_ylabel('Precision', fontsize=12)
ax.set_title('Precision-Recall Curves: All Models (Centralized)', fontsize=14)
ax.legend(loc='upper right', fontsize=10)
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.grid(True, alpha=0.3)
plt.tight_layout()

pr_curve_path = experiment_root / 'results' / 'centralized' / 'pr_curve_all_models.png'
plt.savefig(pr_curve_path, dpi=150)
plt.show()

# Print summary table
print(f"\n{'Model':<30} {'Avg Precision':<15} {'F1 Score':<12}")
print("-" * 57)
for model_name, metrics in sorted(results_all.items(), key=lambda x: -x[1]['average_precision'][-1]):
    ap = metrics['average_precision'][-1]
    f1 = metrics['f1'][-1]
    print(f"{model_name:<30} {ap:<15.4f} {f1:<12.4f}")

print("\n✅ Centralized training complete!")

### 4.2 Federated Training

Privacy-preserving collaborative learning across banks.

In [None]:
from src.ml.training.federated import federated
from src.ml import servers

print("=" * 60)
print("FEDERATED TRAINING - GraphSAGE")
print("=" * 60)

# NOTE: Even if models.yaml contains multiple model configurations,
# we only load GraphSAGE here by specifying it as the model_type parameter
config = load_training_config(
    str(config_dir / 'models.yaml'),
    'GraphSAGE',  # Selects only GraphSAGE config from models.yaml
    setting='federated'
)

# Force CPU device (override config if needed)
config['device'] = 'cpu'

print(f"\nAuto-discovered clients: {list(config['clients'].keys())}")
for client_id, client_config in config['clients'].items():
    print(f"  {client_id}:")
    print(f"    Nodes: {client_config['trainset_nodes']}")

print(f"\nDevice: {config['device']}")
print()

# NOTE: input_dim is automatically detected from the data by the client
# No need to manually calculate it!

# Train GraphSAGE model with federated learning
Server = getattr(servers, config.get('server_type', 'TorchServer'))
Client = getattr(clients, config['client_type'])
Model = getattr(models, 'GraphSAGE')

results_federated = federated(
    seed=42,
    Server=Server,
    Client=Client,
    Model=Model,
    n_workers=2,
    **config
)

# Save results
results_dir = experiment_root / 'results' / 'federated' / 'GraphSAGE'
os.makedirs(results_dir, exist_ok=True)
with open(results_dir / 'results.pkl', 'wb') as f:
    pickle.dump(results_federated, f)

# Display final metrics
print("\nFinal Metrics (per client):")
for client_id, metrics in results_federated.items():
    print(f"\n{client_id}:")
    print(f"  Test Average Precision: {metrics['testset']['average_precision'][-1]:.4f}")
    print(f"  Test F1: {metrics['testset']['f1'][-1]:.4f}")

print(f"\n✓ Results saved to: {results_dir}")

### 4.3 Isolated Training

Each bank trains independently on their own data.

In [None]:
from src.ml.training.isolated import isolated

print("=" * 60)
print("ISOLATED TRAINING - GraphSAGE")
print("=" * 60)

# NOTE: Even if models.yaml contains multiple model configurations,
# we only load GraphSAGE here by specifying it as the model_type parameter
config = load_training_config(
    str(config_dir / 'models.yaml'),
    'GraphSAGE',  # Selects only GraphSAGE config from models.yaml
    setting='isolated'
)

# Force CPU device (override config if needed)
config['device'] = 'cpu'

print(f"\nTraining isolated GraphSAGE models for: {list(config['clients'].keys())}")
print(f"Device: {config['device']}")
print()

# NOTE: input_dim is automatically detected from the data by the client
# No need to manually calculate it!

# Train GraphSAGE models independently
Client = getattr(clients, config['client_type'])
Model = getattr(models, 'GraphSAGE')

results_isolated = isolated(
    seed=42,
    Client=Client,
    Model=Model,
    n_workers=2,
    **config
)

# Save results
results_dir = experiment_root / 'results' / 'isolated' / 'GraphSAGE'
os.makedirs(results_dir, exist_ok=True)
with open(results_dir / 'results.pkl', 'wb') as f:
    pickle.dump(results_isolated, f)

# Display final metrics
print("\nFinal Metrics (per client):")
for client_id, metrics in results_isolated.items():
    print(f"\n{client_id}:")
    print(f"  Test Average Precision: {metrics['testset']['average_precision'][-1]:.4f}")
    print(f"  Test F1: {metrics['testset']['f1'][-1]:.4f}")

print(f"\n✓ Results saved to: {results_dir}")
print("\n✅ All training complete!")

## Step 5: Visualize Results

Compare performance across the three training settings.

### 5.1 Generate Plots

In [None]:
from src.visualize import plot_metrics
from src.visualize.utils import discover_results, load_results

print("=" * 60)
print("VISUALIZATION")
print("=" * 60)

# Auto-discover all results
results_files = discover_results(experiment_root)

print(f"\nFound {len(results_files)} results file(s):")
for key in results_files:
    print(f"  - {key}")

# Generate plots for each
print("\nGenerating plots...")
for key, results_file in results_files.items():
    print(f"  Processing {key}...")
    data = load_results(results_file)
    output_dir = results_file.parent
    
    plot_metrics(
        data,
        str(output_dir),
        metrics=['average_precision', 'f1', 'loss'],
        clients=None,
        datasets=['trainset', 'valset', 'testset'],
        reduction='mean',
        formats=['png']
    )
    print(f"    Saved plots to: {output_dir}/png/")

print("\n✓ All plots generated!")

### 5.2 Compare Results

In [None]:
# Load all results and compare
import numpy as np

comparison = {}

for setting in ['centralized', 'federated', 'isolated']:
    results_file = experiment_root / 'results' / setting / 'GraphSAGE' / 'results.pkl'
    if results_file.exists():
        with open(results_file, 'rb') as f:
            results = pickle.load(f)
        
        # Average across clients
        avg_ap = np.mean([r['testset']['average_precision'][-1] for r in results.values()])
        avg_f1 = np.mean([r['testset']['f1'][-1] for r in results.values()])
        
        comparison[setting] = {
            'Average Precision': avg_ap,
            'F1 Score': avg_f1
        }

# Display comparison
print("\n" + "=" * 60)
print("PERFORMANCE COMPARISON")
print("=" * 60)
print(f"\n{'Setting':<15} {'Avg Precision':<15} {'F1 Score':<15}")
print("-" * 45)
for setting, metrics in comparison.items():
    print(f"{setting.capitalize():<15} {metrics['Average Precision']:<15.4f} {metrics['F1 Score']:<15.4f}")

# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

settings = list(comparison.keys())
ap_scores = [comparison[s]['Average Precision'] for s in settings]
f1_scores = [comparison[s]['F1 Score'] for s in settings]

axes[0].bar(settings, ap_scores, color=['#1f77b4', '#ff7f0e', '#2ca02c'])
axes[0].set_ylabel('Average Precision')
axes[0].set_title('Average Precision by Training Setting')
axes[0].set_ylim(0, 1)

axes[1].bar(settings, f1_scores, color=['#1f77b4', '#ff7f0e', '#2ca02c'])
axes[1].set_ylabel('F1 Score')
axes[1].set_title('F1 Score by Training Setting')
axes[1].set_ylim(0, 1)

plt.tight_layout()

# Save comparison plot
comparison_dir = experiment_root / 'visualizations' / 'comparison'
os.makedirs(comparison_dir, exist_ok=True)
plt.savefig(comparison_dir / 'setting_comparison.png', dpi=150)
print(f"\n✓ Comparison plot saved to: {comparison_dir / 'setting_comparison.png'}")
plt.show()

### 6 Interactive Network Explorer Dashboard

For interactive exploration of the data with filtering by banks, transaction models, and time ranges, use the standalone dashboard:

```bash
# Install network-explorer dependencies if not already installed
uv sync --extra network-explorer

# Run the dashboard
uv run python src/visualize/transaction_network_explorer/dashboard.py --experiment tutorial_demo
```

Then open http://localhost:5006 in your browser.

The dashboard provides:
- **Bank Filtering**: Select which banks to include
- **Transaction Model Filtering**: Choose specific laundering and legitimate patterns
- **Time Range Selection**: Focus on specific time periods
- **Multiple Visualizations**: Amount distributions, degree distributions, and temporal patterns

All visualizations update in real-time as you adjust the filters.

## Summary

Congratulations! You've completed a full workflow through the AMLGentex framework:

✅ **Created** an experiment with minimal configuration  
✅ **Generated** synthetic AML transaction data  
✅ **Preprocessed** transactions into ML-ready features  
✅ **Trained** models in three settings (centralized, federated, isolated)  
✅ **Visualized** and compared results  
✅ **Explored** transactions with interactive widgets  

### Key Takeaways

1. **Convention over Configuration**: By following standard directory structure, the framework auto-discovers paths and clients
2. **Privacy-Preserving Learning**: Federated learning enables collaboration without sharing raw data
3. **Flexible Architecture**: Easy to compare different training paradigms
4. **Scalable**: Framework handles small demos to large-scale experiments

### Next Steps

- **Experiment with different models**: Try GCN, GAT, MLP, or tree-based models
- **Optimize hyperparameters**: Use `scripts/tune_hyperparams.py` for Bayesian optimization
- **Scale up**: Create larger experiments with more accounts and banks
- **Customize patterns**: Modify SAR transaction patterns in config files
- **Analyze results**: Use the visualization tools to understand model behavior

### Documentation

- **ML Module**: `src/ml/README.md`
- **Scripts**: `scripts/README.md`
- **Project**: `README.md`

### Command-Line Usage

All steps can also be run from the command line:

```bash
# Generate data
python scripts/generate.py --conf_file experiments/tutorial_demo/config/data.yaml

# Preprocess
python scripts/preprocess.py --config experiments/tutorial_demo/config/preprocessing.yaml

# Train (centralized)
python -m src.ml.training.centralized \
  --config experiments/tutorial_demo/config/models.yaml \
  --model_type GraphSAGE

# Visualize
python scripts/plot.py --experiment tutorial_demo
```