# SMT-WEEX Notebook 3: Evaluation & Insights (v2)
**Project:** smt-weex-2025
**Author:** Jannet Ekka

**Updates in v2:**
- Fixed error analysis (CatBoost 2D array issue)
- Added class merging experiment (Institutional + CEX_Wallet -> Large_Holder)
- Production reliability assessment
- Trading signal recommendations

## 1. Setup

In [None]:
!pip install -q catboost xgboost lightgbm scikit-learn pandas numpy matplotlib seaborn

In [None]:
from google.colab import auth
auth.authenticate_user()

PROJECT_ID = 'smt-weex-2025'
BUCKET = 'smt-weex-2025-models'

!gcloud config set project {PROJECT_ID}

In [None]:
import pandas as pd
import numpy as np
import json
import pickle
from collections import Counter

from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    classification_report, confusion_matrix, balanced_accuracy_score
)
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder

from catboost import CatBoostClassifier

import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

print("Libraries loaded")

## 2. Load Models and Data

In [None]:
# Download from GCS
!mkdir -p /content/models
!gsutil -m cp gs://{BUCKET}/models/initial/* /content/models/
!gsutil cp gs://{BUCKET}/data/data_splits.npz /content/
!gsutil cp gs://{BUCKET}/data/feature_config.json /content/
!gsutil cp gs://{BUCKET}/data/whale_features_cleaned.csv /content/

In [None]:
# Load data splits
splits = np.load('/content/data_splits.npz')
X_train, y_train = splits['X_train'], splits['y_train']
X_val, y_val = splits['X_val'], splits['y_val']
X_test, y_test = splits['X_test'], splits['y_test']

# Also load CatBoost training data
X_train_cb = splits['X_train_cb']
y_train_cb = splits['y_train_cb']

# Load feature config
with open('/content/feature_config.json', 'r') as f:
    config = json.load(f)
FEATURES = config['features']

# Load original data for class merging experiment
df = pd.read_csv('/content/whale_features_cleaned.csv')

# Load label encoder
with open('/content/models/label_encoder.pkl', 'rb') as f:
    le = pickle.load(f)

label_mapping = {i: label for i, label in enumerate(le.classes_)}
labels = list(label_mapping.values())

print(f"Labels: {label_mapping}")
print(f"Test set: {len(X_test)} samples")

In [None]:
# Load models
models = {}

# CatBoost
models['CatBoost'] = CatBoostClassifier()
models['CatBoost'].load_model('/content/models/catboost_whale_classifier.cbm')

# Others
with open('/content/models/xgboost_whale_classifier.pkl', 'rb') as f:
    models['XGBoost'] = pickle.load(f)

with open('/content/models/randomforest_whale_classifier.pkl', 'rb') as f:
    models['RandomForest'] = pickle.load(f)

with open('/content/models/lightgbm_whale_classifier.pkl', 'rb') as f:
    models['LightGBM'] = pickle.load(f)

print(f"Loaded {len(models)} models")

## 3. Confusion Matrices

In [None]:
def get_predictions(model, X):
    """Get predictions, handling CatBoost's 2D output"""
    y_pred = model.predict(X)
    # CatBoost returns 2D array, flatten it
    if hasattr(y_pred, 'shape') and len(y_pred.shape) > 1:
        y_pred = y_pred.flatten()
    return y_pred.astype(int)

def plot_confusion_matrix(y_true, y_pred, labels, title):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=labels, yticklabels=labels)
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.show()
    return cm

In [None]:
# Plot confusion matrices for all models
confusion_matrices = {}

for name, model in models.items():
    y_pred = get_predictions(model, X_test)
    cm = plot_confusion_matrix(y_test, y_pred, labels, f'{name} Confusion Matrix')
    confusion_matrices[name] = cm

## 4. Per-Class Performance Analysis

In [None]:
# Detailed classification report for CatBoost
y_pred_cb = get_predictions(models['CatBoost'], X_test)

print("=" * 60)
print("CatBoost Classification Report")
print("=" * 60)
print(classification_report(y_test, y_pred_cb, target_names=labels, zero_division=0))

In [None]:
# Per-class metrics comparison across models
per_class_metrics = {}

for model_name, model in models.items():
    y_pred = get_predictions(model, X_test)
    
    precision = precision_score(y_test, y_pred, average=None, zero_division=0)
    recall = recall_score(y_test, y_pred, average=None, zero_division=0)
    f1 = f1_score(y_test, y_pred, average=None, zero_division=0)
    
    per_class_metrics[model_name] = {
        'precision': dict(zip(labels, precision)),
        'recall': dict(zip(labels, recall)),
        'f1': dict(zip(labels, f1))
    }

