In [7]:
import torch
import torch.nn.functional as F
import timm
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from pathlib import Path
import json
from datetime import datetime
import pandas as pd
import os
from typing import List, Dict

class ConsistentVanGoghInference:
    def __init__(self, model_configs: List[Dict], device=None, target_size=256, aggregation_method='both'):
        self.model_configs = model_configs
        self.target_size = target_size
        self.aggregation_method = aggregation_method  # 'majority', 'mean', or 'both'

        if device is None:
            if torch.backends.mps.is_available() and torch.backends.mps.is_built():
                self.device = torch.device("mps")
            elif torch.cuda.is_available():
                self.device = torch.device("cuda")
            else:
                self.device = torch.device("cpu")
        else:
            self.device = torch.device(device)

        print(f"Using device: {self.device}")
        print(f"Aggregation method: {self.aggregation_method}")

        self.class_to_idx = {'authentic': 0, 'imitation': 1}
        self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}

        # Load all models
        self.models = {}
        self._load_all_models()

        self.transform = transforms.Compose([
            transforms.Resize((self.target_size, self.target_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])

        print(f"Successfully loaded {len(self.models)} models")

    def _load_all_models(self):
        for config in self.model_configs:
            model_name = config['name']
            model_type = config['type'].lower()
            model_path = config['path']

            print(f"Loading {model_name} ({model_type})...")

            try:
                # Create model architecture
                if model_type == 'swin':
                    model = timm.create_model(
                        'swin_tiny_patch4_window7_224',
                        pretrained=False,
                        num_classes=2,
                        img_size=256,
                        drop_rate=0.5,
                        drop_path_rate=0.4
                    )
                elif model_type == 'efficientnet':
                    model = timm.create_model(
                        'efficientnet_b5',
                        pretrained=False,
                        num_classes=2
                    )
                else:
                    raise ValueError(f"Unsupported model type: {model_type}")

                # Load weights
                state_dict = torch.load(model_path, map_location=self.device)
                model.load_state_dict(state_dict)

                model.to(self.device)
                model.eval()

                self.models[model_name] = {
                    'model': model,
                    'type': model_type,
                    'category': config['category'],
                    'path': model_path
                }

                print(f"✓ {model_name} loaded successfully")

            except Exception as e:
                print(f"✗ Error loading {model_name}: {e}")
                raise

    def _adaptive_patch_extraction(self, image):
        w, h = image.size
        max_dim = max(w, h)

        if max_dim > 1024:
            grid_size = 4  # 4x4 patches
        elif max_dim >= 512:
            grid_size = 2  # 2x2 patches
        else:
            grid_size = 1  # 1x1 patch

        patches = []

        if grid_size == 1:
            # For small images, use center crop or resize
            min_dim = min(w, h)
            if min_dim < 256:
                # Resize small images
                resized = image.resize((256, 256), Image.Resampling.LANCZOS)
                patches.append(resized)
            else:
                # Center crop for images >= 256px
                left = (w - min_dim) // 2
                top = (h - min_dim) // 2
                center_crop = image.crop((left, top, left + min_dim, top + min_dim))
                patches.append(center_crop)
        else:
            # Grid-based patching (non-overlapping)
            patch_width = w // grid_size
            patch_height = h // grid_size

            for i in range(grid_size):
                for j in range(grid_size):
                    left = j * patch_width
                    upper = i * patch_height
                    # For the last patch in a row/column, extend to the image edge
                    right = (j + 1) * patch_width if (j + 1) < grid_size else w
                    bottom = (i + 1) * patch_height if (i + 1) < grid_size else h

                    patch_img = image.crop((left, upper, right, bottom))
                    if patch_img.size[0] > 0 and patch_img.size[1] > 0:
                        patches.append(patch_img)

        return patches

    def preprocess_image(self, image_path):
        """Preprocess image for model input."""
        try:
            image = Image.open(image_path).convert('RGB')
            return image
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            raise

    def predict_with_adaptive_patches_single_model(self, image, model_info):
        """Run adaptive patch-based inference for a single model using both methods."""
        model = model_info['model']

        patches = self._adaptive_patch_extraction(image)

        if not patches:
            raise ValueError("No patches extracted from image")

        # Transform patches to tensors
        patch_tensors = []
        for patch in patches:
            patch_tensor = self.transform(patch).unsqueeze(0).to(self.device)
            patch_tensors.append(patch_tensor)

        # Get logits and predictions for all patches
        all_logits = []
        patch_predictions = []
        patch_confidences = []

        with torch.no_grad():
            for patch_tensor in patch_tensors:
                logits = model(patch_tensor)
                probabilities = F.softmax(logits, dim=1)

                # Store logits for mean aggregation
                all_logits.append(logits)

                # Get individual patch prediction for majority voting
                predicted_class_idx = torch.argmax(probabilities, dim=1).item()
                confidence = probabilities.cpu().numpy()[0]

                patch_predictions.append(predicted_class_idx)
                patch_confidences.append(confidence)

        results = {}

        # Method 1: Mean Logits Aggregation
        if self.aggregation_method in ['mean', 'both']:
            stacked_logits = torch.cat(all_logits, dim=0)
            avg_logits = torch.mean(stacked_logits, dim=0, keepdim=True)

            # Apply softmax to get probabilities
            mean_probabilities = F.softmax(avg_logits, dim=1)

            # Get prediction
            mean_predicted_class_idx = torch.argmax(mean_probabilities, dim=1).item()
            mean_predicted_class = self.idx_to_class[mean_predicted_class_idx]

            # Get confidence scores
            mean_confidence_scores = mean_probabilities.cpu().numpy()[0]
            mean_max_confidence = float(mean_confidence_scores[mean_predicted_class_idx])

            results['mean_logits'] = {
                'predicted_class': mean_predicted_class,
                'confidence': mean_max_confidence,
                'probabilities': {
                    'authentic': float(mean_confidence_scores[0]),
                    'imitation': float(mean_confidence_scores[1])
                }
            }

        # Method 2: Majority Voting
        if self.aggregation_method in ['majority', 'both']:
            # Count votes
            authentic_votes = sum(1 for pred in patch_predictions if pred == 0)
            imitation_votes = sum(1 for pred in patch_predictions if pred == 1)

            # Determine final prediction based on majority
            if authentic_votes > imitation_votes:
                majority_final_class = 'authentic'
            elif imitation_votes > authentic_votes:
                majority_final_class = 'imitation'
            else:
                # Tie case - use average confidence to break tie
                avg_authentic_conf = np.mean([conf[0] for conf in patch_confidences])
                avg_imitation_conf = np.mean([conf[1] for conf in patch_confidences])

                if avg_authentic_conf > avg_imitation_conf:
                    majority_final_class = 'authentic'
                else:
                    majority_final_class = 'imitation'

            # Calculate final confidence based on the winning class
            if majority_final_class == 'authentic':
                # Average confidence of patches that predicted authentic
                authentic_confidences = [patch_confidences[i][0] for i in range(len(patch_predictions))
                                       if patch_predictions[i] == 0]
                if authentic_confidences:
                    majority_final_confidence = np.mean(authentic_confidences)
                else:
                    majority_final_confidence = np.mean([conf[0] for conf in patch_confidences])
            else:
                # Average confidence of patches that predicted imitation
                imitation_confidences = [patch_confidences[i][1] for i in range(len(patch_predictions))
                                       if patch_predictions[i] == 1]
                if imitation_confidences:
                    majority_final_confidence = np.mean(imitation_confidences)
                else:
                    majority_final_confidence = np.mean([conf[1] for conf in patch_confidences])

            # Calculate overall probabilities for reporting
            overall_authentic_prob = authentic_votes / len(patch_predictions)
            overall_imitation_prob = imitation_votes / len(patch_predictions)

            results['majority_voting'] = {
                'predicted_class': majority_final_class,
                'confidence': float(majority_final_confidence),
                'probabilities': {
                    'authentic': float(overall_authentic_prob),
                    'imitation': float(overall_imitation_prob)
                },
                'vote_counts': {
                    'authentic': authentic_votes,
                    'imitation': imitation_votes,
                    'total_patches': len(patch_predictions)
                }
            }

        return results

    def predict_single_image(self, image_path):
        # Load and preprocess image
        image = self.preprocess_image(image_path)

        results = {}

        # Run inference with each model
        for model_name, model_info in self.models.items():
            try:
                model_result = self.predict_with_adaptive_patches_single_model(image, model_info)
                results[model_name] = model_result
            except Exception as e:
                results[model_name] = {'error': str(e)}

        return results

    def predict_folder_and_create_tables(self, folder_path):
        """Run inference on all images in a folder and create separate results tables."""
        folder_path = Path(folder_path)

        # Supported image extensions
        image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.jp2'}

        # Find all image files
        image_files = []
        for ext in image_extensions:
            image_files.extend(folder_path.glob(f'*{ext}'))
            image_files.extend(folder_path.glob(f'*{ext.upper()}'))

        print(f"Found {len(image_files)} images in {folder_path}")

        if not image_files:
            print("No images found!")
            return

        # Store results for both tables
        mean_table_data = {}
        majority_table_data = {}

        # Initialize table data structures
        model_names = list(self.models.keys())
        for model_name in model_names:
            mean_table_data[model_name] = {}
            majority_table_data[model_name] = {}

        # Process each image
        for image_path in image_files:
            image_name = image_path.name
            print(f"Processing: {image_name}")

            try:
                results = self.predict_single_image(image_path)

                for model_name, result in results.items():
                    if 'error' not in result:
                        # Process mean logits results
                        if 'mean_logits' in result:
                            mean_result = result['mean_logits']
                            pred_class = mean_result['predicted_class']
                            confidence = mean_result['confidence']

                            if pred_class == 'authentic':
                                mean_table_data[model_name][image_name] = f"Auth: {confidence*100:.0f}%"
                            else:
                                mean_table_data[model_name][image_name] = f"Imit: {confidence*100:.0f}%"

                        # Process majority voting results
                        if 'majority_voting' in result:
                            majority_result = result['majority_voting']
                            pred_class = majority_result['predicted_class']
                            confidence = majority_result['confidence']

                            if pred_class == 'authentic':
                                majority_table_data[model_name][image_name] = f"Auth: {confidence*100:.0f}%"
                            else:
                                majority_table_data[model_name][image_name] = f"Imit: {confidence*100:.0f}%"
                    else:
                        # Handle errors
                        mean_table_data[model_name][image_name] = "ERROR"
                        majority_table_data[model_name][image_name] = "ERROR"

            except Exception as e:
                print(f"Error processing {image_name}: {e}")
                for model_name in model_names:
                    mean_table_data[model_name][image_name] = "ERROR"
                    majority_table_data[model_name][image_name] = "ERROR"

        # Create and display the first table (Mean Logits)
        mean_df = pd.DataFrame(mean_table_data).T  # Transpose to have models as rows
        mean_df = mean_df.reindex(sorted(mean_df.columns), axis=1)

        print("\n" + "="*150)
        print("PREDICTION RESULTS TABLE - MEAN LOGITS AGGREGATION")
        print("="*150)
        print(mean_df.to_string())
        print("="*150)

        # Create and display the second table (Majority Voting)
        majority_df = pd.DataFrame(majority_table_data).T  # Transpose to have models as rows
        majority_df = majority_df.reindex(sorted(majority_df.columns), axis=1)

        print("\n" + "="*150)
        print("PREDICTION RESULTS TABLE - MAJORITY VOTING AGGREGATION")
        print("="*150)
        print(majority_df.to_string())
        print("="*150)

        return mean_df, majority_df


def main():
    # Model configurations
    MODEL_CONFIGS = [
        {
            'path': '/swin_overfit_model_run1.pth',
            'type': 'swin',
            'name': 'swin_overfit',
            'category': 'overfit'
        },
        {
            'path': '/swin_regularized_model_run1.pth',
            'type': 'swin',
            'name': 'swin_regularized_1',
            'category': 'regularized'
        },
        {
            'path': '/swin_regularized_model_run2.pth',
            'type': 'swin',
            'name': 'swin_regularized_2',
            'category': 'regularized'
        },
        {
            'path': '/swin_regularized_model_run3.pth',
            'type': 'swin',
            'name': 'swin_regularized_3',
            'category': 'regularized'
        },
        {
            'path': '/swin_regularized_model_run5.pth',
            'type': 'swin',
            'name': 'swin_regularized_5',
            'category': 'regularized'
        },
        {
            'path': '/swin_regularized_model_run6.pth',
            'type': 'swin',
            'name': 'swin_regularized_6',
            'category': 'regularized'
        },
        {
            'path': '/swin_regularized_model_run7.pth',
            'type': 'swin',
            'name': 'swin_regularized_7',
            'category': 'regularized'
        },
        {
            'path': '/efficientnet_unstable_model_run1.pth',
            'type': 'efficientnet',
            'name': 'effnet_unstable',
            'category': 'unstable'
        },
        {
            'path': '/efficientnet_stable_model_run1.pth',
            'type': 'efficientnet',
            'name': 'effnet_stable_1',
            'category': 'stable'
        },
        {
            'path': '/efficientnet_stable_model_run2.pth',
            'type': 'efficientnet',
            'name': 'effnet_stable_2',
            'category': 'stable'
        },
        {
            'path': '/efficientnet_stable_model_run3.pth',
            'type': 'efficientnet',
            'name': 'effnet_stable_3',
            'category': 'stable'
        },
        {
            'path': '/efficientnet_stable_model_run5.pth',
            'type': 'efficientnet',
            'name': 'effnet_stable_5',
            'category': 'stable'
        },
        {
            'path': '/efficientnet_stable_model_run6.pth',
            'type': 'efficientnet',
            'name': 'effnet_stable_6',
            'category': 'stable'
        },
        {
            'path': '/efficientnet_stable_model_run7.pth',
            'type': 'efficientnet',
            'name': 'effnet_stable_7',
            'category': 'stable'
        }
    ]

    # Input folder containing images
    INPUT_FOLDER = "/content/test_images"

    print("Initializing Van Gogh Inference System...")

    # Initialize inference system with both methods
    # You can change aggregation_method to 'mean', 'majority', or 'both'
    inferencer = ConsistentVanGoghInference(MODEL_CONFIGS, aggregation_method='both')

    # Run inference and create tables
    mean_df, majority_df = inferencer.predict_folder_and_create_tables(folder_path=INPUT_FOLDER)

    return mean_df, majority_df

if __name__ == "__main__":
    main()

Initializing Van Gogh Inference System...
Using device: cuda
Aggregation method: both
Loading swin_overfit (swin)...
✓ swin_overfit loaded successfully
Loading swin_regularized_1 (swin)...
✓ swin_regularized_1 loaded successfully
Loading swin_regularized_2 (swin)...
✓ swin_regularized_2 loaded successfully
Loading swin_regularized_3 (swin)...
✓ swin_regularized_3 loaded successfully
Loading swin_regularized_5 (swin)...
✓ swin_regularized_5 loaded successfully
Loading swin_regularized_6 (swin)...
✓ swin_regularized_6 loaded successfully
Loading swin_regularized_7 (swin)...
✓ swin_regularized_7 loaded successfully
Loading effnet_unstable (efficientnet)...
✓ effnet_unstable loaded successfully
Loading effnet_stable_1 (efficientnet)...
✓ effnet_stable_1 loaded successfully
Loading effnet_stable_2 (efficientnet)...
✓ effnet_stable_2 loaded successfully
Loading effnet_stable_3 (efficientnet)...
✓ effnet_stable_3 loaded successfully
Loading effnet_stable_5 (efficientnet)...
✓ effnet_stable_5 