In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from pathlib import Path
import os
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set up plotting style
plt.style.use('ggplot')
sns.set_palette("husl")

def decode_mask(mask_path, img_shape):
    """
    Decode mask from various formats to a standard binary mask
    """
    mask_data = np.load(mask_path, allow_pickle=True)
    H, W = img_shape[:2]

    # ---- (2, N) format ‚Üí coordinate list ----
    if mask_data.ndim == 2 and mask_data.shape[0] == 2:
        y, x = mask_data.astype(int)
        mask = np.zeros((H, W), dtype=np.uint8)
        mask[y, x] = 1

    # ---- (2, H, W) format ‚Üí source and target masks ----
    elif mask_data.ndim == 3 and mask_data.shape[0] == 2:
        mask = np.clip(mask_data[0] + mask_data[1], 0, 1).astype(np.uint8)

    # ---- (1, H, W) format ‚Üí single binary channel ----
    elif mask_data.ndim == 3 and mask_data.shape[0] == 1:
        mask = (mask_data[0] > 0).astype(np.uint8)

    # ---- (3, H, W) format ‚Üí multi-channel (combine to one) ----
    elif mask_data.ndim == 3 and mask_data.shape[0] == 3:
        mask = np.max(mask_data, axis=0).astype(np.uint8)
        mask = (mask > 0).astype(np.uint8)

    # ---- (H, W) already binary ----
    elif mask_data.ndim == 2:
        mask = (mask_data > 0).astype(np.uint8)

    else:
        print(f"Warning: Unknown mask shape: {mask_data.shape}")
        mask = np.zeros((H, W), dtype=np.uint8)

    # Ensure same size as image (resize if needed)
    if mask.shape != (H, W):
        mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_NEAREST)

    return mask

