# AI Image Authenticity Checker - Analysis Notebook

This notebook provides comprehensive analysis and visualization for detecting AI-generated images.

## Contents
1. Setup and Data Loading
2. Feature Extraction and Analysis
3. FFT Spectrum Visualization
4. Model Training
5. Evaluation and Metrics
6. Feature Importance Analysis
7. Error Analysis
8. Interactive Prediction

In [None]:
# Setup and Imports
import sys
import os
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path().absolute().parent
sys.path.insert(0, str(PROJECT_ROOT))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

# Project imports
from config import DATA_DIR, REAL_IMAGES_DIR, FAKE_IMAGES_DIR, MODEL_DIR
from features.fft_features import FFTFeatureExtractor
from features.ela_features import ELAFeatureExtractor
from features.texture_features import TextureFeatureExtractor
from features.feature_fusion import FeatureFusion
from model.classifier import AIImageClassifier
from model.trainer import ModelTrainer

# Plotting config
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['font.size'] = 12

print(f"Project Root: {PROJECT_ROOT}")
print(f"Real Images: {REAL_IMAGES_DIR}")
print(f"Fake Images: {FAKE_IMAGES_DIR}")

## 1. Dataset Overview

In [None]:
# Check dataset statistics
from data.download_datasets import DatasetDownloader

downloader = DatasetDownloader()
stats = downloader.get_dataset_stats()

print("=" * 50)
print("DATASET STATISTICS")
print("=" * 50)
print(f"Real images: {stats['real_images']}")
print(f"Fake images: {stats['fake_images']}")
print(f"Total: {stats['total']}")

# Visualize class distribution
fig, ax = plt.subplots(figsize=(8, 6))
categories = ['Real', 'AI-Generated']
counts = [stats['real_images'], stats['fake_images']]
colors = ['#2ecc71', '#e74c3c']

bars = ax.bar(categories, counts, color=colors, edgecolor='black', linewidth=1.5)
ax.set_ylabel('Number of Images')
ax.set_title('Dataset Class Distribution')

for bar, count in zip(bars, counts):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 5, 
            str(count), ha='center', va='bottom', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
# Download sample dataset if needed
if stats['total'] < 10:
    print("Downloading sample dataset...")
    downloader.download_sample_dataset(num_real=50, num_fake=50)
    stats = downloader.get_dataset_stats()
    print(f"\nDataset now contains {stats['total']} images")

## 2. Feature Extraction and Analysis

In [None]:
# Initialize feature extractors
fft_extractor = FFTFeatureExtractor()
ela_extractor = ELAFeatureExtractor()
texture_extractor = TextureFeatureExtractor()
fusion = FeatureFusion(include_deep=False)

# Collect sample images
supported = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
real_images = [f for f in REAL_IMAGES_DIR.iterdir() if f.suffix.lower() in supported][:20]
fake_images = [f for f in FAKE_IMAGES_DIR.iterdir() if f.suffix.lower() in supported][:20]

print(f"Sample real images: {len(real_images)}")
print(f"Sample fake images: {len(fake_images)}")

In [None]:
# Extract features for visualization
real_fft_features = []
fake_fft_features = []

for img_path in tqdm(real_images, desc="Extracting real FFT features"):
    try:
        features = fft_extractor.extract(img_path)
        real_fft_features.append(features)
    except Exception as e:
        print(f"Error: {e}")

for img_path in tqdm(fake_images, desc="Extracting fake FFT features"):
    try:
        features = fft_extractor.extract(img_path)
        fake_fft_features.append(features)
    except Exception as e:
        print(f"Error: {e}")

print(f"\nExtracted features from {len(real_fft_features)} real and {len(fake_fft_features)} fake images")

## 3. FFT Spectrum Visualization

In [None]:
# Compare radial power profiles
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Real images radial profiles
for features in real_fft_features[:10]:
    axes[0].plot(features.azimuthal_profile, alpha=0.5, color='green')
