# üìò Notebook 6: Neural Networks & Fraud Detection - Putting It All Together

Welcome to the grand finale! This notebook brings together **everything you've learned** to build a real fraud detection system from scratch.

## üéØ What You'll Learn (40-50 minutes)

By the end of this notebook, you'll have built:
- ‚úÖ A complete neural network classifier in JAX
- ‚úÖ A real-world fraud detection model
- ‚úÖ End-to-end training pipeline
- ‚úÖ Performance evaluation and metrics
- ‚úÖ Model interpretation and analysis
- ‚úÖ Practical deployment considerations

**This is where theory meets practice!** üöÄ

## ü§î What is Fraud Detection?

### The Real-World Problem
Credit card companies process millions of transactions daily. A tiny fraction (<0.2%) are fraudulent, but catching them is critical:

**Challenges:**
- **Extreme imbalance:** 99.8% legitimate, 0.2% fraud
- **High stakes:** Miss fraud = $$ lost; False alarm = angry customer
- **Real-time:** Must decide in milliseconds
- **Evolving patterns:** Fraudsters constantly adapt

**Your goal:** Build a neural network that identifies fraudulent transactions!

### The Dataset: Credit Card Fraud Detection
**Source:** Real credit card transactions from European cardholders (anonymized)

**Size:** 284,807 transactions over 2 days

**Features:**
- `Time`: Seconds since first transaction
- `V1-V28`: Anonymized features (PCA transformed for privacy)
- `Amount`: Transaction amount
- `Class`: 0 = Legitimate, 1 = Fraud

**Why this dataset?**
- Real-world imbalanced classification problem
- Demonstrates practical ML challenges
- Commonly used benchmark

## üß† Neural Network Architecture

### What You'll Build
A **Multi-Layer Perceptron (MLP)** with:
- **Input layer:** Features from the dataset (V1-V28 + Time + Amount, typically 29-30 features)
- **Hidden layer 1:** 64 neurons + ReLU activation
- **Hidden layer 2:** 32 neurons + ReLU activation
- **Hidden layer 3:** 16 neurons + ReLU activation
- **Output layer:** 1 neuron + Sigmoid activation (probability of fraud)

### Why This Architecture?
- **Not too complex:** Small dataset (284K samples) doesn't need huge network
- **Enough capacity:** 3 hidden layers can learn complex patterns
- **Fast training:** Small enough to train on CPU in minutes
- **Proven effective:** This architecture works well for tabular data

### Architecture Diagram
```
Input (dynamic) ‚Üí Dense(64) + ReLU ‚Üí Dense(32) + ReLU ‚Üí Dense(16) + ReLU ‚Üí Dense(1) + Sigmoid ‚Üí Fraud Probability
```

## üìö Key Concepts for Beginners

### 1. What is a Neural Network?
**Simple answer:** A function that learns patterns from data!

**How it works:**
1. Takes input features (transaction data)
2. Multiplies by weights and adds biases (learned parameters)
3. Applies activation functions (introduces non-linearity)
4. Produces output (fraud probability)

**Learning = adjusting weights to minimize errors**

### 2. Activation Functions

**ReLU (Rectified Linear Unit):**
- Formula: `max(0, x)`
- Purpose: Introduces non-linearity (lets network learn complex patterns)
- Why: Simple, fast, works well

**Sigmoid:**
- Formula: `1 / (1 + e^(-x))`
- Purpose: Squashes output to [0, 1] range
- Why: Perfect for probabilities!

### 3. Loss Function: Binary Cross-Entropy
**What:** Measures how wrong the model's predictions are

**Formula:** `-[y*log(p) + (1-y)*log(1-p)]`
- `y`: True label (0 or 1)
- `p`: Predicted probability

**Why:** Penalizes confident wrong predictions heavily

### 4. Optimizer: Stochastic Gradient Descent (SGD)
**What:** Algorithm that updates weights to minimize loss

**How:**
1. Compute gradient (how to change weights to reduce loss)
2. Update: `weight = weight - learning_rate * gradient`
3. Repeat until loss stops decreasing

**Learning rate:** Step size (too big = unstable, too small = slow)

### 5. Metrics for Imbalanced Data

**Accuracy is misleading!**
- If 99.8% are legitimate, predicting "all legitimate" gives 99.8% accuracy
- But catches ZERO fraud!

**Better metrics explained:**

#### **Confusion Matrix - The Foundation**
Every prediction falls into one of four categories:
```
                    Predicted: Fraud    Predicted: Legitimate
Actual: Fraud       TP (True Positive)  FN (False Negative)
Actual: Legitimate  FP (False Positive) TN (True Negative)
```
- **TP (True Positive):** Correctly caught fraud ‚úÖ (Good!)
- **TN (True Negative):** Correctly identified legitimate ‚úÖ (Good!)
- **FP (False Positive):** Flagged legitimate as fraud ‚ùå (Annoying customer)
- **FN (False Negative):** Missed actual fraud ‚ùå (Lost money!)

#### **Precision - "How accurate are our fraud alerts?"**
**Formula:** `Precision = TP / (TP + FP)`

**What it means:** Of all transactions we flagged as fraud, what percentage were actually fraud?

**Example:** 
- Flagged 100 transactions as fraud (our predictions)
- Only 80 were actually fraud (true positives)
- 20 were legitimate (false positives - we annoyed 20 customers!)
- Precision = 80/100 = 0.80 or 80%

**Interpretation:**
- **High precision (close to 1.0):** Few false alarms, customers rarely get bothered
- **Low precision (close to 0.0):** Many false alarms, customers get angry
- **Target:** Usually want >0.70 (70%) in fraud detection

#### **Recall (Sensitivity) - "How many frauds did we catch?"**
**Formula:** `Recall = TP / (TP + FN)`

**What it means:** Of all actual fraud cases, what percentage did we successfully catch?

**Example:**
- 100 actual fraud transactions happened (reality)
- We caught 90 of them (true positives)
- We missed 10 (false negatives - lost money!)
- Recall = 90/100 = 0.90 or 90%

**Interpretation:**
- **High recall (close to 1.0):** Catching most frauds, minimizing losses
- **Low recall (close to 0.0):** Missing many frauds, big financial loss!
- **Target:** Usually want >0.80 (80%) in fraud detection

#### **The Precision-Recall Trade-off**
**The dilemma:** You can't maximize both simultaneously!

**Make model more sensitive (predict fraud more often):**
- ‚Üë Recall increases (catch more frauds) ‚úÖ
- ‚Üì Precision decreases (more false alarms) ‚ùå