class ScientificImageEDA:
    def __init__(self, data_dir):
        self.data_dir = Path(data_dir)
        self.train_authentic_dir = self.data_dir / 'train_images' / 'authentic'
        self.train_forged_dir = self.data_dir / 'train_images' / 'forged'
        self.train_masks_dir = self.data_dir / 'train_masks'
        self.test_images_dir = self.data_dir / 'test_images'
        
    def explore_directory_structure(self):
        """Explore and print directory structure"""
        print("üìÅ DIRECTORY STRUCTURE ANALYSIS")
        print("=" * 50)
        
        directories = {
            'Train Authentic': self.train_authentic_dir,
            'Train Forged': self.train_forged_dir,
            'Train Masks': self.train_masks_dir,
            'Test Images': self.test_images_dir
        }
        
        for name, path in directories.items():
            if path.exists():
                files = list(path.glob('*'))
                print(f"{name}: {len(files)} files")
                if files:
                    print(f"  Sample files: {[f.name for f in files[:3]]}")
                print(f"  Extensions: {list(set(f.suffix for f in files))}")
            else:
                print(f"{name}: Directory not found")
            print("-" * 30)
    
    def get_dataset_stats(self):
        """Get comprehensive dataset statistics"""
        print("\nüìä DATASET STATISTICS")
        print("=" * 50)
        
        stats = {}
        
        # Count files
        stats['authentic_images'] = len(list(self.train_authentic_dir.glob('*.png')))
        stats['forged_images'] = len(list(self.train_forged_dir.glob('*.png')))
        stats['mask_files'] = len(list(self.train_masks_dir.glob('*.npy')))
        stats['test_images'] = len(list(self.test_images_dir.glob('*.png')))
        
        # Print basic stats
        for key, value in stats.items():
            print(f"{key.replace('_', ' ').title()}: {value}")
        
        print(f"\nTotal Training Images: {stats['authentic_images'] + stats['forged_images']}")
        print(f"Forgery Ratio: {stats['forged_images']/(stats['authentic_images'] + stats['forged_images']):.2%}")
        
        return stats
    
    def analyze_image_properties(self, sample_size=50):
        """Analyze image dimensions, channels, and data types"""
        print("\nüñºÔ∏è IMAGE PROPERTIES ANALYSIS")
        print("=" * 50)
        
        # Collect sample images from all categories
        all_images = []
        categories = [
            ('authentic', self.train_authentic_dir),
            ('forged', self.train_forged_dir),
            ('test', self.test_images_dir)
        ]
        
        properties = []
        for category, directory in categories:
            if directory.exists():
                image_files = list(directory.glob('*.png'))[:sample_size]
                for img_path in tqdm(image_files, desc=f"Analyzing {category}"):
                    try:
                        img = cv2.imread(str(img_path))
                        if img is not None:
                            height, width, channels = img.shape
                            properties.append({
                                'category': category,
                                'filename': img_path.stem,
                                'height': height,
                                'width': width,
                                'channels': channels,
                                'dtype': str(img.dtype),
                                'min_val': img.min(),
                                'max_val': img.max(),
                                'mean_val': img.mean(),
                                'std_val': img.std()
                            })
                    except Exception as e:
                        print(f"Error processing {img_path}: {e}")
        
        self.img_properties_df = pd.DataFrame(properties)
        
        # Print summary statistics
        print("\nImage Dimensions Summary:")
        print(self.img_properties_df.groupby('category')[['height', 'width']].describe())
        
        return self.img_properties_df
    
    def analyze_mask_properties(self, sample_size=50):
        """Analyze mask properties and distributions"""
        print("\nüé≠ MASK PROPERTIES ANALYSIS")
        print("=" * 50)
        
        if not self.train_masks_dir.exists():
            print("Masks directory not found!")
            return None
        
        mask_files = list(self.train_masks_dir.glob('*.npy'))[:sample_size]
        mask_properties = []
        
        for mask_path in tqdm(mask_files, desc="Analyzing masks"):
            try:
                # Get corresponding image to know the shape
                img_filename = mask_path.stem + '.png'
                img_path = self.train_forged_dir / img_filename
                
                if img_path.exists():
                    img = cv2.imread(str(img_path))
                    mask = decode_mask(mask_path, img.shape)
                    
                    unique_vals, counts = np.unique(mask, return_counts=True)
                    
                    mask_properties.append({
                        'filename': mask_path.stem,
                        'mask_shape': mask.shape,
                        'original_mask_shape': np.load(mask_path).shape,
                        'unique_values': len(unique_vals),
                        'max_value': mask.max(),
                        'total_pixels': mask.size,
                        'forged_pixels': np.sum(mask > 0),
                        'forgery_ratio': np.sum(mask > 0) / mask.size if mask.size > 0 else 0,
                        'mask_format': str(np.load(mask_path).shape)
                    })
            except Exception as e:
                print(f"Error processing {mask_path}: {e}")
        
        self.mask_properties_df = pd.DataFrame(mask_properties)
        
        print("\nMask Properties Summary:")
        print(self.mask_properties_df.describe())
        
        # Analyze mask formats
        print("\nüìã MASK FORMATS DISTRIBUTION:")
        if not self.mask_properties_df.empty:
            format_counts = self.mask_properties_df['mask_format'].value_counts()
            for format_type, count in format_counts.items():
                print(f"  {format_type}: {count} masks")
        
        return self.mask_properties_df
    
    def visualize_image_distributions(self):
        """Create visualizations for image property distributions"""
        print("\nüìà VISUALIZING DISTRIBUTIONS")
        
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle('Image Properties Distribution Analysis', fontsize=16, fontweight='bold')
        
        # 1. Image dimensions scatter plot
        if hasattr(self, 'img_properties_df'):
            categories = self.img_properties_df['category'].unique()
            colors = {'authentic': 'green', 'forged': 'red', 'test': 'blue'}
            
            for category in categories:
                subset = self.img_properties_df[self.img_properties_df['category'] == category]
                axes[0,0].scatter(subset['width'], subset['height'], 
                                alpha=0.6, label=category, color=colors.get(category, 'gray'))
            
            axes[0,0].set_xlabel('Width')
            axes[0,0].set_ylabel('Height')
            axes[0,0].set_title('Image Dimensions Scatter Plot')
            axes[0,0].legend()
            axes[0,0].grid(True)
            
            # 2. Aspect ratio distribution
            self.img_properties_df['aspect_ratio'] = self.img_properties_df['width'] / self.img_properties_df['height']
            for category in categories:
                subset = self.img_properties_df[self.img_properties_df['category'] == category]
                axes[0,1].hist(subset['aspect_ratio'], alpha=0.7, label=category, 
                              bins=20, color=colors.get(category, 'gray'))
            
            axes[0,1].set_xlabel('Aspect Ratio (Width/Height)')
            axes[0,1].set_ylabel('Frequency')
            axes[0,1].set_title('Aspect Ratio Distribution')
            axes[0,1].legend()
            
            # 3. Intensity distributions
            for category in categories:
                subset = self.img_properties_df[self.img_properties_df['category'] == category]
                axes[0,2].hist(subset['mean_val'], alpha=0.7, label=category, 
                              bins=20, color=colors.get(category, 'gray'))
            
            axes[0,2].set_xlabel('Mean Pixel Intensity')
            axes[0,2].set_ylabel('Frequency')
            axes[0,2].set_title('Mean Intensity Distribution')
            axes[0,2].legend()
        
        # 4. Mask properties if available
        if hasattr(self, 'mask_properties_df'):
            axes[1,0].hist(self.mask_properties_df['forgery_ratio'], bins=20, alpha=0.7, color='purple')
            axes[1,0].set_xlabel('Forgery Pixel Ratio')
            axes[1,0].set_ylabel('Frequency')
            axes[1,0].set_title('Forgery Area Distribution')
            
            # 5. Unique values in masks
            axes[1,1].hist(self.mask_properties_df['unique_values'], bins=20, alpha=0.7, color='orange')
            axes[1,1].set_xlabel('Number of Unique Values in Mask')
            axes[1,1].set_ylabel('Frequency')
            axes[1,1].set_title('Mask Complexity (Unique Values)')
            
            # 6. Mask format distribution
            if not self.mask_properties_df.empty:
                format_counts = self.mask_properties_df['mask_format'].value_counts()
                axes[1,2].bar(format_counts.index.astype(str), format_counts.values, color='teal', alpha=0.7)
                axes[1,2].set_xlabel('Mask Format (shape)')
                axes[1,2].set_ylabel('Count')
                axes[1,2].set_title('Mask Format Distribution')
                axes[1,2].tick_params(axis='x', rotation=45)
        
        plt.tight_layout()
        plt.show()
    
    def display_sample_images(self, n_samples=5):
        """Display sample images with their masks if available"""
        print("\nüñºÔ∏è SAMPLE IMAGES VISUALIZATION")
        print("=" * 50)
        
        # Get sample images from each category
        categories = {
            'Authentic': self.train_authentic_dir,
            'Forged': self.train_forged_dir,
            'Test': self.test_images_dir
        }
        
        for category, directory in categories.items():
            if not directory.exists():
                continue
                
            print(f"\n{category} Images:")
            image_files = list(directory.glob('*.png'))[:n_samples]
            
            if not image_files:
                print(f"No images found in {category}")
                continue
            
            fig, axes = plt.subplots(2, n_samples, figsize=(4*n_samples, 8))
            if n_samples == 1:
                axes = axes.reshape(2, 1)
            
            for idx, img_path in enumerate(image_files):
                if idx >= n_samples:
                    break
                    
                # Load and display original image
                img = cv2.imread(str(img_path))
                img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                
                axes[0, idx].imshow(img_rgb)
                axes[0, idx].set_title(f'{category}\n{img_path.name}')
                axes[0, idx].axis('off')
                
                # Try to load and display corresponding mask for forged images
                if category == 'Forged':
                    mask_path = self.train_masks_dir / f"{img_path.stem}.npy"
                    if mask_path.exists():
                        mask = decode_mask(mask_path, img.shape)
                        axes[1, idx].imshow(mask, cmap='hot')
                        axes[1, idx].set_title(f'Decoded Mask\n{mask_path.name}')
                    else:
                        axes[1, idx].text(0.5, 0.5, 'Mask not found', 
                                         ha='center', va='center', transform=axes[1, idx].transAxes)
                    axes[1, idx].axis('off')
                else:
                    axes[1, idx].text(0.5, 0.5, 'No mask', 
                                     ha='center', va='center', transform=axes[1, idx].transAxes)
                    axes[1, idx].axis('off')
            
            plt.tight_layout()
            plt.show()
    
    def display_overlay_images(self, n_samples=3):
        """Display images with mask overlays"""
        print("\nüé≠ IMAGE-MASK OVERLAY VISUALIZATION")
        print("=" * 50)
        
        if not self.train_forged_dir.exists():
            print("Forged images directory not found!")
            return
        
        forged_images = list(self.train_forged_dir.glob('*.png'))[:n_samples]
        
        for img_path in forged_images:
            # Load image
            img = cv2.imread(str(img_path))
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            # Load and decode mask
            mask_path = self.train_masks_dir / f"{img_path.stem}.npy"
            if not mask_path.exists():
                print(f"Mask not found for {img_path.name}")
                continue
            
            mask = decode_mask(mask_path, img.shape)
            original_mask_data = np.load(mask_path)
            
            # Create overlay
            fig, axes = plt.subplots(2, 3, figsize=(18, 12))
            
            # Original image
            axes[0, 0].imshow(img_rgb)
            axes[0, 0].set_title(f'Original Image\n{img_path.name}')
            axes[0, 0].axis('off')
            
            # Original mask data (first channel if 3D)
            if original_mask_data.ndim == 3:
                axes[0, 1].imshow(original_mask_data[0] if original_mask_data.shape[0] >= 1 else original_mask_data)
                axes[0, 1].set_title(f'Original Mask Channel 0\nShape: {original_mask_data.shape}')
            else:
                axes[0, 1].imshow(original_mask_data)
                axes[0, 1].set_title(f'Original Mask Data\nShape: {original_mask_data.shape}')
            axes[0, 1].axis('off')
            
            # Decoded binary mask
            axes[0, 2].imshow(mask, cmap='hot')
            axes[0, 2].set_title('Decoded Binary Mask')
            axes[0, 2].axis('off')
            
            # Overlay 1: Simple overlay
            axes[1, 0].imshow(img_rgb)
            axes[1, 0].imshow(mask, cmap='hot', alpha=0.5)
            axes[1, 0].set_title('Image with Mask Overlay (Alpha=0.5)')
            axes[1, 0].axis('off')
            
            # Overlay 2: Contour overlay
            axes[1, 1].imshow(img_rgb)
            contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            contour_img = img_rgb.copy()
            cv2.drawContours(contour_img, contours, -1, (0, 255, 0), 2)
            axes[1, 1].imshow(contour_img)
            axes[1, 1].set_title('Image with Contour Overlay')
            axes[1, 1].axis('off')
            
            # Forgery statistics
            axes[1, 2].axis('off')
            forgery_ratio = np.sum(mask > 0) / mask.size
            stats_text = f"Forgery Statistics:\n"
            stats_text += f"Coverage: {forgery_ratio:.2%}\n"
            stats_text += f"Forged Pixels: {np.sum(mask > 0):,}\n"
            stats_text += f"Total Pixels: {mask.size:,}\n"
            stats_text += f"Original Mask Shape: {original_mask_data.shape}\n"
            stats_text += f"Decoded Mask Shape: {mask.shape}"
            
            axes[1, 2].text(0.1, 0.9, stats_text, transform=axes[1, 2].transAxes, 
                           fontsize=10, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
            
            plt.tight_layout()
            plt.show()
            
            print("-" * 50)
    
    def analyze_mask_formats(self):
        """Detailed analysis of different mask formats in the dataset"""
        print("\nüîç DETAILED MASK FORMAT ANALYSIS")
        print("=" * 50)
        
        if not self.train_masks_dir.exists():
            print("Masks directory not found!")
            return
        
        mask_files = list(self.train_masks_dir.glob('*.npy'))
        format_stats = {}
        
        for mask_path in tqdm(mask_files, desc="Analyzing mask formats"):
            try:
                mask_data = np.load(mask_path)
                shape_str = str(mask_data.shape)
                
                if shape_str not in format_stats:
                    format_stats[shape_str] = {
                        'count': 0,
                        'shapes': set(),
                        'dtypes': set(),
                        'sample_files': []
                    }
                
                format_stats[shape_str]['count'] += 1
                format_stats[shape_str]['shapes'].add(mask_data.shape)
                format_stats[shape_str]['dtypes'].add(str(mask_data.dtype))
                if len(format_stats[shape_str]['sample_files']) < 3:
                    format_stats[shape_str]['sample_files'].append(mask_path.name)
                    
            except Exception as e:
                print(f"Error analyzing {mask_path}: {e}")
        
        print("\nMask Format Summary:")
        for format_type, stats in format_stats.items():
            print(f"\nFormat: {format_type}")
            print(f"  Count: {stats['count']}")
            print(f"  Dtypes: {list(stats['dtypes'])}")
            print(f"  Sample files: {stats['sample_files']}")
    
    def create_summary_report(self):
        """Generate a comprehensive EDA summary report"""
        print("üöÄ COMPREHENSIVE EDA SUMMARY REPORT")
        print("=" * 60)
        
        # 1. Directory structure
        self.explore_directory_structure()
        
        # 2. Basic statistics
        stats = self.get_dataset_stats()
        
        # 3. Image properties
        img_df = self.analyze_image_properties(sample_size=100)
        
        # 4. Mask properties
        mask_df = self.analyze_mask_properties(sample_size=100)
        
        # 5. Mask format analysis
        self.analyze_mask_formats()
        
        # 6. Visualizations
        self.visualize_image_distributions()
        self.display_sample_images(n_samples=3)
        self.display_overlay_images(n_samples=2)
        
        # 7. Key insights
        print("\nüí° KEY INSIGHTS AND OBSERVATIONS")
        print("=" * 40)
        
        if hasattr(self, 'img_properties_df'):
            print(f"‚Ä¢ Image dimensions vary from {img_df['width'].min()}x{img_df['height'].min()} to {img_df['width'].max()}x{img_df['height'].max()}")
            print(f"‚Ä¢ Average image size: {img_df['width'].mean():.0f}x{img_df['height'].mean():.0f}")
            print(f"‚Ä¢ Common aspect ratios: {img_df['aspect_ratio'].value_counts().head(3).to_dict()}")
        
        if hasattr(self, 'mask_properties_df') and not self.mask_properties_df.empty:
            print(f"‚Ä¢ Average forgery coverage: {self.mask_properties_df['forgery_ratio'].mean():.2%}")
            print(f"‚Ä¢ Mask complexity (avg unique values): {self.mask_properties_df['unique_values'].mean():.1f}")
            print(f"‚Ä¢ Forgery size range: {self.mask_properties_df['forged_pixels'].min()} to {self.mask_properties_df['forged_pixels'].max()} pixels")
        

# Usage example
def main():
    # Initialize EDA pipeline
    data_dir = "/kaggle/input/recodai-luc-scientific-image-forgery-detection"  # Update this path
    eda = ScientificImageEDA(data_dir)
    
    # Run complete analysis
    eda.create_summary_report()

if __name__ == "__main__":
    main()