In [1]:
import torch
import torch.nn.functional as F
import timm
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from pathlib import Path
import pandas as pd
from collections import Counter, defaultdict
from scipy import stats
from sklearn.calibration import calibration_curve
from sklearn.metrics import brier_score_loss, log_loss
import matplotlib.pyplot as plt
import seaborn as sns

class CalibrationTester:
    def __init__(self, overfit_model_path, regularized_model_path, device=None):
        if device is None:
            if torch.backends.mps.is_available() and torch.backends.mps.is_built():
                self.device = torch.device("mps")
            elif torch.cuda.is_available():
                self.device = torch.device("cuda")
            else:
                self.device = torch.device("cpu")
        else:
            self.device = torch.device(device)
            
        self.class_to_idx = {'authentic': 0, 'imitation': 1}
        self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}
        
        # Load models
        self.overfit_model = self._load_swin_model(overfit_model_path, regularized=False)
        self.regularized_model = self._load_swin_model(regularized_model_path, regularized=True)
        
        # Transform
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def _load_swin_model(self, model_path, regularized=False):
        if regularized:
            model = timm.create_model(
                'swin_tiny_patch4_window7_224',
                pretrained=False,
                num_classes=2,
                img_size=256,
                drop_rate=0.5,
                drop_path_rate=0.4
            )
        else:
            model = timm.create_model(
                'swin_tiny_patch4_window7_224',
                pretrained=False,
                num_classes=2,
                img_size=256
            )
        
        state_dict = torch.load(model_path, map_location=self.device)
        model.load_state_dict(state_dict)
        model.to(self.device)
        model.eval()
        return model

    def _extract_patches(self, image):
        w, h = image.size
        max_dim = max(w, h)
        
        if max_dim > 1024:
            grid_size = 4
        elif max_dim >= 512:
            grid_size = 2
        else:
            grid_size = 1
            
        patches = []
        
        if grid_size == 1:
            min_dim = min(w, h)
            if min_dim < 256:
                patches.append(image.resize((256, 256)))
            else:
                left = (w - min_dim) // 2
                top = (h - min_dim) // 2
                patches.append(image.crop((left, top, left + min_dim, top + min_dim)))
        else:
            patch_width = w // grid_size
            patch_height = h // grid_size
            
            for i in range(grid_size):
                for j in range(grid_size):
                    left = j * patch_width
                    upper = i * patch_height
                    right = (j + 1) * patch_width if (j + 1) < grid_size else w
                    bottom = (i + 1) * patch_height if (i + 1) < grid_size else h
                    
                    patch = image.crop((left, upper, right, bottom))
                    if patch.size[0] > 0 and patch.size[1] > 0:
                        patches.append(patch)
        
        return patches

    def collect_patch_level_data(self, test_folder_path, excel_path, sheet_name='vg_cv_data_july31'):
        """Collect patch-level probabilities following the pattern from your code"""
        folder_path = Path(test_folder_path)
        image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
        
        image_files = []
        for ext in image_extensions:
            image_files.extend(folder_path.glob(f'*{ext}'))
            image_files.extend(folder_path.glob(f'*{ext.upper()}'))
        
        # Get true labels from Excel
        excel_df = pd.read_excel(excel_path, sheet_name=sheet_name)
        
        print(f"Found {len(image_files)} images to process...")
        
        # Collections for patch-level data
        painting_ids = []
        true_labels = []
        overfit_probs = []
        regularized_probs = []
        
        valid_images = 0
        
        for i, image_path in enumerate(image_files):
            if i % 50 == 0:
                print(f"Processing image {i+1}/{len(image_files)}")
            
            # Get true label from Excel
            excel_row = excel_df[excel_df['image'] == image_path.name]
            if excel_row.empty:
                continue
                
            is_authentic = excel_row.iloc[0]['is_wikiart_vangogh_oil_painting']
            true_label = 0 if is_authentic == 1 else 1  # 0=authentic, 1=imitation
            
            try:
                # Extract patches
                image = Image.open(image_path).convert('RGB')
                patches = self._extract_patches(image)
                
                # Process each patch
                for patch_idx, patch in enumerate(patches):
                    patch_tensor = self.transform(patch).unsqueeze(0).to(self.device)
                    
                    with torch.no_grad():
                        # Overfit model
                        overfit_logits = self.overfit_model(patch_tensor)
                        overfit_prob = F.softmax(overfit_logits, dim=1)[0, 1].cpu().numpy()  # P(imitation)
                        
                        # Regularized model
                        reg_logits = self.regularized_model(patch_tensor)
                        reg_prob = F.softmax(reg_logits, dim=1)[0, 1].cpu().numpy()  # P(imitation)
                    
                    # Store patch-level data (using image name as painting ID)
                    painting_ids.append(image_path.stem)  # Use filename without extension as ID
                    true_labels.append(true_label)
                    overfit_probs.append(float(overfit_prob))
                    regularized_probs.append(float(reg_prob))
                
                valid_images += 1
                
            except Exception as e:
                print(f"Error processing {image_path.name}: {e}")
                continue
        
        print(f"Successfully processed {valid_images} images with {len(painting_ids)} total patches")
        
        return painting_ids, true_labels, overfit_probs, regularized_probs

    def aggregate_per_painting(self, painting_ids, true_labels, overfit_probs, regularized_probs):
        """Aggregate patch-level predictions per painting using majority voting"""
        
        # Group patches by painting
        paintings = defaultdict(list)
        for pid, y, p1, p2 in zip(painting_ids, true_labels, overfit_probs, regularized_probs):
            paintings[pid].append((y, p1, p2))
        
        # Aggregate results
        painting_true = []
        painting_overfit_major = []
        painting_reg_major = []
        painting_overfit_probs = []
        painting_reg_probs = []
        painting_overfit_variances = []
        painting_reg_variances = []
        painting_overfit_confidences = []
        painting_reg_confidences = []
        
        for pid, patches in paintings.items():
            y_true = patches[0][0]  # All patches have same true label
            overfit_patch_probs = [t[1] for t in patches]
            reg_patch_probs = [t[2] for t in patches]
            
            # Votes at 0.5 threshold
            overfit_votes = [int(p >= 0.5) for p in overfit_patch_probs]
            reg_votes = [int(p >= 0.5) for p in reg_patch_probs]
            
            # Majority voting
            overfit_majority = Counter(overfit_votes).most_common(1)[0][0]
            reg_majority = Counter(reg_votes).most_common(1)[0][0]
            
            # Calculate confidence as proportion of patches voting for majority
            overfit_vote_counts = Counter(overfit_votes)
            reg_vote_counts = Counter(reg_votes)
            
            overfit_confidence = overfit_vote_counts[overfit_majority] / len(overfit_votes)
            reg_confidence = reg_vote_counts[reg_majority] / len(reg_votes)
            
            # Store results
            painting_true.append(y_true)
            painting_overfit_major.append(overfit_majority)
            painting_reg_major.append(reg_majority)
            painting_overfit_probs.append(overfit_patch_probs)
            painting_reg_probs.append(reg_patch_probs)
            painting_overfit_variances.append(np.var(overfit_patch_probs))
            painting_reg_variances.append(np.var(reg_patch_probs))
            painting_overfit_confidences.append(overfit_confidence)
            painting_reg_confidences.append(reg_confidence)
        
        return {
            'true_labels': np.array(painting_true),
            'overfit_predictions': np.array(painting_overfit_major),
            'reg_predictions': np.array(painting_reg_major),
            'overfit_patch_probs': painting_overfit_probs,
            'reg_patch_probs': painting_reg_probs,
            'overfit_variances': np.array(painting_overfit_variances),
            'reg_variances': np.array(painting_reg_variances),
            'overfit_confidences': np.array(painting_overfit_confidences),
            'reg_confidences': np.array(painting_reg_confidences)
        }

    def calculate_ece_for_voting(self, confidences, correct_predictions, n_bins=10):
        """Expected Calibration Error for majority voting confidences"""
        bin_boundaries = np.linspace(0.5, 1.0, n_bins + 1)  # Only bins from 0.5 to 1.0 for vote proportions
        bin_lowers = bin_boundaries[:-1]
        bin_uppers = bin_boundaries[1:]
        
        ece = 0
        for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
            in_bin = (confidences >= bin_lower) & (confidences <= bin_upper)
            prop_in_bin = in_bin.mean()
            
            if prop_in_bin > 0:
                accuracy_in_bin = correct_predictions[in_bin].mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
        
        return ece

    def bootstrap_calibration_comparison(self, overfit_confidences, reg_confidences, 
                                       overfit_correct, reg_correct, metric='ece', n_bootstrap=1000):
        """Bootstrap test comparing calibration metrics"""
        n_samples = len(overfit_correct)
        
        # Calculate observed metrics
        if metric == 'ece':
            overfit_metric = self.calculate_ece_for_voting(overfit_confidences, overfit_correct)
            reg_metric = self.calculate_ece_for_voting(reg_confidences, reg_correct)
        elif metric == 'brier':
            overfit_metric = brier_score_loss(overfit_correct, overfit_confidences)
            reg_metric = brier_score_loss(reg_correct, reg_confidences)
        elif metric == 'log_loss':
            overfit_clipped = np.clip(overfit_confidences, 1e-15, 1-1e-15)
            reg_clipped = np.clip(reg_confidences, 1e-15, 1-1e-15)
            overfit_metric = log_loss(overfit_correct, overfit_clipped)
            reg_metric = log_loss(reg_correct, reg_clipped)
        
        observed_diff = reg_metric - overfit_metric  # Negative = regularized better
        
        # Bootstrap
        np.random.seed(42)
        bootstrap_diffs = []
        
        for _ in range(n_bootstrap):
            indices = np.random.choice(n_samples, size=n_samples, replace=True)
            boot_overfit_conf = overfit_confidences[indices]
            boot_reg_conf = reg_confidences[indices]
            boot_overfit_correct = overfit_correct[indices]
            boot_reg_correct = reg_correct[indices]
            
            # Check if bootstrap sample has both classes (needed for log_loss)
            if metric == 'log_loss':
                if len(np.unique(boot_overfit_correct)) < 2 or len(np.unique(boot_reg_correct)) < 2:
                    continue  # Skip this bootstrap sample
            
            try:
                if metric == 'ece':
                    boot_overfit_metric = self.calculate_ece_for_voting(boot_overfit_conf, boot_overfit_correct)
                    boot_reg_metric = self.calculate_ece_for_voting(boot_reg_conf, boot_reg_correct)
                elif metric == 'brier':
                    boot_overfit_metric = brier_score_loss(boot_overfit_correct, boot_overfit_conf)
                    boot_reg_metric = brier_score_loss(boot_reg_correct, boot_reg_conf)
                elif metric == 'log_loss':
                    boot_overfit_clipped = np.clip(boot_overfit_conf, 1e-15, 1-1e-15)
                    boot_reg_clipped = np.clip(boot_reg_conf, 1e-15, 1-1e-15)
                    boot_overfit_metric = log_loss(boot_overfit_correct, boot_overfit_clipped)
                    boot_reg_metric = log_loss(boot_reg_correct, boot_reg_clipped)
                
                bootstrap_diffs.append(boot_reg_metric - boot_overfit_metric)
                
            except ValueError:
                # Skip bootstrap samples that cause errors
                continue
        
        bootstrap_diffs = np.array(bootstrap_diffs)
        
        # P-value for regularized being better (lower metric)
        p_value_reg_better = np.mean(bootstrap_diffs <= 0)
        
        # Confidence intervals
        ci_95_lower = np.percentile(bootstrap_diffs, 2.5)
        ci_95_upper = np.percentile(bootstrap_diffs, 97.5)
        
        return {
            'metric_name': metric,
            'overfit_metric': overfit_metric,
            'regularized_metric': reg_metric,
            'observed_difference': observed_diff,
            'p_value_reg_better': p_value_reg_better,
            'bootstrap_differences': bootstrap_diffs,
            'ci_95': (ci_95_lower, ci_95_upper),
            'regularized_is_better': observed_diff < 0
        }

    def patch_consistency_test(self, overfit_variances, reg_variances):
        """Test which model has more consistent patch predictions"""
        if len(overfit_variances) > 1:
            stat, p_value = stats.wilcoxon(reg_variances, overfit_variances, 
                                         alternative='less')  # reg < overfit
        else:
            stat, p_value = None, None
        
        return {
            'overfit_mean_variance': np.mean(overfit_variances),
            'regularized_mean_variance': np.mean(reg_variances),
            'difference': np.mean(reg_variances) - np.mean(overfit_variances),
            'wilcoxon_statistic': stat,
            'p_value_reg_more_consistent': p_value,
            'regularized_more_consistent': np.mean(reg_variances) < np.mean(overfit_variances)
        }

    def comprehensive_calibration_analysis(self, test_folder_path, excel_path, 
                                         sheet_name='vg_cv_data_july31'):
        """Run complete calibration analysis following the patch aggregation pattern"""
        
        print("COMPREHENSIVE CALIBRATION ANALYSIS - PATCH AGGREGATION")
        print("=" * 60)
        
        # 1. Collect patch-level data
        print("\n1. COLLECTING PATCH-LEVEL DATA...")
        painting_ids, true_labels, overfit_probs, reg_probs = self.collect_patch_level_data(
            test_folder_path, excel_path, sheet_name)
        
        if len(painting_ids) == 0:
            print("ERROR: No patch data collected!")
            return None
        
        # 2. Aggregate per painting
        print("\n2. AGGREGATING PER PAINTING WITH MAJORITY VOTING...")
        painting_data = self.aggregate_per_painting(painting_ids, true_labels, overfit_probs, reg_probs)
        
        # Extract aggregated data
        true_labels = painting_data['true_labels']
        overfit_preds = painting_data['overfit_predictions']
        reg_preds = painting_data['reg_predictions']
        overfit_confidences = painting_data['overfit_confidences']
        reg_confidences = painting_data['reg_confidences']
        overfit_variances = painting_data['overfit_variances']
        reg_variances = painting_data['reg_variances']
        
        # Calculate correctness
        overfit_correct = (overfit_preds == true_labels).astype(int)
        reg_correct = (reg_preds == true_labels).astype(int)
        
        # Calculate accuracies
        overfit_acc = np.mean(overfit_correct)
        reg_acc = np.mean(reg_correct)
        
        print(f"\n3. BASIC PERFORMANCE:")
        print(f"   Number of paintings: {len(true_labels)}")
        print(f"   Total patches processed: {len(painting_ids)}")
        print(f"   Average patches per painting: {len(painting_ids) / len(true_labels):.1f}")
        print(f"   Overfit accuracy: {overfit_acc:.4f}")
        print(f"   Regularized accuracy: {reg_acc:.4f}")
        print(f"   True labels: {np.sum(true_labels == 0)} authentic, {np.sum(true_labels == 1)} imitation")
        
        # Print confidence distribution
        print(f"   Overfit confidence range: [{np.min(overfit_confidences):.3f}, {np.max(overfit_confidences):.3f}]")
        print(f"   Regularized confidence range: [{np.min(reg_confidences):.3f}, {np.max(reg_confidences):.3f}]")
        
        # 4. Calibration Analysis
        print(f"\n4. CALIBRATION METRICS:")
        
        superiority_evidence = []
        
        # ECE Analysis
        ece_result = self.bootstrap_calibration_comparison(overfit_confidences, reg_confidences,
                                                         overfit_correct, reg_correct, metric='ece')
        print(f"\n   Expected Calibration Error (ECE):")
        print(f"     Overfit ECE: {ece_result['overfit_metric']:.4f}")
        print(f"     Regularized ECE: {ece_result['regularized_metric']:.4f}")
        print(f"     Difference (reg - overfit): {ece_result['observed_difference']:.4f}")
        print(f"     P(Regularized better): {ece_result['p_value_reg_better']:.4f}")
        print(f"     95% CI of difference: [{ece_result['ci_95'][0]:.4f}, {ece_result['ci_95'][1]:.4f}]")
        
        if ece_result['regularized_is_better'] and ece_result['p_value_reg_better'] > 0.95:
            print("     Result: Regularized model is significantly better calibrated")
            superiority_evidence.append("Better calibration (ECE)")
        elif ece_result['regularized_is_better']:
            print("     Result: Regularized model is better calibrated")
            superiority_evidence.append("Better calibration (ECE)")
        else:
            print("     Result: Overfit model is better calibrated")
        
        # Brier Score Analysis
        brier_result = self.bootstrap_calibration_comparison(overfit_confidences, reg_confidences,
                                                           overfit_correct, reg_correct, metric='brier')
        print(f"\n   Brier Score:")
        print(f"     Overfit Brier: {brier_result['overfit_metric']:.4f}")
        print(f"     Regularized Brier: {brier_result['regularized_metric']:.4f}")
        print(f"     Difference (reg - overfit): {brier_result['observed_difference']:.4f}")
        print(f"     P(Regularized better): {brier_result['p_value_reg_better']:.4f}")
        print(f"     95% CI of difference: [{brier_result['ci_95'][0]:.4f}, {brier_result['ci_95'][1]:.4f}]")
        
        if brier_result['regularized_is_better'] and brier_result['p_value_reg_better'] > 0.95:
            print("     Result: Regularized model has significantly better Brier score")
            superiority_evidence.append("Better Brier score")
        elif brier_result['regularized_is_better']:
            print("     Result: Regularized model has better Brier score")
            superiority_evidence.append("Better Brier score")
        else:
            print("     Result: Overfit model has better Brier score")
        
        # Log Loss Analysis
        logloss_result = self.bootstrap_calibration_comparison(overfit_confidences, reg_confidences,
                                                             overfit_correct, reg_correct, metric='log_loss')
        print(f"\n   Log Loss:")
        print(f"     Overfit Log Loss: {logloss_result['overfit_metric']:.4f}")
        print(f"     Regularized Log Loss: {logloss_result['regularized_metric']:.4f}")
        print(f"     Difference (reg - overfit): {logloss_result['observed_difference']:.4f}")
        print(f"     P(Regularized better): {logloss_result['p_value_reg_better']:.4f}")
        print(f"     95% CI of difference: [{logloss_result['ci_95'][0]:.4f}, {logloss_result['ci_95'][1]:.4f}]")
        
        if logloss_result['regularized_is_better'] and logloss_result['p_value_reg_better'] > 0.95:
            print("     Result: Regularized model has significantly better log loss")
            superiority_evidence.append("Better log loss")
        elif logloss_result['regularized_is_better']:
            print("     Result: Regularized model has better log loss")
            superiority_evidence.append("Better log loss")
        else:
            print("     Result: Overfit model has better log loss")
        
        # 5. Patch Consistency Analysis
        print(f"\n5. PATCH CONSISTENCY ANALYSIS:")
        consistency_result = self.patch_consistency_test(overfit_variances, reg_variances)
        print(f"   Overfit mean patch variance: {consistency_result['overfit_mean_variance']:.4f}")
        print(f"   Regularized mean patch variance: {consistency_result['regularized_mean_variance']:.4f}")
        print(f"   Difference (reg - overfit): {consistency_result['difference']:.4f}")
        
        if consistency_result['p_value_reg_more_consistent'] is not None:
            print(f"   P(Regularized more consistent): {consistency_result['p_value_reg_more_consistent']:.4f}")
            
            if (consistency_result['regularized_more_consistent'] and 
                consistency_result['p_value_reg_more_consistent'] < 0.05):
                print("   Result: Regularized model is significantly more consistent")
                superiority_evidence.append("Better patch consistency")
            elif consistency_result['regularized_more_consistent']:
                print("   Result: Regularized model is more consistent")
                superiority_evidence.append("Better patch consistency")
            else:
                print("   Result: Overfit model is more consistent")
        
        # 6. Overall Assessment
        print(f"\n6. OVERALL SUPERIORITY ASSESSMENT:")
        print(f"   Evidence for regularized model: {len(superiority_evidence)} out of 4 metrics")
        for evidence in superiority_evidence:
            print(f"   - {evidence}")
        
        if len(superiority_evidence) >= 3:
            print(f"\n   CONCLUSION: Regularized model is CLEARLY SUPERIOR")
        elif len(superiority_evidence) >= 2:
            print(f"\n   CONCLUSION: Regularized model is SUPERIOR")
        elif len(superiority_evidence) == 1:
            print(f"\n   CONCLUSION: Regularized model is MARGINALLY BETTER")
        else:
            print(f"\n   CONCLUSION: No clear superiority based on calibration")
        
        
        return {
            'ece': ece_result,
            'brier': brier_result,
            'log_loss': logloss_result,
            'consistency': consistency_result,
            'superiority_evidence': superiority_evidence,
            'is_superior': len(superiority_evidence) >= 2,
            'sample_size': len(true_labels),
            'accuracies': {'overfit': overfit_acc, 'regularized': reg_acc},
            'painting_data': painting_data
        }