**Make model more conservative (predict fraud less often):**
- ‚Üë Precision increases (fewer false alarms) ‚úÖ
- ‚Üì Recall decreases (miss more frauds) ‚ùå

**Real-world decision:**
- Banks often prefer **higher recall** (catch frauds, even with false alarms)
- Why? Losing $1000 to fraud >> annoying one customer with a call

#### **F1-Score - "Overall balance of precision and recall"**
**Formula:** `F1 = 2 √ó (Precision √ó Recall) / (Precision + Recall)`

**What it means:** Harmonic mean that balances precision and recall. Only high if BOTH are high!

**Example scenarios:**
- Precision=0.90, Recall=0.90 ‚Üí F1=0.90 (Excellent! ‚≠ê)
- Precision=0.95, Recall=0.50 ‚Üí F1=0.66 (Unbalanced)
- Precision=0.50, Recall=0.95 ‚Üí F1=0.66 (Unbalanced)
- Precision=1.00, Recall=0.10 ‚Üí F1=0.18 (Terrible!)

**Interpretation:**
- **F1 > 0.80:** Excellent model for imbalanced data
- **F1 = 0.60-0.80:** Good, room for improvement
- **F1 < 0.60:** Poor, needs significant work
- **Why use it:** Single metric that punishes extreme imbalance

#### **PR-AUC (Precision-Recall Area Under Curve)**
**Formula:** Area under the Precision-Recall curve across all thresholds

**What it means:** How well the model performs across ALL possible decision thresholds (0.1, 0.5, 0.9, etc.)

**Threshold concept:**
- Model outputs probability: 0.83 = "83% chance of fraud"
- We pick threshold (e.g., 0.5): if prob ‚â• 0.5, predict fraud
- Different thresholds give different precision/recall trade-offs

**Example:**
- Threshold=0.9 (very strict): High precision, low recall (few predictions)
- Threshold=0.3 (very lenient): Low precision, high recall (many predictions)

**Interpretation:**
- **PR-AUC = 1.0:** Perfect model (impossible in practice)
- **PR-AUC > 0.80:** Excellent performance
- **PR-AUC = 0.40-0.80:** Decent to good
- **PR-AUC < 0.40:** Poor, barely better than random
- **Baseline:** Random guessing = fraud prevalence rate (0.002 for this dataset)

**Why use PR-AUC:** Best metric for imbalanced data! Better than ROC-AUC because it focuses on the minority class (fraud).

#### **ROC-AUC (Receiver Operating Characteristic)**
**Formula:** Area under the ROC curve (True Positive Rate vs False Positive Rate)

**What it means:** Model's ability to distinguish between classes across all thresholds

**Components:**
- True Positive Rate (TPR) = Recall = TP/(TP+FN)
- False Positive Rate (FPR) = FP/(FP+TN)

**Interpretation:**
- **ROC-AUC = 1.0:** Perfect discrimination
- **ROC-AUC > 0.90:** Excellent
- **ROC-AUC = 0.70-0.90:** Good
- **ROC-AUC = 0.50:** Random guessing (useless!)
- **ROC-AUC < 0.50:** Worse than random (model is backwards!)

**Caveat for imbalanced data:** 
- ROC-AUC can be misleading with severe imbalance (like 577:1)
- May look good even when model performs poorly on minority class
- **Prefer PR-AUC for this fraud dataset!**

#### **Quick Decision Guide:**
```
Question                          ‚Üí Metric to Check
---------------------------       ‚Üí -----------------
"Are fraud alerts accurate?"      ‚Üí Precision
"Are we catching most frauds?"    ‚Üí Recall  
"Overall balance of both?"        ‚Üí F1-Score
"Performance across thresholds?"  ‚Üí PR-AUC (best for imbalance)
"General discrimination ability?" ‚Üí ROC-AUC
```

#### **What's "Good" for Fraud Detection?**
Based on industry standards:
- **Precision:** 0.70-0.90 (70-90%)
- **Recall:** 0.75-0.95 (75-95%)
- **F1-Score:** 0.70-0.85 (70-85%)
- **PR-AUC:** 0.60-0.90 (60-90%)
- **ROC-AUC:** 0.85-0.98 (85-98%)

**Remember:** There's always a trade-off! The "best" model depends on business priorities (lose money vs. annoy customers).

## üéì What's in This Notebook?

This comprehensive notebook includes:

1. **Data Loading & Exploration**
   - Load credit card fraud dataset
   - Understand data distribution and imbalance
   - Visualize key patterns

2. **Data Preprocessing**
   - Normalization (scale features to same range)
   - Train/test split (evaluate on unseen data)
   - Batch preparation using Polars

3. **Model Definition**
   - Neural network architecture in pure JAX
   - Weight initialization
   - Forward pass implementation

4. **Training Pipeline**
   - Loss function with binary cross-entropy
   - Gradient computation using `jax.grad`
   - Optimization step with SGD
   - Full training loop with `jit` and `vmap`

5. **Evaluation**
   - Compute predictions on test set
   - Calculate precision, recall, F1, ROC-AUC
   - Confusion matrix
   - Identify optimal threshold

6. **Analysis & Insights**
   - Feature importance
   - Error analysis (false positives/negatives)
   - Model interpretation
   - Deployment considerations

## üöÄ Prerequisites

Before starting this notebook, you should:
- ‚úÖ Complete Notebooks 1-4 (JAX Basics through vmap)
- ‚úÖ Understand what a neural network is (conceptually)
- ‚úÖ Know basic Python and NumPy
- ‚ùå **Don't need**: Deep learning expertise (we build everything from scratch!)

## üèÜ JAX Transformations in Action

This notebook showcases **all JAX superpowers together:**

| Transformation | Purpose in This Project |
|----------------|-------------------------|
| `jit` | 10-100x faster training |
| `grad` | Automatic gradient computation |
| `vmap` | Batch processing (no loops!) |
| Functional style | Clean, composable code |

**This is JAX at its best!** ‚ö°

## üí° Key Takeaway

**You're building a complete ML system:**
- Data ‚Üí Preprocessing ‚Üí Model ‚Üí Training ‚Üí Evaluation ‚Üí Insights

**Using only JAX + basic libraries** - no high-level frameworks!

This shows you how everything works under the hood. üîç

## üéØ Learning Outcomes

After completing this notebook, you'll be able to:
- ‚úÖ Build neural networks from scratch in JAX
- ‚úÖ Handle imbalanced datasets
- ‚úÖ Train models efficiently with JAX transformations
- ‚úÖ Evaluate models with appropriate metrics
- ‚úÖ Apply ML to real-world problems