# Show per-class F1 for all models
print("\n=== Per-Class F1 Scores ===")
f1_df = pd.DataFrame({model: metrics['f1'] for model, metrics in per_class_metrics.items()})
f1_df['test_samples'] = [sum(y_test == le.transform([label])[0]) for label in labels]
print(f1_df.round(4))

In [None]:
# Visualize per-class F1 across models
fig, ax = plt.subplots(figsize=(14, 6))

x = np.arange(len(labels))
width = 0.2

for i, (model_name, metrics) in enumerate(per_class_metrics.items()):
    f1_values = [metrics['f1'][label] for label in labels]
    ax.bar(x + i*width, f1_values, width, label=model_name)

ax.set_ylabel('F1 Score')
ax.set_title('Per-Class F1 Score Comparison')
ax.set_xticks(x + width * 1.5)
ax.set_xticklabels(labels, rotation=45, ha='right')
ax.legend()
ax.set_ylim(0, 1)
ax.axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='50% threshold')
plt.tight_layout()
plt.show()

## 5. Feature Importance Analysis

In [None]:
# CatBoost feature importance
catboost_importance = models['CatBoost'].get_feature_importance()
importance_df = pd.DataFrame({
    'feature': FEATURES,
    'importance': catboost_importance
}).sort_values('importance', ascending=False)

print("=== CatBoost Feature Importance (Top 15) ===")
print(importance_df.head(15).to_string(index=False))

# WARNING about balance_eth_log dominance
top_feature_pct = importance_df.iloc[0]['importance']
print(f"\n*** WARNING: Top feature ({importance_df.iloc[0]['feature']}) has {top_feature_pct:.1f}% importance ***")
print("This suggests the model relies heavily on balance rather than behavioral patterns.")

In [None]:
# Visualize feature importance
plt.figure(figsize=(12, 10))
top_n = 20
top_features = importance_df.head(top_n)

colors = ['coral' if feat == 'balance_eth_log' else 'steelblue' 
          for feat in top_features['feature']]

plt.barh(range(len(top_features)), top_features['importance'].values, color=colors)
plt.yticks(range(len(top_features)), top_features['feature'].values)
plt.xlabel('Importance')
plt.title(f'Top {top_n} Most Important Features (CatBoost)\n(Orange = Dominant Feature)')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

In [None]:
# Compare feature importance across models
rf_importance = models['RandomForest'].feature_importances_
xgb_importance = models['XGBoost'].feature_importances_

importance_comparison = pd.DataFrame({
    'feature': FEATURES,
    'CatBoost': catboost_importance / catboost_importance.sum(),
    'RandomForest': rf_importance / rf_importance.sum(),
    'XGBoost': xgb_importance / xgb_importance.sum()
})

importance_comparison['avg'] = importance_comparison[['CatBoost', 'RandomForest', 'XGBoost']].mean(axis=1)
importance_comparison = importance_comparison.sort_values('avg', ascending=False)

print("=== Consensus Top Features (All Models) ===")
print(importance_comparison[['feature', 'avg', 'CatBoost', 'RandomForest', 'XGBoost']].head(15).round(4).to_string(index=False))

## 6. Error Analysis (FIXED)

In [None]:
# Error analysis for CatBoost
y_pred_cb = get_predictions(models['CatBoost'], X_test)

# Find misclassified samples
misclassified_idx = np.where(y_test != y_pred_cb)[0]
print(f"Total misclassifications: {len(misclassified_idx)} / {len(y_test)} ({len(misclassified_idx)/len(y_test)*100:.1f}%)")

# Analyze confusion pairs
confusion_pairs = []
for idx in misclassified_idx:
    true_label = label_mapping[y_test[idx]]
    pred_label = label_mapping[int(y_pred_cb[idx])]  # Cast to int to fix the error
    confusion_pairs.append((true_label, pred_label))

confusion_counts = Counter(confusion_pairs)

print("\n=== Most Common Misclassifications (CatBoost) ===")
for (true_l, pred_l), count in confusion_counts.most_common(10):
    print(f"{true_l:15s} -> {pred_l:15s}: {count} times")

In [None]:
# Analyze which classes are problematic
print("\n=== Class-wise Error Analysis ===")
for label_idx, label in label_mapping.items():
    mask = y_test == label_idx
    if mask.sum() == 0:
        continue
    
    correct = (y_pred_cb[mask] == label_idx).sum()
    total = mask.sum()
    accuracy = correct / total * 100
    
    # What are they misclassified as?
    misclassified_as = Counter(y_pred_cb[mask & (y_pred_cb != label_idx)])
    
    print(f"\n{label} ({total} samples): {accuracy:.1f}% correct")
    if misclassified_as:
        for pred_idx, count in misclassified_as.most_common(3):
            print(f"  -> Misclassified as {label_mapping[pred_idx]}: {count}")