# Main execution function
def run_calibration_test():
    OVERFIT_MODEL = "/kaggle/input/statistical-test-dataset/swin_overfit_run1.pth"
    REGULARIZED_MODEL = "/kaggle/input/statistical-test-dataset/swin_regularized_run1.pth"
    TEST_FOLDER = "/kaggle/input/statistical-test-dataset/test_folder/test_folder"
    EXCEL_PATH = "/kaggle/input/excel-file/vg_cv_data_july31_v1_with_train_val_test_split.xlsx"
    
    # Initialize calibration tester
    tester = CalibrationTester(OVERFIT_MODEL, REGULARIZED_MODEL)
    
    # Run comprehensive analysis
    results = tester.comprehensive_calibration_analysis(TEST_FOLDER, EXCEL_PATH)
    
    return tester, results

if __name__ == "__main__":
    tester, results = run_calibration_test()

COMPREHENSIVE CALIBRATION ANALYSIS - PATCH AGGREGATION

1. COLLECTING PATCH-LEVEL DATA...
Found 227 images to process...
Processing image 1/227
Processing image 51/227
Processing image 101/227
Processing image 151/227
Processing image 201/227
Successfully processed 227 images with 3140 total patches

2. AGGREGATING PER PAINTING WITH MAJORITY VOTING...

3. BASIC PERFORMANCE:
   Number of paintings: 227
   Total patches processed: 3140
   Average patches per painting: 13.8
   Overfit accuracy: 0.9692
   Regularized accuracy: 0.9692
   True labels: 109 authentic, 118 imitation
   Overfit confidence range: [0.500, 1.000]
   Regularized confidence range: [0.500, 1.000]

4. CALIBRATION METRICS:

   Expected Calibration Error (ECE):
     Overfit ECE: 0.0179
     Regularized ECE: 0.0165
     Difference (reg - overfit): -0.0014
     P(Regularized better): 0.5790
     95% CI of difference: [-0.0209, 0.0174]
     Result: Regularized model is better calibrated

   Brier Score:
     Overfit Brier: 