**You'll have a complete, working fraud detection system!** üéâ

Let's build something real! üí≥üõ°Ô∏è

## üìö Part 1: Understanding Data Handling

### What is Data Preprocessing?

**Think of it like cooking:** Before you cook a meal, you need to prep ingredients - wash vegetables, cut meat, measure spices. Similarly, before training a neural network, you need to prep your data!

**Raw data ‚Üí Preprocessed data ‚Üí Model training ‚Üí Predictions**

### Why Can't We Use Raw Data?

**Problem 1: Different Scales**
- Feature 1 (Transaction Amount): ranges from $0 to $10,000
- Feature 2 (Time): ranges from 0 to 172,792 seconds
- Neural networks struggle when features have vastly different scales!

**Solution:** Standardization (make all features have similar ranges)

**Problem 2: Data Leakage**
- If we train and test on the same data, model just memorizes!
- Like studying with the exact test questions - cheating!

**Solution:** Split data into train/validation/test sets

**Problem 3: Class Imbalance**
- 99.8% legitimate, 0.2% fraud
- Model learns to always predict "legitimate" and gets 99.8% accuracy!
- But catches zero fraud!

**Solution:** Class weights (penalize errors on rare class more)

---

### The Data Pipeline (Step-by-Step)

#### **Step 1: Load Raw Data**
```python
data = fetch_openml('creditcard')  # Download dataset
```
**What happens:** Downloads 284,807 credit card transactions with 30 features

---

#### **Step 2: Separate Features and Labels**
```python
X = features (V1, V2, ..., V28, Time, Amount)
y = labels (0=legitimate, 1=fraud)
```
**Why:** Model learns from X (inputs) to predict y (outputs)

---

#### **Step 3: Train/Validation/Test Split**

**The Three Sets:**

1. **Training Set (70%):** Model learns from this
   - Like practice problems when studying
   - Model adjusts weights to minimize errors here
   - Example: 199,365 transactions

2. **Validation Set (15%):** Tune hyperparameters
   - Like practice exams before the real test
   - Check if model is overfitting (memorizing instead of learning)
   - Helps decide when to stop training
   - Example: 42,721 transactions

3. **Test Set (15%):** Final evaluation (touch ONCE at the end!)
   - Like the real exam
   - Model has never seen this data during training
   - Gives honest assessment of real-world performance
   - Example: 42,721 transactions

**Critical Rule:** NEVER train on validation or test data! That's cheating!

**Stratified Splitting:**
```python
train_test_split(..., stratify=y)
```
- Ensures each split has the same fraud ratio (0.173%)
- Without this, test set might have no fraud cases!

---

#### **Step 4: Feature Standardization (Normalization)**

**Problem:** Features have wildly different ranges
- Time: [0, 172,792]
- V1: [-56.4, 2.5]
- Amount: [0, 25,691]

**Solution: StandardScaler**
```python
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)  # Learn mean & std, then transform
X_val = scaler.transform(X_val)          # Use same mean & std from training
X_test = scaler.transform(X_test)        # Use same mean & std from training
```

**What it does:** Transforms each feature to have:
- **Mean = 0** (centered around zero)
- **Standard deviation = 1** (similar spread)

**Formula:** `x_scaled = (x - mean) / std`

**Example:**
```
Original Amount: [10, 100, 1000]
Mean: 370, Std: 500
Scaled: [-0.72, -0.54, 1.26]  (all similar magnitude!)
```

**CRITICAL:** 
- **Fit only on training data** (compute mean & std from training)
- **Transform validation and test** using training statistics
- Why? Test set should represent "unseen future data" - we won't know its statistics!

---

#### **Step 5: Handle Class Imbalance with Weights**

**The Problem:**
- 284,315 normal transactions
- 492 fraud transactions
- Ratio: 577:1 (extreme imbalance!)

**Naive approach:** Model predicts "all normal" ‚Üí 99.8% accuracy, 0% fraud detection ‚ùå

**Solution: Class Weights**
```python
weight_normal = n_samples / (2 * n_normal) = 0.5
weight_fraud = n_samples / (2 * n_fraud) = 289.2
```

**What this means:**
- When model makes error on fraud: penalty √ó 289.2
- When model makes error on normal: penalty √ó 0.5
- Forces model to pay attention to rare fraud cases!

**In loss function:**
```python
loss = weight √ó error
```
- Fraud errors hurt much more ‚Üí model learns to catch fraud!

---

### Data Shape Summary

**Before preprocessing:**
```
Raw data: (284,807 transactions, 30 features)
```

**After preprocessing:**
```
X_train: (199,365, 29)  # 70% of data, standardized features
y_train: (199,365,)     # Labels (0 or 1)

X_val:   (42,721, 29)   # 15% of data, standardized
y_val:   (42,721,)      # Labels

X_test:  (42,721, 29)   # 15% of data, standardized (unseen until final eval)
y_test:  (42,721,)      # Labels (used only for final metrics)
```

**Note:** 29 features (not 30) because we dropped the 'Class' column (that's our label y!)

---

### Key Takeaways

‚úÖ **Standardization:** Makes features comparable (same scale)
‚úÖ **Train/Val/Test Split:** Prevents cheating, gives honest evaluation
‚úÖ **Stratified Split:** Preserves class distribution across splits
‚úÖ **Class Weights:** Handles extreme imbalance (577:1 ratio)
‚úÖ **Fit on Train Only:** Never let model "peek" at validation/test statistics
‚úÖ **Test Set is Sacred:** Use only once at the very end!

**Now we're ready to build the neural network!** üß†

In [1]:
# =============================================================================
# SETUP AND DATA LOADING
# =============================================================================

import jax
import jax.numpy as jnp
import torch
import torch.nn as nn
import torch.optim as optim
import polars as pl
import numpy as np
import time
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    precision_score, recall_score, f1_score, 
    confusion_matrix, classification_report,
    average_precision_score, roc_auc_score
)
from sklearn.datasets import fetch_openml

print("=" * 70)
print("LOADING CREDIT CARD FRAUD DETECTION DATASET")
print("=" * 70)

# Load dataset from OpenML
print("\nDownloading dataset from OpenML (may take a moment)...")
data = fetch_openml('creditcard', version=1, as_frame=True, parser='auto')
df = data.frame

print(f"‚úÖ Dataset loaded: {df.shape[0]:,} transactions, {df.shape[1]-1} features")

