<a href="https://colab.research.google.com/github/Nandhini1008/classification/blob/main/predicting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ===============================================================================
# IMPROVED DEEPFAKE PREDICTOR WITH CONSISTENT PROCESSING
# ===============================================================================

import torch
import torch.nn as nn
import pickle
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import cv2
from pathlib import Path
import io
import os
import warnings
import joblib
import dill
import hashlib
warnings.filterwarnings('ignore')

# ===============================================================================
# EXACT MODEL ARCHITECTURE (Must match training)
# ===============================================================================

class SimpleClassifier(nn.Module):
    """CNN classifier - MUST exactly match your trained model architecture"""
    def __init__(self, num_classes=3):
        super(SimpleClassifier, self).__init__()

        self.features = nn.Sequential(
            # Conv Block 1
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            # Conv Block 2
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            # Conv Block 3
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            # Conv Block 4
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# ===============================================================================
# CONSISTENT IMAGE PREPROCESSOR
# ===============================================================================

class ConsistentImagePreprocessor:
    """
    Ensures consistent image preprocessing that matches training pipeline
    """
    def __init__(self, target_size=(224, 224), normalization='imagenet'):
        self.target_size = target_size
        self.normalization = normalization

        # Define exact preprocessing that should match training
        if normalization == 'imagenet':
            # Standard ImageNet normalization
            self.transform = transforms.Compose([
                transforms.Resize(target_size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])
            ])
        elif normalization == 'custom':
            # Custom normalization (modify based on your training)
            self.transform = transforms.Compose([
                transforms.Resize(target_size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                   std=[0.5, 0.5, 0.5])
            ])
        else:
            # No normalization - just resize and convert to tensor
            self.transform = transforms.Compose([
                transforms.Resize(target_size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.ToTensor()
            ])

    def preprocess_image(self, image_path):
        """
        Preprocess image consistently
        """
        try:
            # Load image
            if isinstance(image_path, str):
                image = Image.open(image_path).convert('RGB')
                print(f"Loaded image: {os.path.basename(image_path)} - Size: {image.size}")
            else:
                image = image_path
                print(f"Using provided image - Size: {image.size}")

            # Apply preprocessing
            tensor = self.transform(image)

            # Log preprocessing details
            print(f"Preprocessed tensor shape: {tensor.shape}")
            print(f"Tensor range: [{tensor.min():.3f}, {tensor.max():.3f}]")
            print(f"Tensor mean: {tensor.mean():.3f}, std: {tensor.std():.3f}")

            return tensor

        except Exception as e:
            print(f"Error preprocessing image: {e}")
            return None

    def create_tensor_sample(self, image_path, label=None):
        """
        Create tensor sample in the exact format used during training
        """
        tensor = self.preprocess_image(image_path)
        if tensor is None:
            return None

        sample = {
            'pixel_values': tensor,
        }

        if label is not None:
            sample['labels'] = label

        return sample

# ===============================================================================
# ROBUST MODEL LOADER
# ===============================================================================

class RobustModelLoader:
    """
    Handles different model file formats and corruption
    """

    @staticmethod
    def load_model_safely(model_path, model_class=SimpleClassifier, num_classes=3):
        """
        Safely load model with multiple fallback methods
        """
        print(f"Attempting to load model from: {model_path}")

        if not os.path.exists(model_path):
            print(f"Model file not found: {model_path}")
            return None, "file_not_found"

        # Get file hash for consistency check
        file_hash = RobustModelLoader._get_file_hash(model_path)
        print(f"Model file hash: {file_hash}")

        # Try different loading methods
        loading_methods = [
            ("PyTorch state_dict", RobustModelLoader._load_pytorch_state_dict),
            ("PyTorch model object", RobustModelLoader._load_pytorch_model),
            ("Pickle state_dict", RobustModelLoader._load_pickle_state_dict),
            ("Pickle model object", RobustModelLoader._load_pickle_model),
            ("Joblib", RobustModelLoader._load_joblib),
            ("Manual reconstruction", RobustModelLoader._manual_load)
        ]

        for method_name, method_func in loading_methods:
            print(f"Trying method: {method_name}")
            try:
                model = method_func(model_path, model_class, num_classes)
                if model is not None:
                    print(f"Successfully loaded model using: {method_name}")
                    return model, method_name
            except Exception as e:
                print(f"Failed with {method_name}: {str(e)[:100]}")

        print("All loading methods failed")
        return None, "failed"

    @staticmethod
    def _get_file_hash(file_path):
        """Get file hash for consistency checking"""
        with open(file_path, 'rb') as f:
            content = f.read()
        return hashlib.md5(content).hexdigest()[:8]

    @staticmethod
    def _load_pytorch_state_dict(model_path, model_class, num_classes):
        """Load as PyTorch state dict"""
        state_dict = torch.load(model_path, map_location='cpu')
        if isinstance(state_dict, dict):
            model = model_class(num_classes=num_classes)
            model.load_state_dict(state_dict)
            return model
        return None

    @staticmethod
    def _load_pytorch_model(model_path, model_class, num_classes):
        """Load as complete PyTorch model"""
        model = torch.load(model_path, map_location='cpu')
        if hasattr(model, 'forward'):
            return model
        return None

    @staticmethod
    def _load_pickle_state_dict(model_path, model_class, num_classes):
        """Load as pickled state dict"""
        with open(model_path, 'rb') as f:
            state_dict = pickle.load(f)
        if isinstance(state_dict, dict):
            model = model_class(num_classes=num_classes)
            model.load_state_dict(state_dict)
            return model
        return None

    @staticmethod
    def _load_pickle_model(model_path, model_class, num_classes):
        """Load as pickled model"""
        with open(model_path, 'rb') as f:
            model = pickle.load(f)
        if hasattr(model, 'forward'):
            return model
        return None

    @staticmethod
    def _load_joblib(model_path, model_class, num_classes):
        """Load using joblib"""
        loaded = joblib.load(model_path)
        if hasattr(loaded, 'forward'):
            return loaded
        elif isinstance(loaded, dict):
            model = model_class(num_classes=num_classes)
            model.load_state_dict(loaded)
            return model
        return None

    @staticmethod
    def _manual_load(model_path, model_class, num_classes):
        """Manual loading with error handling"""
        # Try to read file as binary and extract useful parts
        with open(model_path, 'rb') as f:
            data = f.read()

        # Try different byte offsets in case of header corruption
        for offset in [0, 8, 16, 32, 64]:
            try:
                buffer = io.BytesIO(data[offset:])
                loaded = pickle.load(buffer)
                if isinstance(loaded, dict):
                    model = model_class(num_classes=num_classes)
                    model.load_state_dict(loaded)
                    return model
                elif hasattr(loaded, 'forward'):
                    return loaded
            except:
                continue

        return None

# ===============================================================================
# CONSISTENT DEEPFAKE PREDICTOR
# ===============================================================================

class ConsistentDeepfakePredictor:
    """
    Consistent predictor that ensures reproducible results
    """
    def __init__(self, model_path, preprocessing_config=None, device=None):
        self.model_path = model_path
        self.device = device or torch.device('cpu')
        self.class_names = ['AI-generated', 'Deepfake', 'Real']

        # Set up consistent preprocessing
        if preprocessing_config is None:
            preprocessing_config = {
                'target_size': (224, 224),
                'normalization': 'imagenet'  # Change this to match your training
            }

        self.preprocessor = ConsistentImagePreprocessor(**preprocessing_config)
        self.model = None
        self.loading_method = None
        self.model_hash = None

        # Load model
        self._load_model()

        # Set deterministic behavior
        torch.manual_seed(42)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(42)

    def _load_model(self):
        """Load model with error handling"""
        print("="*60)
        print("CONSISTENT MODEL LOADING")
        print("="*60)

        self.model, self.loading_method = RobustModelLoader.load_model_safely(
            self.model_path, SimpleClassifier, 3
        )

        if self.model is None:
            print("WARNING: Could not load trained model. Creating untrained fallback.")
            print("This will give random predictions!")
            self.model = SimpleClassifier(num_classes=3)
            self.loading_method = "untrained_fallback"

        # Ensure model is in evaluation mode
        self.model.eval()
        self.model = self.model.to(self.device)

        # Disable dropout for consistent results
        for module in self.model.modules():
            if isinstance(module, nn.Dropout):
                module.p = 0.0

        print(f"Model loading status: {self.loading_method}")
        print("="*60)

    def predict_image(self, image_path, show_details=True):
        """
        Make consistent prediction on image
        """
        if show_details:
            print(f"\nPREDICTING: {os.path.basename(image_path)}")
            print("-" * 50)
            print(f"Model status: {self.loading_method}")

        if not os.path.exists(image_path):
            print(f"Image not found: {image_path}")
            return None

        # Preprocess image
        tensor = self.preprocessor.preprocess_image(image_path)
        if tensor is None:
            return None

        # Make prediction
        batch_tensor = tensor.unsqueeze(0).to(self.device)

        with torch.no_grad():
            outputs = self.model(batch_tensor)
            probabilities = torch.softmax(outputs, dim=1)
            predicted_class = torch.argmax(probabilities, dim=1).item()
            confidence_scores = probabilities[0].cpu().numpy()

        # Results
        result = {
            'predicted_class': predicted_class,
            'predicted_label': self.class_names[predicted_class],
            'confidence': confidence_scores[predicted_class],
            'all_probabilities': confidence_scores,
            'model_status': self.loading_method
        }

        if show_details:
            print(f"Prediction: {self.class_names[predicted_class]}")
            print(f"Confidence: {confidence_scores[predicted_class]*100:.2f}%")

            if self.loading_method != "untrained_fallback":
                print("\nAll probabilities:")
                for i, (class_name, prob) in enumerate(zip(self.class_names, confidence_scores)):
                    bar = "█" * int(prob * 20)
                    print(f"  {class_name:12}: {prob*100:6.2f}% |{bar}")
            else:
                print("WARNING: Using untrained model - predictions are random!")

            # Show image with prediction
            self._display_result(image_path, result)

        return result

    def _display_result(self, image_path, result):
        """Display image with prediction"""
        try:
            image = Image.open(image_path)

            plt.figure(figsize=(10, 6))
            plt.imshow(image)
            plt.axis('off')

            title = f"Prediction: {result['predicted_label']} ({result['confidence']*100:.1f}%)\n"
            title += f"Status: {result['model_status']}"

            if result['model_status'] == "untrained_fallback":
                title += " (RANDOM PREDICTIONS)"

            plt.title(title, fontsize=12, pad=20)
            plt.tight_layout()
            plt.show()

        except Exception as e:
            print(f"Could not display image: {e}")

    def create_tensor_samples(self, image_paths, labels=None):
        """
        Create consistent tensor samples for training/testing
        """
        print("Creating consistent tensor samples...")
        tensor_samples = []

        for i, image_path in enumerate(image_paths):
            label = labels[i] if labels is not None else None
            sample = self.preprocessor.create_tensor_sample(image_path, label)
            if sample is not None:
                tensor_samples.append(sample)
                print(f"Processed: {os.path.basename(image_path)}")

        print(f"Created {len(tensor_samples)} tensor samples")
        return tensor_samples

    def test_consistency(self, image_path, num_runs=5):
        """
        Test prediction consistency across multiple runs
        """
        print(f"\nTESTING CONSISTENCY: {os.path.basename(image_path)}")
        print("=" * 60)

        results = []

        for run in range(num_runs):
            print(f"Run {run + 1}:")
            result = self.predict_image(image_path, show_details=False)
            if result:
                results.append(result)
                print(f"  {result['predicted_label']} ({result['confidence']*100:.2f}%)")

        # Check consistency
        if results:
            predictions = [r['predicted_label'] for r in results]
            confidences = [r['confidence'] for r in results]

            unique_predictions = set(predictions)
            consistent = len(unique_predictions) == 1

            print(f"\nConsistency Analysis:")
            print(f"  Unique predictions: {len(unique_predictions)}")
            print(f"  Consistent: {'Yes' if consistent else 'No'}")
            print(f"  Confidence std: {np.std(confidences):.4f}")

            if consistent:
                print(f"  Stable prediction: {predictions[0]}")
            else:
                print(f"  Prediction counts: {dict(zip(*np.unique(predictions, return_counts=True)))}")

# ===============================================================================
# PTH FILE SPECIFIC HANDLER
# ===============================================================================

class PTHModelHandler:
    """
    Specialized handler for PyTorch .pth files
    """

    @staticmethod
    def inspect_pth_file(model_path):
        """
        Inspect .pth file contents and structure
        """
        print(f"INSPECTING PTH FILE: {model_path}")
        print("=" * 60)

        try:
            checkpoint = torch.load(model_path, map_location='cpu')

            print(f"File type: {type(checkpoint)}")

            if isinstance(checkpoint, dict):
                print(f"Dictionary keys: {list(checkpoint.keys())}")

                # Check each key
                for key, value in checkpoint.items():
                    print(f"  {key}: {type(value)}")

                    if isinstance(value, dict):
                        print(f"    -> Dict with {len(value)} items")
                        if len(value) < 20:  # Show keys if not too many
                            print(f"    -> Keys: {list(value.keys())[:10]}")
                    elif hasattr(value, 'shape'):
                        print(f"    -> Shape: {value.shape}")
                    elif isinstance(value, (int, float, str)):
                        print(f"    -> Value: {value}")

            elif hasattr(checkpoint, 'state_dict'):
                print("Complete model object detected")
                state_dict = checkpoint.state_dict()
                print(f"State dict keys: {len(state_dict)} parameters")

            return checkpoint

        except Exception as e:
            print(f"Error inspecting file: {e}")
            return None

    @staticmethod
    def load_pth_model(model_path, model_class, num_classes=3):
        """
        Load model from .pth file with comprehensive handling
        """
        print(f"LOADING PTH MODEL: {model_path}")
        print("=" * 50)

        try:
            checkpoint = torch.load(model_path, map_location='cpu')
            print(f"Loaded checkpoint type: {type(checkpoint)}")

            # Method 1: Direct model object
            if hasattr(checkpoint, 'forward') and hasattr(checkpoint, 'state_dict'):
                print("Method: Direct model object")
                return checkpoint.eval()

            # Method 2: Dictionary with known keys
            elif isinstance(checkpoint, dict):
                # Try different common keys
                state_dict_keys = ['state_dict', 'model_state_dict', 'model', 'net', 'network']

                for key in state_dict_keys:
                    if key in checkpoint:
                        print(f"Method: State dict from key '{key}'")
                        state_dict = checkpoint[key]

                        if isinstance(state_dict, dict):
                            model = model_class(num_classes=num_classes)
                            model.load_state_dict(state_dict)
                            return model.eval()

                # Try treating the whole dict as state_dict
                print("Method: Direct state dict")
                try:
                    model = model_class(num_classes=num_classes)
                    model.load_state_dict(checkpoint)
                    return model.eval()
                except Exception as e:
                    print(f"Direct state dict failed: {e}")

            # Method 3: If it's already a tensor dict, try direct loading
            elif isinstance(checkpoint, dict) and all(isinstance(v, torch.Tensor) for v in checkpoint.values()):
                print("Method: Pure tensor dictionary")
                model = model_class(num_classes=num_classes)
                model.load_state_dict(checkpoint)
                return model.eval()

        except Exception as e:
            print(f"PTH loading error: {e}")

        return None

# ===============================================================================
# ENHANCED PREDICTOR WITH PTH SUPPORT
# ===============================================================================

class EnhancedPTHPredictor(ConsistentDeepfakePredictor):
    """
    Enhanced predictor with specialized .pth file support
    """

    def _load_model(self):
        """Enhanced model loading with PTH file inspection"""
        print("=" * 60)
        print("ENHANCED PTH MODEL LOADING")
        print("=" * 60)

        # First inspect the file if it's a .pth file
        if self.model_path.endswith('.pth'):
            print("Detected .pth file - running inspection...")
            PTHModelHandler.inspect_pth_file(self.model_path)
            print()

            # Try PTH-specific loading first
            self.model = PTHModelHandler.load_pth_model(
                self.model_path, SimpleClassifier, 3
            )

            if self.model is not None:
                self.loading_method = "pth_specialized"
                print("Successfully loaded using PTH-specific handler")
            else:
                print("PTH-specific loading failed, trying general methods...")
                # Fall back to general loading methods
                super()._load_model()
        else:
            # Use general loading for non-.pth files
            super()._load_model()

        print(f"Final model loading status: {self.loading_method}")
        print("=" * 60)

# ===============================================================================
# USAGE EXAMPLES
# ===============================================================================

def run_pth_prediction_pipeline(model_path, test_images):
    """
    Run prediction pipeline optimized for .pth files
    """
    print("LAUNCHING PTH-OPTIMIZED DEEPFAKE PREDICTION PIPELINE")
    print("=" * 70)

    # Initialize enhanced predictor
    predictor = EnhancedPTHPredictor(
        model_path=model_path,
        preprocessing_config={
            'target_size': (224, 224),
            'normalization': 'imagenet'  # Adjust based on your training
        }
    )

    # Test each image
    for image_path in test_images:
        if os.path.exists(image_path):
            print(f"\nProcessing: {os.path.basename(image_path)}")

            # Single prediction
            result = predictor.predict_image(image_path)

            # Test consistency (important for verifying model loaded correctly)
            predictor.test_consistency(image_path, num_runs=3)

            print("\n" + "=" * 70)
        else:
            print(f"Image not found: {image_path}")

    return predictor

def inspect_and_predict(model_path, image_path):
    """
    Quick function to inspect model and make single prediction
    """
    print("QUICK INSPECT AND PREDICT")
    print("=" * 40)

    # Inspect model file
    if model_path.endswith('.pth'):
        PTHModelHandler.inspect_pth_file(model_path)

    # Make prediction
    predictor = EnhancedPTHPredictor(model_path)
    result = predictor.predict_image(image_path)

    return result

# Example usage for .pth files:
if __name__ == "__main__":
    # For .pth files
    model_path = "/content/sample_data/deepfake_classifier_simple.pth"  # Change this to your .pth file path
    test_images = ["/content/sample_data/chk.jpg", "/content/sample_data/chk2.jpg"]

    # Run full pipeline
    predictor = run_pth_prediction_pipeline(model_path, test_images)

    # Or quick inspect and predict
    # result = inspect_and_predict(model_path, test_images[0])

    print("\nPTH PREDICTION PIPELINE COMPLETED!")

    # For your specific files (update paths as needed):
    """
    model_path = "/content/sample_data/deepfake_classifier_simple.pth"
    test_images = ["/content/sample_data/chk.jpg", "/content/sample_data/chk2.jpg"]
    predictor = run_pth_prediction_pipeline(model_path, test_images)
    """

💡 USAGE EXAMPLES:

# Initialize predictor with your .pkl file:
predictor = DeepfakePredictor('your_model.pkl')

# Predict single image:
result = predictor.predict_image('path/to/image.jpg')

# Test different preprocessing methods:
test_different_preprocessing(predictor, 'path/to/image.jpg')

# Batch prediction:
results = predictor.predict_batch(['image1.jpg', 'image2.jpg'])

# Full pipeline:
main_prediction_pipeline('your_model.pkl', ['test1.jpg', 'test2.jpg'])

🚀 DEEPFAKE PREDICTION PIPELINE
🔄 Loading model from /content/three_class_model_20250905_110508.pkl
❌ Error loading model: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
💡 Trying alternative loading methods...
❌ All loading methods failed: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CP

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.