axes[0].set_title('FFT Radial Power Profile - Real Images', fontsize=14)
axes[0].set_xlabel('Frequency Bin')
axes[0].set_ylabel('Normalized Power')

# Fake images radial profiles  
for features in fake_fft_features[:10]:
    axes[1].plot(features.azimuthal_profile, alpha=0.5, color='red')
axes[1].set_title('FFT Radial Power Profile - AI-Generated Images', fontsize=14)
axes[1].set_xlabel('Frequency Bin')
axes[1].set_ylabel('Normalized Power')

plt.tight_layout()
plt.show()

In [None]:
# Compare key FFT features
real_metrics = {
    'spectral_ratio': [f.spectral_ratio for f in real_fft_features],
    'spectral_entropy': [f.spectral_entropy for f in real_fft_features],
    'spectral_flatness': [f.spectral_flatness for f in real_fft_features],
    'gan_fingerprint': [f.gan_fingerprint_score for f in real_fft_features],
}

fake_metrics = {
    'spectral_ratio': [f.spectral_ratio for f in fake_fft_features],
    'spectral_entropy': [f.spectral_entropy for f in fake_fft_features],
    'spectral_flatness': [f.spectral_flatness for f in fake_fft_features],
    'gan_fingerprint': [f.gan_fingerprint_score for f in fake_fft_features],
}

# Create comparison boxplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

for ax, metric in zip(axes.flat, real_metrics.keys()):
    data = [real_metrics[metric], fake_metrics[metric]]
    bp = ax.boxplot(data, labels=['Real', 'AI-Generated'], patch_artist=True)
    bp['boxes'][0].set_facecolor('#2ecc71')
    bp['boxes'][1].set_facecolor('#e74c3c')
    ax.set_title(f'{metric.replace("_", " ").title()}', fontsize=12)
    ax.set_ylabel('Value')

plt.suptitle('FFT Feature Comparison: Real vs AI-Generated', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 4. Model Training

In [None]:
# Initialize trainer
trainer = ModelTrainer(
    real_dir=REAL_IMAGES_DIR,
    fake_dir=FAKE_IMAGES_DIR,
    include_deep=False
)

# Load dataset (limit for quick training)
X, y = trainer.load_dataset(max_samples=100)

print(f"\nFeature matrix shape: {X.shape}")
print(f"Labels shape: {y.shape}")
print(f"Class distribution: {np.bincount(y)}")

In [None]:
# Split data
trainer.split_data(X, y)

print(f"Training samples: {len(trainer.X_train)}")
print(f"Validation samples: {len(trainer.X_val)}")
print(f"Test samples: {len(trainer.X_test)}")

In [None]:
# Train Random Forest classifier
results = trainer.train(algorithm='rf')

print("\n" + "="*50)
print("TRAINING RESULTS")
print("="*50)
print(f"Algorithm: {results['algorithm']}")
print(f"Test Accuracy: {results['test_accuracy']:.4f}")
print(f"Test AUC: {results['test_auc']:.4f}")
print(f"Model saved: {results['model_path']}")

## 5. Evaluation and Metrics

In [None]:
# Load trained model and evaluate
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc, precision_recall_curve

model = AIImageClassifier.load(results['model_path'])

# Predictions
y_pred = model.predict(trainer.X_test)
y_prob = model.predict_proba(trainer.X_test)[:, 1]

# Confusion Matrix
cm = confusion_matrix(trainer.y_test, y_pred)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Confusion matrix heatmap
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0],
            xticklabels=['Real', 'AI-Generated'],
            yticklabels=['Real', 'AI-Generated'])
axes[0].set_title('Confusion Matrix', fontsize=14)
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('Actual')

# ROC Curve
fpr, tpr, _ = roc_curve(trainer.y_test, y_prob)
roc_auc = auc(fpr, tpr)

axes[1].plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
axes[1].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
axes[1].set_xlim([0.0, 1.0])
axes[1].set_ylim([0.0, 1.05])
axes[1].set_xlabel('False Positive Rate')
axes[1].set_ylabel('True Positive Rate')
axes[1].set_title('ROC Curve', fontsize=14)
axes[1].legend(loc='lower right')