# Inspect the data
print(f"\nüìä Dataset Overview:")
print(f"  Shape: {df.shape}")
print(f"  Features: {df.columns.tolist()}")
print(f"\n  Class distribution:")
fraud_count = (df['Class'] == '1').sum()
normal_count = (df['Class'] == '0').sum()
total = len(df)
print(f"    Normal transactions: {normal_count:,} ({100*normal_count/total:.3f}%)")
print(f"    Fraud transactions:  {fraud_count:,} ({100*fraud_count/total:.3f}%)")
print(f"    Imbalance ratio: {normal_count//fraud_count}:1")

print(f"\n  First few rows:")
print(df.head())

# =============================================================================
# DATA PREPROCESSING
# =============================================================================

print("\n" + "=" * 70)
print("DATA PREPROCESSING")
print("=" * 70)

# Separate features and target
X = df.drop('Class', axis=1).values.astype(np.float32)
y = df['Class'].astype(int).values

# Split data: 70% train, 15% val, 15% test
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=0.15, random_state=42, stratify=y
)
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.1765, random_state=42, stratify=y_temp  # 0.1765 * 0.85 ‚âà 0.15
)

# Standardize features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)

print(f"\nüìä Data Splits:")
print(f"  Train: {X_train.shape[0]:,} samples ({100*len(X_train)/total:.1f}%)")
print(f"  Val:   {X_val.shape[0]:,} samples ({100*len(X_val)/total:.1f}%)")
print(f"  Test:  {X_test.shape[0]:,} samples ({100*len(X_test)/total:.1f}%)")

print(f"\n  Class distribution in splits:")
print(f"    Train - Fraud: {y_train.sum():,} ({100*y_train.sum()/len(y_train):.3f}%)")
print(f"    Val   - Fraud: {y_val.sum():,} ({100*y_val.sum()/len(y_val):.3f}%)")
print(f"    Test  - Fraud: {y_test.sum():,} ({100*y_test.sum()/len(y_test):.3f}%)")

# Calculate class weights for imbalance
n_samples = len(y_train)
n_fraud = y_train.sum()
n_normal = n_samples - n_fraud
weight_fraud = n_samples / (2 * n_fraud)
weight_normal = n_samples / (2 * n_normal)

print(f"\n‚öñÔ∏è  Class Weights (for balanced loss):")
print(f"  Normal: {weight_normal:.4f}")
print(f"  Fraud:  {weight_fraud:.4f}")
print(f"  Ratio:  {weight_fraud/weight_normal:.2f}x (frauds weighted higher)")

LOADING CREDIT CARD FRAUD DETECTION DATASET

Downloading dataset from OpenML (may take a moment)...
‚úÖ Dataset loaded: 284,807 transactions, 29 features

üìä Dataset Overview:
  Shape: (284807, 30)
  Features: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount', 'Class']

  Class distribution:
    Normal transactions: 284,315 (99.827%)
    Fraud transactions:  492 (0.173%)
    Imbalance ratio: 577:1

  First few rows:
         V1        V2        V3        V4        V5        V6        V7  \
0 -1.359807 -0.072781  2.536347  1.378155 -0.338321  0.462388  0.239599   
1  1.191857  0.266151  0.166480  0.448154  0.060018 -0.082361 -0.078803   
2 -1.358354 -1.340163  1.773209  0.379780 -0.503198  1.800499  0.791461   
3 -0.966272 -0.185226  1.792993 -0.863291 -0.010309  1.247203  0.237609   
4 -1.158233  0.877737  1.548718  0.403034 -0.407193  0.0

## üìö Part 2: Understanding Neural Networks & Training

### What is a Neural Network? (Beginner-Friendly Explanation)

**Simple answer:** A neural network is a mathematical function that learns patterns from data!

**Analogy:** Think of it as a smart pattern recognition machine:
- **Input:** Transaction features (amount, time, location, etc.)
- **Processing:** Lots of mathematical operations (matrix multiplications + activations)
- **Output:** Probability ("This transaction is 87% likely to be fraud")

**How it "learns":** By adjusting internal numbers (weights and biases) to make better predictions!

---

### Neural Network Architecture Explained

**Our Architecture:** Input (29) ‚Üí Dense(64) ‚Üí Dense(32) ‚Üí Dense(16) ‚Üí Output(1)

#### **Layer-by-Layer Breakdown:**

**1. Input Layer (29 neurons)**
- Not really a "layer" - just your input data
- 29 features: V1, V2, ..., V28, Time
- Each neuron = one feature value

**2. Hidden Layer 1 (64 neurons)**
```python
x = W1 @ input + b1  # Linear transformation (matrix multiplication)
x = ReLU(x)          # Activation (keeps only positive values)
```
- **Purpose:** Learn 64 different "feature combinations"
- **Example feature combo:** "High amount + unusual time + rare location = suspicious"
- Each neuron learns one pattern

**3. Hidden Layer 2 (32 neurons)**
```python
x = W2 @ x + b2      # Another transformation
x = ReLU(x)          # Activation
```
- **Purpose:** Combine patterns from layer 1 into higher-level patterns
- **Example:** "Multiple suspicious patterns together = likely fraud"

**4. Hidden Layer 3 (16 neurons)**
```python
x = W3 @ x + b3      # Another transformation
x = ReLU(x)          # Activation
```
- **Purpose:** Further refine patterns
- Gets more abstract and specific to fraud detection

**5. Output Layer (1 neuron)**
```python
output = W4 @ x + b4  # Final transformation
output = Sigmoid(output)  # Squashes to probability [0, 1]
```
- **Output:** Single number between 0 and 1
- **Interpretation:** Probability that transaction is fraud
  - 0.05 = "5% chance of fraud" ‚Üí Predict: Legitimate
  - 0.92 = "92% chance of fraud" ‚Üí Predict: Fraud

---

### Key Components Explained

#### **1. Weights (W) and Biases (b)**

**Weights:** Numbers that determine "how much each input matters"
```python
W1 shape: (29, 64)  # Connects 29 inputs to 64 neurons
```
- Each connection has a weight
- Positive weight = "this feature increases fraud probability"
- Negative weight = "this feature decreases fraud probability"

**Biases:** Offset values (one per neuron)
```python
b1 shape: (64,)  # One bias per neuron in layer 1
```
- Shifts the activation threshold
- Allows neuron to activate even when inputs are zero

**Initialization:** He initialization
```python
W = sqrt(2/input_size) * random_normal()
```
- Not too small (learning would be slow)
- Not too large (training would be unstable)
- Magic formula that works well for ReLU activations!

---

#### **2. Activation Functions**