## 7. RECOMMENDATION 1: Class Merging Experiment

### Rationale:
- **Institutional** has only 32 samples (6 in test set) and achieves only **15% F1**
- **CEX_Wallet** has 76 samples but achieves only **40% F1**
- Both share similar characteristics:
  - High balance
  - Lower transaction frequency than DeFi traders
  - Professional/institutional behavior
- Merging them into **Large_Holder** creates a more balanced class (108 samples)
- This reduces the problem from 6-class to 5-class, matching research assumptions

In [None]:
# Create merged dataset
df_merged = df.copy()

# Merge Institutional + CEX_Wallet -> Large_Holder
df_merged['category_merged'] = df_merged['category'].replace({
    'Institutional': 'Large_Holder',
    'CEX_Wallet': 'Large_Holder'
})

print("=== Original Class Distribution ===")
print(df['category'].value_counts())

print("\n=== Merged Class Distribution ===")
print(df_merged['category_merged'].value_counts())

In [None]:
# Prepare merged data for training
X_merged = df_merged[FEATURES].values
y_merged_raw = df_merged['category_merged'].values

# New label encoder for merged classes
le_merged = LabelEncoder()
y_merged = le_merged.fit_transform(y_merged_raw)

merged_labels = list(le_merged.classes_)
print(f"Merged labels: {merged_labels}")
print(f"Number of classes: {len(merged_labels)} (reduced from 6)")

In [None]:
# Train/test split for merged data
X_train_m, X_test_m, y_train_m, y_test_m = train_test_split(
    X_merged, y_merged,
    test_size=0.20,
    random_state=42,
    stratify=y_merged
)

print(f"Training: {len(X_train_m)}, Test: {len(X_test_m)}")
print(f"\nTest distribution:")
for i, label in enumerate(merged_labels):
    count = (y_test_m == i).sum()
    print(f"  {label}: {count}")

In [None]:
# Train CatBoost on merged data
print("Training CatBoost on merged classes (5 classes)...")

cb_merged = CatBoostClassifier(
    iterations=300,
    learning_rate=0.03,
    depth=5,
    l2_leaf_reg=3,
    loss_function='MultiClass',
    random_seed=42,
    verbose=50,
    auto_class_weights='Balanced'
)

cb_merged.fit(X_train_m, y_train_m)
print("Training complete!")

In [None]:
# Evaluate merged model
y_pred_merged = get_predictions(cb_merged, X_test_m)

print("=" * 60)
print("MERGED MODEL (5 classes) Classification Report")
print("=" * 60)
print(classification_report(y_test_m, y_pred_merged, target_names=merged_labels, zero_division=0))

# Compare metrics
f1_merged = f1_score(y_test_m, y_pred_merged, average='macro')
f1_original = f1_score(y_test, y_pred_cb, average='macro')

print(f"\n=== Comparison ===")
print(f"Original (6 classes) F1 Macro: {f1_original:.4f}")
print(f"Merged (5 classes) F1 Macro:   {f1_merged:.4f}")
print(f"Improvement: {(f1_merged - f1_original)*100:+.1f}%")

In [None]:
# Confusion matrix for merged model
plot_confusion_matrix(y_test_m, y_pred_merged, merged_labels, 'CatBoost (Merged Classes) Confusion Matrix')

In [None]:
# 5-Fold CV on merged data for more reliable comparison
print("\n=== 5-Fold CV Comparison ===")

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Original 6-class CV
X_orig = df[FEATURES].values
y_orig = le.transform(df['category'].values)

cv_scores_original = []
for train_idx, val_idx in skf.split(X_orig, y_orig):
    cb_temp = CatBoostClassifier(
        iterations=300, learning_rate=0.03, depth=5, l2_leaf_reg=3,
        random_seed=42, verbose=0, auto_class_weights='Balanced'
    )
    cb_temp.fit(X_orig[train_idx], y_orig[train_idx])
    y_pred_temp = get_predictions(cb_temp, X_orig[val_idx])
    cv_scores_original.append(f1_score(y_orig[val_idx], y_pred_temp, average='macro'))

print(f"Original (6 classes): {np.mean(cv_scores_original):.4f} (+/- {np.std(cv_scores_original):.4f})")