plt.tight_layout()
plt.show()

# Classification report
print("\nClassification Report:")
print(classification_report(trainer.y_test, y_pred, target_names=['Real', 'AI-Generated']))

In [None]:
# Precision-Recall Curve
precision, recall, thresholds = precision_recall_curve(trainer.y_test, y_prob)

fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(recall, precision, color='blue', lw=2)
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_title('Precision-Recall Curve', fontsize=14)
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 6. Feature Importance Analysis

In [None]:
# Get feature importance
if model.feature_importance_ is not None:
    feature_names = fusion.get_feature_names()[:len(model.feature_importance_)]
    importance = model.feature_importance_
    
    # Top 20 features
    top_k = 20
    top_indices = np.argsort(importance)[-top_k:][::-1]
    
    top_features = [feature_names[i] if i < len(feature_names) else f'feature_{i}' for i in top_indices]
    top_importance = importance[top_indices]
    
    # Plot
    fig, ax = plt.subplots(figsize=(10, 8))
    y_pos = np.arange(len(top_features))
    
    ax.barh(y_pos, top_importance, color='steelblue', edgecolor='black')
    ax.set_yticks(y_pos)
    ax.set_yticklabels(top_features)
    ax.invert_yaxis()
    ax.set_xlabel('Importance')
    ax.set_title('Top 20 Most Important Features', fontsize=14)
    
    plt.tight_layout()
    plt.show()
else:
    print("Feature importance not available for this model type")

## 7. Interactive Prediction

In [None]:
from inference.predict import ImagePredictor

# Initialize predictor
predictor = ImagePredictor(model_path=results['model_path'])

# Test on a sample image
test_images = list(REAL_IMAGES_DIR.glob('*.jpg'))[:3] + list(FAKE_IMAGES_DIR.glob('*.png'))[:3]

if test_images:
    print("Sample Predictions:")
    print("=" * 60)
    
    for img_path in test_images:
        try:
            result = predictor.predict(img_path)
            print(f"\nImage: {img_path.name}")
            print(f"Prediction: {result.prediction}")
            print(f"Confidence: {result.confidence:.2%} ({result.confidence_level})")
        except Exception as e:
            print(f"Error predicting {img_path.name}: {e}")
else:
    print("No test images found")

## 8. Visualize Individual Predictions

In [None]:
import cv2
from PIL import Image

# Visualize predictions
if test_images:
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    for ax, img_path in zip(axes.flat, test_images):
        try:
            # Load and display image
            img = Image.open(img_path)
            ax.imshow(img)
            
            # Get prediction
            result = predictor.predict(img_path)
            
            # Color based on prediction
            color = 'green' if result.prediction == 'Real' else 'red'
            
            ax.set_title(f"{result.prediction}\n{result.confidence:.1%} confidence", 
                        color=color, fontsize=12, fontweight='bold')
            ax.axis('off')
            
        except Exception as e:
            ax.set_title(f"Error: {e}")
            ax.axis('off')
    
    plt.suptitle('Sample Predictions', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

## 9. Summary

This notebook demonstrated:

1. **Dataset Loading** - Loaded real and AI-generated images
2. **Feature Extraction** - Extracted FFT, ELA, texture, and noise features
3. **FFT Analysis** - Compared frequency spectra between real and fake images
4. **Model Training** - Trained Random Forest classifier
5. **Evaluation** - Generated ROC curves, confusion matrix, and classification reports
6. **Feature Importance** - Identified most discriminative features
7. **Interactive Prediction** - Made predictions on new images

### Key Findings
- FFT spectral features show distinct patterns between real and AI-generated images
- GAN fingerprint detection captures periodic upsampling artifacts
- Ensemble methods typically achieve the best accuracy

### Next Steps
- Train on larger datasets (COCO, DiffusionDB)
- Experiment with deep learning features
- Deploy as API service