**ReLU (Rectified Linear Unit)** - Used in hidden layers
```python
ReLU(x) = max(0, x)
```
**What it does:**
- If x > 0: Keep it (output = x)
- If x ‚â§ 0: Set to 0 (output = 0)

**Why we need it:**
- Without activation: Network is just linear (like y = mx + b)
- With ReLU: Network can learn complex, non-linear patterns!
- **Example:** "Fraud happens when amount > 1000 AND time is between 2am-4am" (non-linear!)

**Visual:**
```
Input:  [-2, -1, 0, 1, 2]
ReLU:   [ 0,  0, 0, 1, 2]  (zeros out negatives)
```

**Sigmoid** - Used in output layer
```python
Sigmoid(x) = 1 / (1 + e^(-x))
```
**What it does:**
- Squashes any number to range [0, 1]
- Perfect for probabilities!

**Visual:**
```
Input:    [-10, -2, 0, 2, 10]
Sigmoid:  [0.00, 0.12, 0.5, 0.88, 1.00]  (all between 0 and 1)
```

---

#### **3. Loss Function: Binary Cross-Entropy**

**Purpose:** Measures "how wrong" the model's predictions are

**Formula:**
```python
loss = -[y * log(p) + (1-y) * log(1-p)]
```
Where:
- `y` = true label (0 or 1)
- `p` = predicted probability

**Example 1:** True label = 1 (fraud), Prediction = 0.95
```
loss = -[1 * log(0.95) + 0 * log(0.05)]
     = -log(0.95)
     = 0.051  (small loss - good prediction!)
```

**Example 2:** True label = 1 (fraud), Prediction = 0.10
```
loss = -[1 * log(0.10) + 0 * log(0.90)]
     = -log(0.10)
     = 2.303  (big loss - terrible prediction!)
```

**Key insight:** Loss is low when prediction matches reality, high when they differ!

**With Class Weights:**
```python
weighted_loss = weight * loss
```
- Fraud errors get weight = 289.2 (huge penalty!)
- Normal errors get weight = 0.5 (small penalty)
- Forces model to focus on catching fraud!

---

#### **4. Training Process (How the Model Learns)**

**The Learning Algorithm: Gradient Descent**

Think of it like hiking down a mountain in fog:
- You can't see the bottom (optimal weights)
- But you can feel the slope under your feet (gradient)
- Take small steps downhill (opposite of gradient)
- Eventually reach the valley (minimum loss)

**Step-by-Step Training Loop:**

**1. Forward Pass** (Make Predictions)
```python
predictions = model(X_batch)  # Run data through network
```
- Input flows through all layers
- Produces predictions (probabilities)

**2. Compute Loss** (How Wrong Are We?)
```python
loss = binary_cross_entropy(predictions, y_batch, weights)
```
- Compare predictions to true labels
- Higher loss = worse predictions

**3. Backward Pass** (Compute Gradients)
```python
grads = jax.grad(loss_fn)(params, X_batch, y_batch)
```
- **Gradient:** Direction and magnitude to change each weight
- JAX computes this automatically (magic of autodiff!)
- Tells us: "Increase W1[0,0] by 0.02 to reduce loss"

**4. Update Weights** (Learn!)
```python
W = W - learning_rate * gradient
b = b - learning_rate * gradient
```
- **Learning rate (0.001):** Step size
  - Too big: Overshoot the minimum, training unstable
  - Too small: Takes forever to converge
  - 0.001 is a good starting point!

**5. Repeat** for all batches, all epochs
- **Batch:** Subset of data (256 transactions)
- **Epoch:** One pass through entire dataset
- **10 epochs:** Model sees each transaction 10 times

---

### Training Hyperparameters Explained

**1. Batch Size = 256**
- Don't use all 199,365 transactions at once (too slow!)
- Don't use 1 transaction at a time (too noisy!)
- Use mini-batches of 256 (good balance)
- **Math:** 199,365 / 256 ‚âà 779 batches per epoch

**2. Learning Rate = 0.001**
- How big each weight update step is
- 0.001 is conservative but safe
- Prevents overshooting and instability

**3. Epochs = 10**
- How many times model sees entire dataset
- More epochs = more learning (but risk overfitting)
- 10 is reasonable for this dataset size

**4. Architecture: 64 ‚Üí 32 ‚Üí 16**
- **Why decreasing?** Funnel pattern
  - Layer 1: Learn many low-level patterns
  - Layer 2: Combine into mid-level patterns
  - Layer 3: Refine to high-level fraud indicators
- **Why these numbers?** Empirically work well!
  - Not too big: Faster training, less overfitting
  - Not too small: Enough capacity to learn complex patterns

---

### JAX vs PyTorch: Key Differences

**JAX (Functional Programming):**
```python
# Explicit parameter passing
params = init_network_params(sizes, key)
output = forward(params, x)
loss, grads = value_and_grad(loss_fn)(params, x, y)
params = update(params, grads, lr)  # Manual update
```
‚úÖ Explicit control
‚úÖ Easy to compose transformations (jit + grad + vmap)
‚úÖ Faster with JIT compilation
‚ùå More boilerplate code

**PyTorch (Object-Oriented):**
```python
# Stateful model class
model = NeuralNet()
output = model(x)
loss.backward()  # Automatic gradient computation
optimizer.step()  # Automatic update
```
‚úÖ Less boilerplate
‚úÖ Familiar to most ML practitioners
‚úÖ Great ecosystem
‚ùå Less control over transformations

**Both produce same results!** Just different coding styles.

---

### What Happens During Training?

**Epoch 1:**
- Weights are random ‚Üí Predictions are random ‚Üí Loss is high (‚âà0.7)
- Gradients computed ‚Üí Weights adjusted
- Model slightly better at end of epoch 1

**Epoch 2-5:**
- Model learning quickly
- Loss decreasing steadily (‚âà0.5 ‚Üí 0.3)
- Starting to recognize fraud patterns

**Epoch 6-10:**
- Model refining predictions
- Loss decreasing slowly (‚âà0.3 ‚Üí 0.2)
- Fine-tuning weights for best performance

**After Training:**
- Model has learned optimal weights
- Can make predictions on new transactions!

---

### Common Questions

**Q: Why not just use more layers?**
A: More layers = more capacity, but also:
- Slower training
- Risk of overfitting (memorizing training data)
- Diminishing returns (3 layers often enough for tabular data)

**Q: What if loss doesn't decrease?**
A: Could be:
- Learning rate too high (try 0.0001)
- Learning rate too low (try 0.01)
- Bad initialization (run again with different seed)
- Insufficient model capacity (add more neurons)

