# Address-Aware GNN for Cryptographic Function Detection

This notebook provides an interactive interface for:
1. Data exploration and visualization
2. Model training and evaluation
3. Inference on new binaries
4. Analysis of address-based features

In [None]:
# Imports
import json
import glob
import os
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from tqdm.notebook import tqdm

import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# Import our GNN modules
from new_gnn import (
    GraphDataset, AddressAwareGNN, HierarchicalGNN,
    GNNTrainer, CryptoDetectionPipeline, collate_fn,
    AddressFeatureExtractor
)

# Plotting configuration
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
%matplotlib inline

## Part 1: Data Exploration

In [None]:
# Load all JSON files
data_dir = '/home/bhoomi/Desktop/compilerRepo/vestigo-data/ghidra_json'
json_files = glob.glob(os.path.join(data_dir, '*.json'))

print(f"Found {len(json_files)} JSON files")
print("\nSample files:")
for f in json_files[:5]:
    print(f"  - {os.path.basename(f)}")

In [None]:
# Analyze label distribution
all_labels = []
all_addresses = []
graph_complexities = []

for json_file in tqdm(json_files[:100], desc="Analyzing data"):
    try:
        with open(json_file, 'r') as f:
            data = json.load(f)
        
        for func in data['functions']:
            if 'label' in func:
                all_labels.append(func['label'])
                all_addresses.append(func['address'])
                graph_complexities.append(func.get('graph_level', {}).get('cyclomatic_complexity', 0))
    except:
        continue

print(f"Total functions analyzed: {len(all_labels)}")

In [None]:
# Visualize label distribution
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Bar chart
label_counts = Counter(all_labels)
labels, counts = zip(*label_counts.most_common())

axes[0].barh(range(len(labels)), counts, color='steelblue')
axes[0].set_yticks(range(len(labels)))
axes[0].set_yticklabels(labels)
axes[0].set_xlabel('Count')
axes[0].set_title('Label Distribution', fontweight='bold', fontsize=14)
axes[0].grid(axis='x', alpha=0.3)

# Pie chart (top 10)
top_10 = label_counts.most_common(10)
labels_pie, counts_pie = zip(*top_10)
axes[1].pie(counts_pie, labels=labels_pie, autopct='%1.1f%%', startangle=90)
axes[1].set_title('Top 10 Label Distribution', fontweight='bold', fontsize=14)