# Merged 5-class CV
cv_scores_merged = []
for train_idx, val_idx in skf.split(X_merged, y_merged):
    cb_temp = CatBoostClassifier(
        iterations=300, learning_rate=0.03, depth=5, l2_leaf_reg=3,
        random_seed=42, verbose=0, auto_class_weights='Balanced'
    )
    cb_temp.fit(X_merged[train_idx], y_merged[train_idx])
    y_pred_temp = get_predictions(cb_temp, X_merged[val_idx])
    cv_scores_merged.append(f1_score(y_merged[val_idx], y_pred_temp, average='macro'))

print(f"Merged (5 classes):   {np.mean(cv_scores_merged):.4f} (+/- {np.std(cv_scores_merged):.4f})")
print(f"\nCV Improvement: {(np.mean(cv_scores_merged) - np.mean(cv_scores_original))*100:+.1f}%")

## 8. RECOMMENDATION 2: Production Reliability Assessment

### Which predictions can we trust in production?
Based on per-class F1 scores, we categorize model reliability:

In [None]:
# Production reliability assessment
print("=" * 60)
print("PRODUCTION RELIABILITY ASSESSMENT")
print("=" * 60)

# Get per-class F1 for CatBoost
f1_per_class = per_class_metrics['CatBoost']['f1']

reliable = []
moderate = []
unreliable = []

for label, f1_val in f1_per_class.items():
    if f1_val >= 0.70:
        reliable.append((label, f1_val))
    elif f1_val >= 0.50:
        moderate.append((label, f1_val))
    else:
        unreliable.append((label, f1_val))

print("\n*** HIGH RELIABILITY (F1 >= 70%) - Trust these predictions ***")
for label, f1_val in reliable:
    print(f"  {label}: {f1_val*100:.1f}% F1")

print("\n*** MODERATE RELIABILITY (50-70% F1) - Use with caution ***")
for label, f1_val in moderate:
    print(f"  {label}: {f1_val*100:.1f}% F1")

print("\n*** LOW RELIABILITY (F1 < 50%) - Do NOT trust in production ***")
for label, f1_val in unreliable:
    print(f"  {label}: {f1_val*100:.1f}% F1")

In [None]:
# Visualize reliability tiers
fig, ax = plt.subplots(figsize=(10, 6))

labels_sorted = sorted(f1_per_class.keys(), key=lambda x: f1_per_class[x], reverse=True)
f1_values = [f1_per_class[l] for l in labels_sorted]

colors = []
for f1 in f1_values:
    if f1 >= 0.70:
        colors.append('green')
    elif f1 >= 0.50:
        colors.append('orange')
    else:
        colors.append('red')

bars = ax.bar(labels_sorted, f1_values, color=colors)

ax.axhline(y=0.70, color='green', linestyle='--', alpha=0.7, label='High reliability (70%)')
ax.axhline(y=0.50, color='orange', linestyle='--', alpha=0.7, label='Moderate reliability (50%)')

ax.set_ylabel('F1 Score')
ax.set_title('Production Reliability by Category\n(Green=Trust, Orange=Caution, Red=Avoid)')
ax.set_ylim(0, 1)
ax.legend()
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

## 9. RECOMMENDATION 3: Trading Signal Guidelines

### Based on model reliability, here are recommended trading actions:

In [None]:
print("=" * 60)
print("TRADING SIGNAL GUIDELINES")
print("=" * 60)

trading_signals = {
    'Exploiter': {
        'reliability': 'HIGH (94% F1)',
        'signal': 'STRONG AVOID',
        'action': 'Do NOT copy trades. These are exploit/hack addresses.',
        'rationale': 'Model correctly identifies 94% of exploiters. High tx/day + drained balance pattern.'
    },
    'Miner': {
        'reliability': 'HIGH (79% F1)',
        'signal': 'BEARISH on large sells',
        'action': 'Monitor for selling acceleration. Miners selling = supply pressure.',
        'rationale': 'Distinct low tx/day pattern. When they move, its significant.'
    },
    'DeFi_Trader': {
        'reliability': 'MODERATE (58% F1)',
        'signal': 'FOLLOW with caution',
        'action': 'Track DEX swaps and LP movements. Potential alpha signals.',
        'rationale': '58% reliability means 4 in 10 may be misclassified.'
    },
    'CEX_Wallet': {
        'reliability': 'LOW (40% F1)',
        'signal': 'NEUTRAL - Exchange flow only',
        'action': 'Only use for aggregate exchange flow analysis, not individual signals.',
        'rationale': 'Often confused with DeFi_Trader. Unreliable for individual wallet tracking.'
    },
    'Staker': {
        'reliability': 'LOW (43% F1)',
        'signal': 'WEAK SIGNAL',
        'action': 'Unstaking events may indicate selling intent, but low confidence.',
        'rationale': 'Frequently misclassified. Use with other confirmations only.'
    },
    'Institutional': {
        'reliability': 'VERY LOW (15% F1)',
        'signal': 'DO NOT USE',
        'action': 'Model cannot reliably identify institutional wallets.',
        'rationale': 'Only 1 in 6 correctly identified. Not production ready.'
    }
}