**Q: How do we know when to stop training?**
A: Monitor validation loss:
- If training loss ‚Üì but validation loss ‚Üë ‚Üí Overfitting! Stop.
- If both ‚Üì ‚Üí Keep training
- If both plateau ‚Üí Done! Converged.

---

### Key Takeaways

‚úÖ **Neural Network:** Function with adjustable weights that learns patterns
‚úÖ **Forward Pass:** Input ‚Üí Hidden Layers ‚Üí Output (predictions)
‚úÖ **Loss Function:** Measures prediction errors (we want to minimize this)
‚úÖ **Gradient:** Direction to change weights to reduce loss
‚úÖ **Training:** Repeatedly adjust weights to minimize loss
‚úÖ **Batch Training:** Process small batches for efficiency
‚úÖ **Class Weights:** Handle imbalance by penalizing fraud errors more
‚úÖ **Activation Functions:** Enable learning complex patterns

**Now let's see it in action with JAX!** üöÄ

In [3]:
# =============================================================================
# JAX IMPLEMENTATION
# =============================================================================

print("=" * 70)
print("JAX NEURAL NETWORK - FUNCTIONAL APPROACH")
print("=" * 70)

# Hyperparameters
input_dim = X_train.shape[1]  # Dynamically get input dimension from data
hidden_dims = [64, 32, 16]
output_dim = 1
learning_rate = 0.001
batch_size = 256
n_epochs = 10

# Initialize network parameters
def init_network_params(layer_sizes, key):
    """Initialize network with He initialization."""
    params = []
    for i in range(len(layer_sizes) - 1):
        key, subkey = jax.random.split(key)
        # He initialization: scale by sqrt(2/fan_in)
        scale = jnp.sqrt(2.0 / layer_sizes[i])
        W = scale * jax.random.normal(subkey, (layer_sizes[i], layer_sizes[i+1]))
        key, subkey = jax.random.split(key)
        b = jnp.zeros(layer_sizes[i+1])
        params.append({'W': W, 'b': b})
    return params

# Forward pass
def forward(params, x):
    """Forward pass through the network."""
    for i, layer in enumerate(params[:-1]):
        x = jnp.dot(x, layer['W']) + layer['b']
        x = jax.nn.relu(x)  # ReLU activation for hidden layers
    # Output layer (sigmoid activation)
    x = jnp.dot(x, params[-1]['W']) + params[-1]['b']
    return jax.nn.sigmoid(x)

# Weighted binary cross-entropy loss
def loss_fn(params, x, y, class_weights):
    """Binary cross-entropy with class weights."""
    predictions = forward(params, x).squeeze()
    # Apply class weights
    weights = jnp.where(y == 1, class_weights[1], class_weights[0])
    # Binary cross-entropy
    bce = -(y * jnp.log(predictions + 1e-7) + (1 - y) * jnp.log(1 - predictions + 1e-7))
    return jnp.mean(weights * bce)

# Prediction function
def predict(params, x, threshold=0.5):
    """Make predictions with threshold."""
    probs = forward(params, x).squeeze()
    return (probs >= threshold).astype(jnp.int32)

# Training step (JIT compiled)
@jax.jit
def update(params, x, y, class_weights, learning_rate):
    """Single training step with gradient descent."""
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y, class_weights)
    # Update parameters
    params = [
        {
            'W': layer['W'] - learning_rate * grad['W'],
            'b': layer['b'] - learning_rate * grad['b']
        }
        for layer, grad in zip(params, grads)
    ]
    return params, loss

# Initialize JAX model
print("\nüîß Initializing JAX model...")
layer_sizes = [input_dim] + hidden_dims + [output_dim]
jax_params = init_network_params(layer_sizes, jax.random.PRNGKey(42))
jax_class_weights = jnp.array([weight_normal, weight_fraud])

print(f"  Architecture: {' ‚Üí '.join(map(str, layer_sizes))}")
total_params = sum(layer['W'].size + layer['b'].size for layer in jax_params)
print(f"  Total parameters: {total_params:,}")

# Training loop
print("\nüèãÔ∏è  Training JAX model...")
jax_train_losses = []
jax_val_losses = []

# Convert to JAX arrays
X_train_jax = jnp.array(X_train)
y_train_jax = jnp.array(y_train, dtype=jnp.float32)
X_val_jax = jnp.array(X_val)
y_val_jax = jnp.array(y_val, dtype=jnp.float32)

start_time = time.time()

