# ECONTRAIL Detection - Dataset Loading and Testing

This notebook demonstrates how to load and test the modified dataset.

## Purpose

- Load and explore the modified dataset
- Verify data integrity
- Visualize samples
- Test data preprocessing pipeline

In [None]:
# Import required libraries
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json

# Import ECONTRAIL detection utilities
from econtrail_detection.utils import (
    load_image,
    preprocess_image
)

print("Setup complete!")

## 1. Dataset Overview

Explore the structure and contents of the modified dataset.

In [None]:
# Define dataset directory
data_dir = Path('data')

# Check if dataset exists
if not data_dir.exists():
    print(f"Dataset directory not found: {data_dir}")
    print("Creating directory structure...")
    data_dir.mkdir(parents=True, exist_ok=True)
    (data_dir / 'images').mkdir(exist_ok=True)
    (data_dir / 'masks').mkdir(exist_ok=True)
    (data_dir / 'ground_truth').mkdir(exist_ok=True)
    print("Directory structure created!")
else:
    print(f"Dataset directory found: {data_dir}")

# List subdirectories
subdirs = [d for d in data_dir.iterdir() if d.is_dir()]
print(f"\nSubdirectories: {[d.name for d in subdirs]}")

In [None]:
# Count files in each subdirectory
for subdir in subdirs:
    image_files = list(subdir.glob('*.png')) + list(subdir.glob('*.jpg')) + list(subdir.glob('*.tif'))
    print(f"{subdir.name:20s}: {len(image_files)} files")

## 2. Load Dataset Samples

Load and display sample images from the dataset.

In [None]:
# Load images from 'images' subdirectory
images_dir = data_dir / 'images'

if images_dir.exists():
    image_files = list(images_dir.glob('*.png')) + list(images_dir.glob('*.jpg')) + list(images_dir.glob('*.tif'))
    
    if image_files:
        print(f"Found {len(image_files)} images in dataset")
        
        # Display first few samples
        n_samples = min(6, len(image_files))
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.flatten()
        
        for i in range(n_samples):
            img = load_image(image_files[i])
            axes[i].imshow(img)
            axes[i].set_title(f"{image_files[i].name}\nShape: {img.shape}")
            axes[i].axis('off')
        
        # Hide unused subplots
        for i in range(n_samples, 6):
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.show()
    else:
        print("No images found in 'data/images' directory.")
        print("Add your dataset images to this directory.")
else:
    print(f"Images directory not found: {images_dir}")
    print("Create 'data/images' directory and add your dataset images.")

## 3. Verify Image-Mask Pairs

Check if images have corresponding masks and visualize them together.

In [None]:
# Check for corresponding masks
masks_dir = data_dir / 'masks'

if images_dir.exists() and masks_dir.exists():
    image_files = list(images_dir.glob('*.png')) + list(images_dir.glob('*.jpg'))
    
    if image_files:
        print("Checking image-mask pairs...\n")
        
        paired_data = []
        for img_path in image_files:
            # Look for corresponding mask
            mask_path = masks_dir / img_path.name
            if not mask_path.exists():
                # Try with different extension
                mask_path = masks_dir / f"{img_path.stem}.png"
            
            if mask_path.exists():
                paired_data.append((img_path, mask_path))
        
        print(f"Found {len(paired_data)} paired image-mask samples")
        print(f"Missing masks: {len(image_files) - len(paired_data)}")
        
        # Visualize paired samples
        if paired_data:
            n_samples = min(3, len(paired_data))
            fig, axes = plt.subplots(n_samples, 2, figsize=(12, 4 * n_samples))
            if n_samples == 1:
                axes = axes.reshape(1, -1)
            
            for i in range(n_samples):
                img_path, mask_path = paired_data[i]
                
                # Load and display image
                img = load_image(img_path)
                axes[i, 0].imshow(img)
                axes[i, 0].set_title(f"Image: {img_path.name}")
                axes[i, 0].axis('off')
                
                # Load and display mask
                mask = load_image(mask_path)
                if len(mask.shape) == 3:
                    mask = mask[:, :, 0]  # Use first channel
                axes[i, 1].imshow(mask, cmap='gray')
                axes[i, 1].set_title(f"Mask: {mask_path.name}")
                axes[i, 1].axis('off')
            
            plt.tight_layout()
            plt.show()
    else:
        print("No images found to check for pairs.")
else:
    print("Images or masks directory not found.")
    print("Create 'data/images' and 'data/masks' directories with your data.")