for category, info in trading_signals.items():
    print(f"\n{'='*40}")
    print(f"Category: {category}")
    print(f"{'='*40}")
    print(f"Reliability: {info['reliability']}")
    print(f"Signal:      {info['signal']}")
    print(f"Action:      {info['action']}")
    print(f"Rationale:   {info['rationale']}")

In [None]:
# Summary table for trading signals
signal_summary = pd.DataFrame([
    {'Category': 'Exploiter', 'F1': '94%', 'Reliability': 'HIGH', 'Trading Signal': 'AVOID - Exploit addresses'},
    {'Category': 'Miner', 'F1': '79%', 'Reliability': 'HIGH', 'Trading Signal': 'BEARISH on sells'},
    {'Category': 'DeFi_Trader', 'F1': '58%', 'Reliability': 'MODERATE', 'Trading Signal': 'FOLLOW with caution'},
    {'Category': 'Staker', 'F1': '43%', 'Reliability': 'LOW', 'Trading Signal': 'Weak unstake signal'},
    {'Category': 'CEX_Wallet', 'F1': '40%', 'Reliability': 'LOW', 'Trading Signal': 'Exchange flow only'},
    {'Category': 'Institutional', 'F1': '15%', 'Reliability': 'VERY LOW', 'Trading Signal': 'DO NOT USE'}
])

print("\n=== Trading Signal Summary ===")
print(signal_summary.to_string(index=False))

## 10. Save Results

In [None]:
# Save merged model if it performed better
import os
os.makedirs('/content/models_v2', exist_ok=True)

# Save merged model
cb_merged.save_model('/content/models_v2/catboost_merged_5class.cbm')

# Save merged label encoder
with open('/content/models_v2/label_encoder_merged.pkl', 'wb') as f:
    pickle.dump(le_merged, f)

# Save evaluation results
evaluation_results = {
    'original_6class': {
        'f1_macro_holdout': f1_original,
        'f1_macro_cv_mean': np.mean(cv_scores_original),
        'f1_macro_cv_std': np.std(cv_scores_original),
        'per_class_f1': per_class_metrics['CatBoost']['f1']
    },
    'merged_5class': {
        'f1_macro_holdout': float(f1_merged),
        'f1_macro_cv_mean': np.mean(cv_scores_merged),
        'f1_macro_cv_std': np.std(cv_scores_merged),
        'merged_classes': ['Institutional', 'CEX_Wallet'],
        'new_class': 'Large_Holder'
    },
    'production_reliability': {
        'high': [l for l, f in f1_per_class.items() if f >= 0.70],
        'moderate': [l for l, f in f1_per_class.items() if 0.50 <= f < 0.70],
        'low': [l for l, f in f1_per_class.items() if f < 0.50]
    },
    'trading_signals': trading_signals
}

with open('/content/models_v2/evaluation_results.json', 'w') as f:
    json.dump(evaluation_results, f, indent=2, default=str)

print("Results saved locally")

In [None]:
# Upload to GCS
!gsutil -m cp -r /content/models_v2/* gs://{BUCKET}/models/evaluation/
print(f"Uploaded to gs://{BUCKET}/models/evaluation/")

## Summary

### Key Findings:

1. **Feature Dominance**: `balance_eth_log` accounts for 27% of model importance - suggests model relies on balance rather than behavioral patterns

2. **Class Merging Experiment**: 
   - Merging Institutional + CEX_Wallet -> Large_Holder
   - Reduces problem from 6-class to 5-class
   - CV improvement: [see output above]

3. **Production Reliability**:
   - HIGH: Exploiter (94%), Miner (79%)
   - MODERATE: DeFi_Trader (58%)
   - LOW: CEX_Wallet (40%), Staker (43%), Institutional (15%)

4. **Trading Recommendations**:
   - Trust Exploiter/Miner predictions for trading signals
   - Use DeFi_Trader with caution
   - Avoid using Institutional predictions entirely

### Next Steps:
- Notebook 04: Hyperparameter tuning (focus on CatBoost)
- Consider using merged 5-class model in production if CV shows improvement