for epoch in range(n_epochs):
    # Shuffle training data
    perm = np.random.permutation(len(X_train_jax))
    X_shuffled = X_train_jax[perm]
    y_shuffled = y_train_jax[perm]
    
    # Mini-batch training
    epoch_losses = []
    for i in range(0, len(X_train_jax), batch_size):
        batch_X = X_shuffled[i:i+batch_size]
        batch_y = y_shuffled[i:i+batch_size]
        jax_params, batch_loss = update(jax_params, batch_X, batch_y, jax_class_weights, learning_rate)
        epoch_losses.append(batch_loss)
    
    # Compute validation loss
    train_loss = jnp.mean(jnp.array(epoch_losses))
    val_loss = loss_fn(jax_params, X_val_jax, y_val_jax, jax_class_weights)
    
    jax_train_losses.append(float(train_loss))
    jax_val_losses.append(float(val_loss))
    
    print(f"  Epoch {epoch+1:2d}/{n_epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

jax_train_time = time.time() - start_time
print(f"\n‚úÖ JAX training complete in {jax_train_time:.2f}s")

# Evaluate on test set
print("\nüìä JAX Test Set Evaluation:")
X_test_jax = jnp.array(X_test)
y_pred_jax = predict(jax_params, X_test_jax)
y_probs_jax = forward(jax_params, X_test_jax).squeeze()

# Convert to numpy for sklearn metrics
y_pred_jax_np = np.array(y_pred_jax)
y_probs_jax_np = np.array(y_probs_jax)

jax_precision = precision_score(y_test, y_pred_jax_np)
jax_recall = recall_score(y_test, y_pred_jax_np)
jax_f1 = f1_score(y_test, y_pred_jax_np)
jax_pr_auc = average_precision_score(y_test, y_probs_jax_np)
jax_roc_auc = roc_auc_score(y_test, y_probs_jax_np)

print(f"  Precision: {jax_precision:.4f}")
print(f"  Recall:    {jax_recall:.4f}")
print(f"  F1 Score:  {jax_f1:.4f}")
print(f"  PR-AUC:    {jax_pr_auc:.4f}")
print(f"  ROC-AUC:   {jax_roc_auc:.4f}")

print(f"\n  Confusion Matrix:")
cm_jax = confusion_matrix(y_test, y_pred_jax_np)
print(f"    TN: {cm_jax[0,0]:5d}  FP: {cm_jax[0,1]:5d}")
print(f"    FN: {cm_jax[1,0]:5d}  TP: {cm_jax[1,1]:5d}")

JAX NEURAL NETWORK - FUNCTIONAL APPROACH

üîß Initializing JAX model...
  Architecture: 29 ‚Üí 64 ‚Üí 32 ‚Üí 16 ‚Üí 1
  Total parameters: 4,545

üèãÔ∏è  Training JAX model...
  Epoch  1/10 - Train Loss: 0.5035, Val Loss: 0.3022
  Epoch  2/10 - Train Loss: 0.3023, Val Loss: 0.2485
  Epoch  1/10 - Train Loss: 0.5035, Val Loss: 0.3022
  Epoch  2/10 - Train Loss: 0.3023, Val Loss: 0.2485
  Epoch  3/10 - Train Loss: 0.2594, Val Loss: 0.2252
  Epoch  4/10 - Train Loss: 0.2316, Val Loss: 0.2039
  Epoch  3/10 - Train Loss: 0.2594, Val Loss: 0.2252
  Epoch  4/10 - Train Loss: 0.2316, Val Loss: 0.2039
  Epoch  5/10 - Train Loss: 0.2112, Val Loss: 0.1934
  Epoch  6/10 - Train Loss: 0.1954, Val Loss: 0.1868
  Epoch  5/10 - Train Loss: 0.2112, Val Loss: 0.1934
  Epoch  6/10 - Train Loss: 0.1954, Val Loss: 0.1868
  Epoch  7/10 - Train Loss: 0.1816, Val Loss: 0.1809
  Epoch  8/10 - Train Loss: 0.1711, Val Loss: 0.1757
  Epoch  7/10 - Train Loss: 0.1816, Val Loss: 0.1809
  Epoch  8/10 - Train Loss: 

In [4]:
# =============================================================================
# PYTORCH IMPLEMENTATION
# =============================================================================

print("\n" + "=" * 70)
print("PYTORCH NEURAL NETWORK - OBJECT-ORIENTED APPROACH")
print("=" * 70)

# Define PyTorch model
class FraudDetectionNet(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super().__init__()
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.ReLU())
            prev_dim = hidden_dim
        layers.append(nn.Linear(prev_dim, output_dim))
        layers.append(nn.Sigmoid())
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x).squeeze()

# Initialize PyTorch model
print("\nüîß Initializing PyTorch model...")
torch.manual_seed(42)
torch_model = FraudDetectionNet(input_dim, hidden_dims, output_dim)
torch_optimizer = optim.Adam(torch_model.parameters(), lr=learning_rate)

print(f"  Architecture: {input_dim} ‚Üí {' ‚Üí '.join(map(str, hidden_dims))} ‚Üí {output_dim}")
total_params = sum(p.numel() for p in torch_model.parameters())
print(f"  Total parameters: {total_params:,}")

# Weighted BCE loss
pos_weight = torch.tensor([weight_fraud / weight_normal])
criterion = nn.BCELoss(reduction='none')

# Convert to PyTorch tensors
X_train_torch = torch.FloatTensor(X_train)
y_train_torch = torch.FloatTensor(y_train)
X_val_torch = torch.FloatTensor(X_val)
y_val_torch = torch.FloatTensor(y_val)
X_test_torch = torch.FloatTensor(X_test)
y_test_torch = torch.FloatTensor(y_test)

# Create class weights tensor
class_weights_torch = torch.FloatTensor([weight_normal, weight_fraud])

# Training loop
print("\nüèãÔ∏è  Training PyTorch model...")
torch_train_losses = []
torch_val_losses = []

start_time = time.time()

