# AI-Generated Image Detection Demo

This notebook demonstrates the AI-generated image detection system with multiple methods.

In [None]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2

# Add parent directory to path
sys.path.append('..')

from src.detectors.frequency_detector import FrequencyDomainDetector
from src.detectors.statistical_detector import StatisticalDetector
from src.detectors.ensemble_detector import EnsembleDetector
from src.utils.image_processing import load_image

## 1. Load and Visualize an Image

In [None]:
# Load an image (replace with your image path)
image_path = '../data/test_image.jpg'  # Change this to your image

try:
    image = load_image(image_path)
    
    plt.figure(figsize=(10, 8))
    plt.imshow(image)
    plt.title('Input Image')
    plt.axis('off')
    plt.show()
    
    print(f"Image shape: {image.shape}")
except Exception as e:
    print(f"Error loading image: {e}")
    print("Please provide a valid image path")

## 2. Frequency Domain Analysis

In [None]:
freq_detector = FrequencyDomainDetector()

try:
    # Extract features
    freq_features = freq_detector.extract_features(image)
    
    # Get prediction
    freq_prob = freq_detector.predict_proba(freq_features)
    
    print("Frequency Domain Analysis:")
    print(f"  Features extracted: {len(freq_features)}")
    print(f"  AI-generated probability: {freq_prob:.2%}")
    print("\nFeature values:")
    for i, (name, value) in enumerate(zip(freq_detector.feature_names, freq_features)):
        print(f"  {name}: {value:.4f}")
except Exception as e:
    print(f"Error: {e}")

## 3. Statistical Feature Analysis

In [None]:
stat_detector = StatisticalDetector()

try:
    # Extract features
    stat_features = stat_detector.extract_features(image)
    
    # Get prediction
    stat_prob = stat_detector.predict_proba(stat_features)
    
    print("Statistical Analysis:")
    print(f"  Features extracted: {len(stat_features)}")
    print(f"  AI-generated probability: {stat_prob:.2%}")
    print("\nFeature values:")
    for i, (name, value) in enumerate(zip(stat_detector.feature_names, stat_features)):
        print(f"  {name}: {value:.4f}")
except Exception as e:
    print(f"Error: {e}")

## 4. Ensemble Detection

In [None]:
# Initialize ensemble detector (without CNN for faster demo)
ensemble_detector = EnsembleDetector(use_cnn=False)

try:
    # Get prediction
    result = ensemble_detector.predict(image, threshold=0.5)
    
    print("="*60)
    print("ENSEMBLE DETECTION RESULT")
    print("="*60)
    
    if result['is_ai_generated']:
        print("ðŸ¤– Verdict: AI-GENERATED")
    else:
        print("ðŸ“· Verdict: REAL IMAGE")
    
    print(f"\nProbability: {result['probability']:.2%}")
    print(f"Confidence: {result['confidence']:.2%}")
    print(f"\nIndividual Predictions:")
    for method, prob in result['individual_predictions'].items():
        if prob is not None:
            print(f"  {method.capitalize()}: {prob:.2%}")
    print("="*60)
except Exception as e:
    print(f"Error: {e}")

## 5. Visualize Frequency Domain

In [None]:
try:
    # Convert to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    
    # Compute FFT
    f = np.fft.fft2(gray)
    fshift = np.fft.fftshift(f)
    magnitude_spectrum = np.log(np.abs(fshift) + 1)
    
    # Compute DCT
    dct = cv2.dct(np.float32(gray))
    dct_log = np.log(np.abs(dct) + 1)
    
    # Plot
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(gray, cmap='gray')
    axes[0].set_title('Original (Grayscale)')
    axes[0].axis('off')
    
    axes[1].imshow(magnitude_spectrum, cmap='viridis')
    axes[1].set_title('FFT Magnitude Spectrum')
    axes[1].axis('off')
    
    axes[2].imshow(dct_log, cmap='viridis')
    axes[2].set_title('DCT Coefficients')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
except Exception as e:
    print(f"Error: {e}")

## 6. Compare Multiple Images

In [None]:
# List of images to compare
image_paths = [
    '../data/real/image1.jpg',
    '../data/ai_generated/image1.jpg',
    # Add more paths here
]

ensemble_detector = EnsembleDetector(use_cnn=False)

results = []

for img_path in image_paths:
    if not os.path.exists(img_path):
        print(f"Skipping {img_path} (not found)")
        continue
    
    try:
        img = load_image(img_path)
        result = ensemble_detector.predict(img)
        results.append({
            'path': img_path,
            'image': img,
            'result': result
        })
    except Exception as e:
        print(f"Error processing {img_path}: {e}")

# Visualize results
if results:
    fig, axes = plt.subplots(1, len(results), figsize=(5*len(results), 5))
    if len(results) == 1:
        axes = [axes]
    
    for i, r in enumerate(results):
        axes[i].imshow(r['image'])
        verdict = "AI" if r['result']['is_ai_generated'] else "Real"
        prob = r['result']['probability']
        axes[i].set_title(f"{verdict}\n({prob:.1%})")
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()
else:
    print("No images to compare. Please update image_paths with valid image files.")

## 7. Analyze Feature Importance

In [None]:
# This requires a trained ML model
ml_model_path = '../trained_models/ml_model_random_forest.pkl'

if os.path.exists(ml_model_path):
    detector = EnsembleDetector(ml_model_path=ml_model_path, use_cnn=False)
    importance = detector.get_feature_importance()
    
    if importance is not None:
        # Combine feature names
        all_feature_names = (freq_detector.feature_names + 
                           stat_detector.feature_names)
        
        # Sort by importance
        indices = np.argsort(importance)[::-1][:15]  # Top 15 features
        
        plt.figure(figsize=(12, 6))
        plt.barh(range(len(indices)), importance[indices])
        plt.yticks(range(len(indices)), [all_feature_names[i] for i in indices])
        plt.xlabel('Feature Importance')
        plt.title('Top 15 Most Important Features')
        plt.gca().invert_yaxis()
        plt.tight_layout()
        plt.show()
    else:
        print("Model doesn't support feature importance")
else:
    print(f"ML model not found at {ml_model_path}")
    print("Train a model first using: python src/train.py ...")

## Summary

This notebook demonstrated:
1. Loading and visualizing images
2. Frequency domain analysis
3. Statistical feature extraction
4. Ensemble detection combining multiple methods
5. Visualizing frequency domain representations
6. Comparing multiple images
7. Analyzing feature importance

For training models and achieving >80% accuracy, use the training scripts:
```bash
python src/train.py --data_dir data --model_type both --epochs 50
```