## 4. Test Data Preprocessing

Test the preprocessing pipeline on sample data.

In [None]:
# Test preprocessing with different configurations
if images_dir.exists():
    image_files = list(images_dir.glob('*.png')) + list(images_dir.glob('*.jpg'))
    
    if image_files:
        # Load a sample image
        sample_img = load_image(image_files[0])
        print(f"Original image shape: {sample_img.shape}")
        print(f"Original data type: {sample_img.dtype}")
        print(f"Original value range: [{sample_img.min()}, {sample_img.max()}]")
        
        # Test different preprocessing options
        fig, axes = plt.subplots(2, 2, figsize=(12, 12))
        
        # Original
        axes[0, 0].imshow(sample_img)
        axes[0, 0].set_title(f"Original\nShape: {sample_img.shape}")
        axes[0, 0].axis('off')
        
        # Normalized
        normalized = preprocess_image(sample_img, normalize=True)
        axes[0, 1].imshow(normalized)
        axes[0, 1].set_title(f"Normalized\nRange: [{normalized.min():.2f}, {normalized.max():.2f}]")
        axes[0, 1].axis('off')
        
        # Resized
        resized = preprocess_image(sample_img, target_size=(256, 256))
        axes[1, 0].imshow(resized)
        axes[1, 0].set_title(f"Resized\nShape: {resized.shape}")
        axes[1, 0].axis('off')
        
        # Normalized + Resized
        both = preprocess_image(sample_img, target_size=(256, 256), normalize=True)
        axes[1, 1].imshow(both)
        axes[1, 1].set_title(f"Normalized + Resized\nShape: {both.shape}")
        axes[1, 1].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        print("\nPreprocessing test complete!")
    else:
        print("No images available for preprocessing test.")
else:
    print("Images directory not found.")

## 5. Dataset Statistics

Calculate and display dataset statistics.

In [None]:
# Calculate dataset statistics
if images_dir.exists():
    image_files = list(images_dir.glob('*.png')) + list(images_dir.glob('*.jpg'))
    
    if image_files:
        print("Calculating dataset statistics...\n")
        
        shapes = []
        sizes = []
        
        for img_path in image_files:
            img = load_image(img_path)
            shapes.append(img.shape)
            sizes.append(img_path.stat().st_size / 1024)  # Size in KB
        
        # Display statistics
        print(f"Total images: {len(image_files)}")
        print(f"\nImage shapes (unique): {set(shapes)}")
        print(f"\nFile sizes:")
        print(f"  Min:  {min(sizes):.2f} KB")
        print(f"  Max:  {max(sizes):.2f} KB")
        print(f"  Mean: {np.mean(sizes):.2f} KB")
        
        # Plot file size distribution
        plt.figure(figsize=(10, 5))
        plt.hist(sizes, bins=20, edgecolor='black')
        plt.xlabel('File Size (KB)')
        plt.ylabel('Count')
        plt.title('Distribution of Image File Sizes')
        plt.grid(True, alpha=0.3)
        plt.show()
    else:
        print("No images found for statistics calculation.")
else:
    print("Images directory not found.")

## 6. Create Dataset Metadata (Optional)

Generate a metadata file for the dataset.

In [None]:
# Create dataset metadata
if images_dir.exists():
    image_files = list(images_dir.glob('*.png')) + list(images_dir.glob('*.jpg'))
    
    if image_files:
        metadata = {
            'dataset_name': 'ECONTRAIL Modified Dataset',
            'num_images': len(image_files),
            'image_files': [img.name for img in image_files],
            'created': str(Path.cwd()),
        }
        
        # Save metadata
        metadata_path = data_dir / 'dataset_metadata.json'
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        print(f"Metadata saved to: {metadata_path}")
        print(f"\nMetadata summary:")
        print(f"  Dataset: {metadata['dataset_name']}")
        print(f"  Images:  {metadata['num_images']}")
    else:
        print("No images found to create metadata.")
else:
    print("Images directory not found.")

## Summary

This notebook demonstrated:
1. Loading and exploring the modified dataset structure
2. Visualizing dataset samples
3. Verifying image-mask pairs
4. Testing the preprocessing pipeline
5. Calculating dataset statistics
6. Creating dataset metadata

### Next Steps

- Add more images to `data/images/`
- Add corresponding masks to `data/masks/`
- Use the `evaluation.ipynb` notebook to run model predictions
- See the [research paper](https://doi.org/10.1109/TGRS.2025.3629628) for more details