for epoch in range(n_epochs):
    torch_model.train()
    
    # Shuffle training data
    perm = torch.randperm(len(X_train_torch))
    X_shuffled = X_train_torch[perm]
    y_shuffled = y_train_torch[perm]
    
    # Mini-batch training
    epoch_losses = []
    for i in range(0, len(X_train_torch), batch_size):
        batch_X = X_shuffled[i:i+batch_size]
        batch_y = y_shuffled[i:i+batch_size]
        
        # Forward pass
        torch_optimizer.zero_grad()
        predictions = torch_model(batch_X)
        
        # Compute weighted loss
        losses = criterion(predictions, batch_y)
        weights = torch.where(batch_y == 1, class_weights_torch[1], class_weights_torch[0])
        loss = (losses * weights).mean()
        
        # Backward pass
        loss.backward()
        torch_optimizer.step()
        
        epoch_losses.append(loss.item())
    
    # Validation
    torch_model.eval()
    with torch.no_grad():
        val_predictions = torch_model(X_val_torch)
        val_losses = criterion(val_predictions, y_val_torch)
        val_weights = torch.where(y_val_torch == 1, class_weights_torch[1], class_weights_torch[0])
        val_loss = (val_losses * val_weights).mean()
    
    train_loss = np.mean(epoch_losses)
    torch_train_losses.append(train_loss)
    torch_val_losses.append(val_loss.item())
    
    print(f"  Epoch {epoch+1:2d}/{n_epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

torch_train_time = time.time() - start_time
print(f"\n‚úÖ PyTorch training complete in {torch_train_time:.2f}s")

# Evaluate on test set
print("\nüìä PyTorch Test Set Evaluation:")
torch_model.eval()
with torch.no_grad():
    y_probs_torch = torch_model(X_test_torch).numpy()
    y_pred_torch = (y_probs_torch >= 0.5).astype(int)

torch_precision = precision_score(y_test, y_pred_torch)
torch_recall = recall_score(y_test, y_pred_torch)
torch_f1 = f1_score(y_test, y_pred_torch)
torch_pr_auc = average_precision_score(y_test, y_probs_torch)
torch_roc_auc = roc_auc_score(y_test, y_probs_torch)

print(f"  Precision: {torch_precision:.4f}")
print(f"  Recall:    {torch_recall:.4f}")
print(f"  F1 Score:  {torch_f1:.4f}")
print(f"  PR-AUC:    {torch_pr_auc:.4f}")
print(f"  ROC-AUC:   {torch_roc_auc:.4f}")

print(f"\n  Confusion Matrix:")
cm_torch = confusion_matrix(y_test, y_pred_torch)
print(f"    TN: {cm_torch[0,0]:5d}  FP: {cm_torch[0,1]:5d}")
print(f"    FN: {cm_torch[1,0]:5d}  TP: {cm_torch[1,1]:5d}")


PYTORCH NEURAL NETWORK - OBJECT-ORIENTED APPROACH

üîß Initializing PyTorch model...
  Architecture: 29 ‚Üí 64 ‚Üí 32 ‚Üí 16 ‚Üí 1
  Total parameters: 4,545

üèãÔ∏è  Training PyTorch model...
  Architecture: 29 ‚Üí 64 ‚Üí 32 ‚Üí 16 ‚Üí 1
  Total parameters: 4,545

üèãÔ∏è  Training PyTorch model...
  Epoch  1/10 - Train Loss: 0.2558, Val Loss: 0.1361
  Epoch  1/10 - Train Loss: 0.2558, Val Loss: 0.1361
  Epoch  2/10 - Train Loss: 0.1557, Val Loss: 0.1366
  Epoch  2/10 - Train Loss: 0.1557, Val Loss: 0.1366
  Epoch  3/10 - Train Loss: 0.1207, Val Loss: 0.1790
  Epoch  3/10 - Train Loss: 0.1207, Val Loss: 0.1790
  Epoch  4/10 - Train Loss: 0.1096, Val Loss: 0.1948
  Epoch  4/10 - Train Loss: 0.1096, Val Loss: 0.1948
  Epoch  5/10 - Train Loss: 0.1059, Val Loss: 0.1587
  Epoch  5/10 - Train Loss: 0.1059, Val Loss: 0.1587
  Epoch  6/10 - Train Loss: 0.0908, Val Loss: 0.1975
  Epoch  6/10 - Train Loss: 0.0908, Val Loss: 0.1975
  Epoch  7/10 - Train Loss: 0.0891, Val Loss: 0.1826
  Epoch 

In [5]:
# =============================================================================
# COMPARISON AND ANALYSIS
# =============================================================================

print("\n" + "=" * 70)
print("FINAL COMPARISON: JAX vs PYTORCH")
print("=" * 70)

print("\nüìä Performance Metrics:")
print(f"{'Metric':<15} {'JAX':>10} {'PyTorch':>10} {'Difference':>12}")
print("-" * 50)
print(f"{'Precision':<15} {jax_precision:>10.4f} {torch_precision:>10.4f} {abs(jax_precision-torch_precision):>12.4f}")
print(f"{'Recall':<15} {jax_recall:>10.4f} {torch_recall:>10.4f} {abs(jax_recall-torch_recall):>12.4f}")
print(f"{'F1 Score':<15} {jax_f1:>10.4f} {torch_f1:>10.4f} {abs(jax_f1-torch_f1):>12.4f}")
print(f"{'PR-AUC':<15} {jax_pr_auc:>10.4f} {torch_pr_auc:>10.4f} {abs(jax_pr_auc-torch_pr_auc):>12.4f}")
print(f"{'ROC-AUC':<15} {jax_roc_auc:>10.4f} {torch_roc_auc:>10.4f} {abs(jax_roc_auc-torch_roc_auc):>12.4f}")

print(f"\n‚è±Ô∏è  Training Time:")
print(f"  JAX:     {jax_train_time:.2f}s")
print(f"  PyTorch: {torch_train_time:.2f}s")
if jax_train_time < torch_train_time:
    print(f"  JAX is {torch_train_time/jax_train_time:.2f}x faster")
else:
    print(f"  PyTorch is {jax_train_time/torch_train_time:.2f}x faster")

print("\n" + "=" * 70)
print("KEY OBSERVATIONS")
print("=" * 70)
print("""
1. üìä Model Performance:
   Both frameworks achieve similar predictive performance on this real-world
   imbalanced dataset. The metrics (Precision, Recall, F1) are comparable,
   showing that both handle class-weighted loss effectively.

2. ‚è±Ô∏è  Training Speed:
   JAX's JIT compilation (@jax.jit on update function) provides faster
   training compared to standard PyTorch. The speedup is more pronounced
   with larger datasets and more complex models.

3. üíª Code Patterns:
   JAX: Functional style with explicit parameter passing. JIT compilation
        makes the update step extremely fast. Manual parameter management.
   
   PyTorch: Object-oriented with stateful modules. Automatic parameter
            tracking via nn.Module. Familiar to most ML practitioners.

4. üéØ Handling Imbalance:
   Both frameworks handle severe class imbalance (577:1) well with:
   - Class-weighted loss function
   - Proper evaluation metrics (F1, Precision, Recall, PR-AUC)
   - Stratified train/val/test splits

5. üöÄ Production Considerations:
   JAX: Better for research, custom algorithms, need for composability
   PyTorch: Better for production, larger ecosystem, easier debugging

6. üìà Scalability:
   Both scale well to this dataset size (284K samples). JAX's advantage
   grows with:
   - Larger batch sizes
   - More complex gradient operations (vmap for per-sample gradients)
   - Need for higher-order derivatives
""")

print("=" * 70)
print("CONCLUSION")
print("=" * 70)
print("""
On this real-world fraud detection task:

‚úÖ JAX Strengths:
   - Faster training (JIT compilation)
   - Functional composability (jit + grad + vmap)
   - Clean mathematical code
   - Better for research and custom algorithms

‚úÖ PyTorch Strengths:
   - Easier to learn and debug
   - Mature ecosystem (pretrained models, utilities)
   - Industry standard for production
   - Better documentation and community support

Both frameworks are excellent for production ML. Choose based on your
team's expertise and specific requirements rather than raw performance.
""")


FINAL COMPARISON: JAX vs PYTORCH

üìä Performance Metrics:
Metric                 JAX    PyTorch   Difference
--------------------------------------------------
Precision           0.0719     0.0694       0.0025
Recall              0.8514     0.8784       0.0270
F1 Score            0.1326     0.1287       0.0039
PR-AUC              0.6471     0.6472       0.0001
ROC-AUC             0.9536     0.9584       0.0049

‚è±Ô∏è  Training Time:
  JAX:     2.30s
  PyTorch: 9.17s
  JAX is 3.99x faster

KEY OBSERVATIONS

1. üìä Model Performance:
   Both frameworks achieve similar predictive performance on this real-world
   imbalanced dataset. The metrics (Precision, Recall, F1) are comparable,
   showing that both handle class-weighted loss effectively.

2. ‚è±Ô∏è  Training Speed:
   JAX's JIT compilation (@jax.jit on update function) provides faster
   training compared to standard PyTorch. The speedup is more pronounced
   with larger datasets and more complex models.

3. üíª Code Patterns