plt.tight_layout()
plt.savefig('label_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Analyze complexity distribution by crypto type
df = pd.DataFrame({
    'label': all_labels,
    'complexity': graph_complexities
})

# Filter to crypto functions only
crypto_df = df[df['label'] != 'Non-Crypto']

plt.figure(figsize=(14, 6))
sns.boxplot(data=crypto_df, x='label', y='complexity', palette='Set2')
plt.xticks(rotation=45, ha='right')
plt.xlabel('Crypto Algorithm', fontsize=12)
plt.ylabel('Cyclomatic Complexity', fontsize=12)
plt.title('Complexity Distribution by Crypto Algorithm', fontweight='bold', fontsize=14)
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig('complexity_by_algorithm.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nComplexity statistics by algorithm:")
print(crypto_df.groupby('label')['complexity'].describe())

## Part 2: Address Feature Analysis

In [None]:
# Analyze address patterns
sample_file = json_files[0]
with open(sample_file, 'r') as f:
    sample_data = json.load(f)

# Extract address features for all functions
address_features_list = []

for func in sample_data['functions'][:50]:
    for node in func.get('node_level', []):
        addr_features = AddressFeatureExtractor.extract_address_features(node['address'])
        addr_features['label'] = func.get('label', 'Unknown')
        address_features_list.append(addr_features)

addr_df = pd.DataFrame(address_features_list)
print("Address features extracted:", addr_df.shape)
addr_df.head()

In [None]:
# Visualize address features
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()

features_to_plot = [
    'addr_entropy',
    'addr_alignment_4',
    'addr_alignment_16',
    'addr_ones_ratio',
    'addr_nibble_variety',
    'is_text_section'
]

for i, feature in enumerate(features_to_plot):
    crypto_only = addr_df[addr_df['label'] != 'Non-Crypto']
    
    axes[i].hist(crypto_only[feature], bins=30, alpha=0.6, label='Crypto', color='orange')
    axes[i].hist(addr_df[addr_df['label'] == 'Non-Crypto'][feature], bins=30, alpha=0.6, label='Non-Crypto', color='blue')
    axes[i].set_xlabel(feature)
    axes[i].set_ylabel('Frequency')
    axes[i].set_title(f'{feature} Distribution')
    axes[i].legend()
    axes[i].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('address_features_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

## Part 3: Model Training

In [None]:
# Prepare datasets
train_files, test_files = train_test_split(json_files, test_size=0.15, random_state=42)
train_files, val_files = train_test_split(train_files, test_size=0.15/0.85, random_state=42)

print(f"Train: {len(train_files)} files")
print(f"Val: {len(val_files)} files")
print(f"Test: {len(test_files)} files")

In [None]:
# Load datasets
print("Loading training data...")
train_dataset = GraphDataset(train_files)

print("\nLoading validation data...")
val_dataset = GraphDataset(val_files, train_dataset.label_encoder)
val_dataset.node_scaler = train_dataset.node_scaler
val_dataset.edge_scaler = train_dataset.edge_scaler
val_dataset.graph_scaler = train_dataset.graph_scaler

print("\nLoading test data...")
test_dataset = GraphDataset(test_files, train_dataset.label_encoder)
test_dataset.node_scaler = train_dataset.node_scaler
test_dataset.edge_scaler = train_dataset.edge_scaler
test_dataset.graph_scaler = train_dataset.graph_scaler

print(f"\nDatasets loaded successfully!")
print(f"Classes: {train_dataset.label_encoder.classes_}")

In [None]:
# Create data loaders
BATCH_SIZE = 32

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

In [None]:
# Build model
sample = train_dataset[0]
num_node_features = sample.x.shape[1]
num_edge_features = sample.edge_attr.shape[1] if sample.edge_attr.numel() > 0 else 0
num_graph_features = sample.graph_features.shape[0]
num_classes = len(train_dataset.label_encoder.classes_)

print(f"Node features: {num_node_features}")
print(f"Edge features: {num_edge_features}")
print(f"Graph features: {num_graph_features}")
print(f"Number of classes: {num_classes}")

model = AddressAwareGNN(
    num_node_features=num_node_features,
    num_edge_features=num_edge_features,
    num_graph_features=num_graph_features,
    num_classes=num_classes,
    hidden_dim=256,
    num_layers=4,
    dropout=0.3,
    conv_type='gat',
    pooling='concat'
)

print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Train model
trainer = GNNTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    label_encoder=train_dataset.label_encoder,
    lr=0.001,
    weight_decay=1e-4
)

# Train for specified epochs
NUM_EPOCHS = 50
trainer.train(num_epochs=NUM_EPOCHS, save_dir='./gnn_models')

In [None]:
# Plot training history
trainer.plot_training_history('./training_history.png')

## Part 4: Model Evaluation

In [None]:
# Load best model
checkpoint = torch.load('./gnn_models/best_model.pth', map_location=trainer.device)
trainer.model.load_state_dict(checkpoint['model_state_dict'])

# Evaluate on test set
test_results = trainer.test()

In [None]:
# Plot confusion matrix
trainer.plot_confusion_matrix(test_results['confusion_matrix'], './confusion_matrix.png')

In [None]:
# Analyze per-class performance
from sklearn.metrics import precision_recall_fscore_support

precision, recall, f1, support = precision_recall_fscore_support(
    test_results['labels'],
    test_results['predictions'],
    labels=range(num_classes)
)

performance_df = pd.DataFrame({
    'Class': train_dataset.label_encoder.classes_,
    'Precision': precision,
    'Recall': recall,
    'F1-Score': f1,
    'Support': support
}).sort_values('F1-Score', ascending=False)

print("\nPer-class performance:")
print(performance_df.to_string(index=False))

# Visualize
fig, ax = plt.subplots(figsize=(14, 6))
x = np.arange(len(performance_df))
width = 0.25

ax.bar(x - width, performance_df['Precision'], width, label='Precision', alpha=0.8)
ax.bar(x, performance_df['Recall'], width, label='Recall', alpha=0.8)
ax.bar(x + width, performance_df['F1-Score'], width, label='F1-Score', alpha=0.8)

ax.set_xlabel('Crypto Algorithm', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('Per-Class Performance Metrics', fontweight='bold', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels(performance_df['Class'], rotation=45, ha='right')
ax.legend()
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig('per_class_performance.png', dpi=300, bbox_inches='tight')
plt.show()

## Part 5: Inference Pipeline

In [None]:
# Save metadata for inference
metadata = {
    'label_encoder': train_dataset.label_encoder,
    'node_scaler': train_dataset.node_scaler,
    'edge_scaler': train_dataset.edge_scaler,
    'graph_scaler': train_dataset.graph_scaler,
    'model_config': {
        'num_node_features': num_node_features,
        'num_edge_features': num_edge_features,
        'num_graph_features': num_graph_features,
        'num_classes': num_classes,
        'hidden_dim': 256,
        'num_layers': 4,
        'dropout': 0.3,
        'conv_type': 'gat',
        'pooling': 'concat',
    }
}

with open('./gnn_models/metadata.pkl', 'wb') as f:
    pickle.dump(metadata, f)

print("✓ Metadata saved for inference")

In [None]:
# Run inference on a test file
pipeline = CryptoDetectionPipeline(
    model_path='./gnn_models/best_model.pth',
    metadata_path='./gnn_models/metadata.pkl'
)

# Test on a sample file
test_file = test_files[0]
results = pipeline.process_json(test_file, './detection_results.json')

In [None]:
# Visualize detection results
if results['crypto_functions']:
    # Algorithm distribution
    algo_counts = Counter([f['algorithm'] for f in results['crypto_functions']])
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Bar chart
    algos, counts = zip(*algo_counts.most_common())
    axes[0].barh(range(len(algos)), counts, color='teal')
    axes[0].set_yticks(range(len(algos)))
    axes[0].set_yticklabels(algos)
    axes[0].set_xlabel('Count')
    axes[0].set_title('Detected Crypto Algorithms', fontweight='bold')
    axes[0].grid(axis='x', alpha=0.3)
    
    # Confidence distribution
    confidences = [f['confidence'] for f in results['crypto_functions']]
    axes[1].hist(confidences, bins=20, color='coral', edgecolor='black')
    axes[1].set_xlabel('Confidence Score')
    axes[1].set_ylabel('Frequency')
    axes[1].set_title('Detection Confidence Distribution', fontweight='bold')
    axes[1].axvline(np.mean(confidences), color='red', linestyle='--', label=f'Mean: {np.mean(confidences):.3f}')
    axes[1].legend()
    axes[1].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('detection_results_visualization.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print top detections
    print("\nTop 10 crypto functions by confidence:")
    for i, func in enumerate(results['crypto_functions'][:10], 1):
        print(f"{i}. {func['address']} - {func['algorithm']} (confidence: {func['confidence']:.4f})")
else:
    print("No crypto functions detected in this file.")

## Part 6: Feature Importance Analysis

In [None]:
# Analyze which features are most important
# We'll use gradient-based feature attribution

def compute_feature_importance(model, data_loader, device):
    """
    Compute feature importance using gradient magnitude.
    """
    model.eval()
    model.to(device)
    
    feature_gradients = []
    
    for batch in tqdm(data_loader, desc="Computing importance"):
        batch = batch.to(device)
        batch.x.requires_grad = True
        
        output = model(batch)
        
        # Get prediction
        pred = output.argmax(dim=1)
        
        # Compute gradient w.r.t. input features
        loss = output[range(len(pred)), pred].sum()
        loss.backward()
        
        # Store gradient magnitudes
        feature_gradients.append(batch.x.grad.abs().mean(dim=0).cpu().detach().numpy())
    
    # Average across batches
    avg_importance = np.mean(feature_gradients, axis=0)
    
    return avg_importance

# Compute feature importance
importance = compute_feature_importance(trainer.model, val_loader, trainer.device)

# Visualize top features
top_k = 30
top_indices = np.argsort(importance)[-top_k:][::-1]

plt.figure(figsize=(12, 8))
plt.barh(range(top_k), importance[top_indices], color='purple', alpha=0.7)
plt.yticks(range(top_k), [f'Feature {i}' for i in top_indices])
plt.xlabel('Importance (Gradient Magnitude)', fontsize=12)
plt.ylabel('Feature Index', fontsize=12)
plt.title(f'Top {top_k} Most Important Node Features', fontweight='bold', fontsize=14)
plt.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.savefig('feature_importance.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\nTop {top_k} feature indices: {top_indices.tolist()}")

## Part 7: Model Comparison

In [None]:
# Train multiple architectures and compare
architectures = ['gcn', 'gat', 'sage', 'gin']
results_comparison = []

for arch in architectures:
    print(f"\n{'='*60}")
    print(f"Training {arch.upper()} model")
    print('='*60)
    
    model = AddressAwareGNN(
        num_node_features=num_node_features,
        num_edge_features=num_edge_features,
        num_graph_features=num_graph_features,
        num_classes=num_classes,
        hidden_dim=256,
        num_layers=4,
        dropout=0.3,
        conv_type=arch,
        pooling='concat'
    )
    
    trainer = GNNTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        label_encoder=train_dataset.label_encoder,
        lr=0.001,
        weight_decay=1e-4
    )
    
    # Train for fewer epochs for comparison
    trainer.train(num_epochs=30, save_dir=f'./gnn_models/{arch}')
    
    # Evaluate
    test_loss, test_acc, test_f1, _, _ = trainer.evaluate(test_loader)
    
    results_comparison.append({
        'architecture': arch.upper(),
        'test_acc': test_acc,
        'test_f1': test_f1,
        'test_loss': test_loss
    })

# Compare results
comparison_df = pd.DataFrame(results_comparison)
print("\n" + "="*60)
print("ARCHITECTURE COMPARISON")
print("="*60)
print(comparison_df.to_string(index=False))

In [None]:
# Visualize architecture comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

x = np.arange(len(comparison_df))
width = 0.35

# Accuracy
axes[0].bar(x, comparison_df['test_acc'], width, label='Accuracy', color='skyblue')
axes[0].set_xlabel('Architecture', fontsize=12)
axes[0].set_ylabel('Accuracy', fontsize=12)
axes[0].set_title('Test Accuracy by Architecture', fontweight='bold', fontsize=14)
axes[0].set_xticks(x)
axes[0].set_xticklabels(comparison_df['architecture'])
axes[0].grid(axis='y', alpha=0.3)

# F1 Score
axes[1].bar(x, comparison_df['test_f1'], width, label='F1 Score', color='coral')
axes[1].set_xlabel('Architecture', fontsize=12)
axes[1].set_ylabel('F1 Score', fontsize=12)
axes[1].set_title('Test F1 Score by Architecture', fontweight='bold', fontsize=14)
axes[1].set_xticks(x)
axes[1].set_xticklabels(comparison_df['architecture'])
axes[1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('architecture_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

## Summary

This notebook demonstrated:
1. ✓ Data exploration and label distribution analysis
2. ✓ Address-based feature extraction and visualization
3. ✓ GNN model training with comprehensive metrics
4. ✓ Model evaluation and confusion matrix analysis
5. ✓ Inference pipeline for crypto detection
6. ✓ Feature importance analysis
7. ✓ Architecture comparison (GCN, GAT, SAGE, GIN)

**Next steps:**
- Run hyperparameter tuning with `gnn_hyperparameter_tuning.py`
- Deploy model for production inference
- Analyze false positives/negatives for model improvement