In [None]:
import pandas as pd
from source import image_id_converter as img_idc
import matplotlib.pyplot as plt
from PIL import Image



In [None]:
import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image, ImageEnhance, ImageFilter
import numpy as np
from pathlib import Path
import tempfile


def load_image_paths(image_directory, extensions=('.jpg', '.jpeg', '.tif', '.tiff')):
    """
    Load all image file paths from a directory.
    
    Args:
        image_directory (str): Path to directory containing images
        extensions (tuple): Allowed file extensions
        
    Returns:
        list: Sorted list of image file paths
    """
    image_paths = []
    image_dir = Path(image_directory)
    
    if not image_dir.exists():
        raise FileNotFoundError(f"Directory {image_directory} does not exist")
    
    for ext in extensions:
        # Find files with current extension (case insensitive)
        image_paths.extend(image_dir.glob(f"*{ext}"))
        image_paths.extend(image_dir.glob(f"*{ext.upper()}"))
    
    # Sort paths to ensure consistent ordering
    image_paths = sorted([str(path) for path in image_paths])
    
    print(f"Found {len(image_paths)} images in {image_directory}")
    return image_paths

def analyze_image_sizes(image_directory, extensions=('.jpg', '.jpeg', '.tif', '.tiff')):
    """
    Analyze image sizes in a directory to help determine appropriate target_size.
    
    Args:
        image_directory (str): Path to directory containing images
        extensions (tuple): Allowed file extensions
        
    Returns:
        dict: Dictionary containing size analysis results
    """
    # Get all image paths
    image_paths = load_image_paths(image_directory, extensions)
    
    if not image_paths:
        print("No images found in directory")
        return None
    
    sizes = []
    widths = []
    heights = []
    failed_images = []
    
    print(f"Analyzing {len(image_paths)} images...")
    
    # Analyze each image
    for i, image_path in enumerate(image_paths):
        try:
            with Image.open(image_path) as img:
                width, height = img.size
                sizes.append((width, height))
                widths.append(width)
                heights.append(height)
        except Exception as e:
            failed_images.append((image_path, str(e)))
            print(f"Failed to read {image_path}: {e}")
        
        # Progress indicator for large datasets
        if (i + 1) % 100 == 0:
            print(f"Processed {i + 1}/{len(image_paths)} images...")
    
    if not sizes:
        print("No valid images found")
        return None
    
    # Calculate statistics
    unique_sizes = list(set(sizes))
    all_same_size = len(unique_sizes) == 1
    
    min_width = min(widths)
    max_width = max(widths)
    avg_width = sum(widths) / len(widths)
    
    min_height = min(heights)
    max_height = max(heights)
    avg_height = sum(heights) / len(heights)
    
    min_size = (min_width, min_height)
    max_size = (max_width, max_height)
    avg_size = (avg_width, avg_height)
    
    # Create results dictionary
    results = {
        'total_images': len(image_paths),
        'valid_images': len(sizes),
        'failed_images': len(failed_images),
        'all_same_size': all_same_size,
        'unique_sizes_count': len(unique_sizes),
        'min_size': min_size,
        'max_size': max_size,
        'avg_size': avg_size,
        'min_width': min_width,
        'max_width': max_width,
        'avg_width': avg_width,
        'min_height': min_height,
        'max_height': max_height,
        'avg_height': avg_height,
        'failed_images': failed_images
    }
    
    # Print summary
    print("\n" + "="*50)
    print("IMAGE SIZE ANALYSIS SUMMARY")
    print("="*50)
    print(f"Total images found: {results['total_images']}")
    print(f"Valid images: {results['valid_images']}")
    print(f"Failed to read: {results['failed_images']}")
    print(f"\nAll images same size: {'Yes' if all_same_size else 'No'}")
    print(f"Number of unique sizes: {results['unique_sizes_count']}")
    
    print(f"\nSize ranges:")
    print(f"  Minimum size: {min_size[0]} x {min_size[1]}")
    print(f"  Maximum size: {max_size[0]} x {max_size[1]}")
    print(f"  Average size: {avg_size[0]:.1f} x {avg_size[1]:.1f}")
    
    print(f"\nWidth range: {min_width} - {max_width} (avg: {avg_width:.1f})")
    print(f"Height range: {min_height} - {max_height} (avg: {avg_height:.1f})")
    
    if results['failed_images']:
        print(f"\nFailed images:")
        for path, error in results['failed_images'][:5]:  # Show first 5 failures
            print(f"  {path}: {error}")
        if len(results['failed_images']) > 5:
            print(f"  ... and {len(results['failed_images']) - 5} more")
    
    # Suggest target size
    if all_same_size:
        print(f"\nRecommendation: Use target_size={min_size} (all images are the same size)")
    else:
        # Suggest a reasonable target size based on minimum dimensions
        suggested_size = min(min_width, min_height)
        # Round to common sizes
        common_sizes = [28, 32, 64, 128, 224, 256, 512]
        suggested_size = min(common_sizes, key=lambda x: abs(x - suggested_size))
        print(f"\nRecommendation: Consider target_size=({suggested_size}, {suggested_size})")
        print(f"  (Based on minimum dimension and common image sizes)")
    
    return results


def convert_to_grayscale(image):
    """
    Convert PIL image to grayscale.
    
    Args:
        image (PIL.Image): Input image
        
    Returns:
        PIL.Image: Grayscale image
    """
    return image.convert('L')


def apply_aging_effect(image):
    """
    Apply aging effects to PIL image (adapted from your existing function).
    
    Args:
        image (PIL.Image): Input image
        
    Returns:
        PIL.Image: Aged image
    """
    # Ensure image is in RGB mode for aging effects
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # Heavy JPEG compression using temporary file
    with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp_file:
        temp_path = temp_file.name
        image.save(temp_path, 'JPEG', quality=15)
        image = Image.open(temp_path)
        os.unlink(temp_path)  # Clean up temp file
    
    # Significant brightness reduction
    enhancer = ImageEnhance.Brightness(image)
    image = enhancer.enhance(0.8)
    
    # Low contrast
    enhancer = ImageEnhance.Contrast(image)
    image = enhancer.enhance(0.9)
    
    # Add significant noise
    img_array = np.array(image)
    noise = np.random.normal(0, 0.1, img_array.shape).astype(np.uint8)
    img_array = np.clip(img_array.astype(np.int16) + noise, 0, 255).astype(np.uint8)
    image = Image.fromarray(img_array)
    
    # Strong blur
    image = image.filter(ImageFilter.GaussianBlur(radius=1.5))
    
    return image


def load_and_preprocess_image(image_path, target_size=(28, 28), convert_grayscale=True, apply_aging=False):
    """
    Load and preprocess a single image.
    
    Args:
        image_path (str): Path to image file
        target_size (tuple): Target size for resizing (width, height)
        convert_grayscale (bool): Whether to convert to grayscale
        apply_aging (bool): Whether to apply aging effects
        
    Returns:
        PIL.Image: Preprocessed image
    """
    try:
        # Load image
        image = Image.open(image_path)
        
        # Apply aging effects first (works best on RGB images)
        if apply_aging:
            image = apply_aging_effect(image)
        
        # Convert to grayscale if requested
        if convert_grayscale:
            image = convert_to_grayscale(image)
        
        # Resize to target size
        image = image.resize(target_size, Image.Resampling.LANCZOS)
        
        return image
        
    except Exception as e:
        print(f"Error processing {image_path}: {str(e)}")
        return None


def create_transforms(mean=0.5, std=1.0):
    """
    Create image transforms matching your MNIST setup.
    
    Args:
        mean (float): Normalization mean
        std (float): Normalization standard deviation
        
    Returns:
        torchvision.transforms.Compose: Transform pipeline
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((mean,), (std,))
    ])
    return transform


def create_label_transform():
    """
    Create label transform matching your MNIST setup.
    
    Returns:
        torchvision.transforms.Compose: Label transform pipeline
    """
    label_transform = transforms.Compose([
        lambda x: torch.LongTensor([x])
    ])
    return label_transform


class CustomImageDataset(Dataset):
    """
    Custom dataset class that mimics torchvision.datasets.MNIST structure.
    
    This dataset loads images lazily (on-demand) and applies preprocessing
    and transforms similar to how MNIST dataset works.
    """
    
    def __init__(self, image_paths, labels, target_size=(28, 28), 
                 convert_grayscale=True, apply_aging=False,
                 transform=None, target_transform=None):
        """
        Initialize the dataset.
        
        Args:
            image_paths (list): List of image file paths
            labels (list): List of integer labels corresponding to images
            target_size (tuple): Target size for resizing images
            convert_grayscale (bool): Whether to convert images to grayscale
            apply_aging (bool): Whether to apply aging effects
            transform (callable): Transform to apply to images
            target_transform (callable): Transform to apply to labels
        """
        self.image_paths = image_paths
        self.labels = labels
        self.target_size = target_size
        self.convert_grayscale = convert_grayscale
        self.apply_aging = apply_aging
        self.transform = transform
        self.target_transform = target_transform
        
        # Validate that we have equal number of images and labels
        if len(image_paths) != len(labels):
            raise ValueError(f"Number of images ({len(image_paths)}) must match number of labels ({len(labels)})")
    
    def __len__(self):
        """Return the total number of samples."""
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        """
        Get a single sample from the dataset.
        
        Args:
            idx (int): Index of the sample to retrieve
            
        Returns:
            tuple: (image, label) where image is a tensor and label is a tensor
        """
        # Load and preprocess image (lazy loading)
        image = load_and_preprocess_image(
            self.image_paths[idx],
            target_size=self.target_size,
            convert_grayscale=self.convert_grayscale,
            apply_aging=self.apply_aging
        )
        
        # Handle case where image loading failed
        if image is None:
            # Return a black image as fallback
            if self.convert_grayscale:
                image = Image.new('L', self.target_size, 0)
            else:
                image = Image.new('RGB', self.target_size, (0, 0, 0))
        
        # Get label
        label = self.labels[idx]
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        if self.target_transform:
            label = self.target_transform(label)
        
        return image, label


def create_image_dataset(image_paths, labels, target_size=(28, 28),
                        convert_grayscale=True, apply_aging=False,
                        mean=0.5, std=1.0):
    """
    Create a complete image dataset ready for use with DataLoader.
    
    Args:
        image_paths (list): List of image file paths
        labels (list): List of integer labels for the images
        target_size (tuple): Target size for resizing images
        convert_grayscale (bool): Whether to convert images to grayscale
        apply_aging (bool): Whether to apply aging effects
        mean (float): Normalization mean
        std (float): Normalization standard deviation
        
    Returns:
        CustomImageDataset: Dataset ready for use with DataLoader
    """
    # Step 1: Create transforms
    transform = create_transforms(mean=mean, std=std)
    target_transform = create_label_transform()
    
    # Step 2: Create dataset
    dataset = CustomImageDataset(
        image_paths=image_paths,
        labels=labels,
        target_size=target_size,
        convert_grayscale=convert_grayscale,
        apply_aging=apply_aging,
        transform=transform,
        target_transform=target_transform
    )
    
    return dataset



In [None]:
def get_samples(valid_loader):
  # 1. get numpy array of all validation images:
  val_images_noisy = []
  val_images = []
  val_labels = []

  for batch_idx, (noisy_data, data, target) in enumerate(valid_loader):
      val_images_noisy.append(noisy_data.detach().cpu().numpy())
      val_images.append(data.detach().cpu().numpy())
      val_labels.append(target.detach().cpu().numpy())

  val_images_noisy = np.concatenate(val_images_noisy, axis=0)
  val_images = np.concatenate(val_images, axis=0)
  val_labels = np.concatenate(val_labels, axis=0)

  # 2. get numpy array of balanced validation samples for visualization:
  sample_images_noisy = []
  sample_images = []
  sample_labels = []
  single_el_idx = []  # indexes of single element per class for visualization

  n_class = np.max(val_labels) + 1
  # Determines the number of classes (for MNIST, this would be 10, representing digits 0-9).
  for class_idx in range(n_class):
    map_c = val_labels == class_idx

    ims_c_noisy = val_images_noisy[map_c]
    ims_c = val_images[map_c]
    # For each class:
       # Creates a boolean mask map_c identifying all samples of the current class.
       # Extracts noisy and clean images for just this class.
      

    samples_idx = np.random.choice(len(ims_c), N_SAMPLE, replace=False)

    ims_c_noisy_samples = ims_c_noisy[samples_idx]
    ims_c_samples = ims_c[samples_idx]
    # Randomly selects N_SAMPLE images from the current class.
    # replace=False ensures no duplicates are selected.
    # Extracts both noisy and clean versions of these sampled images.
      

    sample_images_noisy.append(ims_c_noisy_samples)
    sample_images.append(ims_c_samples)

    sample_labels.append([class_idx]*N_SAMPLE)

    # Adds the sampled noisy images, clean images, and labels to their respective lists.
    # Creates an array of N_SAMPLE repeated labels for this class.

    start_idx = N_SAMPLE*class_idx
    single_el_idx.extend([start_idx + i for i in range(min(N_VIS_SAMPLE, N_SAMPLE))])
    # Calculates the indices for the first N_VIS_SAMPLE elements of this class in the final concatenated array.
    # These indices will be used to extract a smaller subset for visualization.

    
  sample_images_noisy = np.concatenate(sample_images_noisy, axis=0)
  sample_images = np.concatenate(sample_images, axis=0)
  sample_labels = np.concatenate(sample_labels, axis=0)
  single_el_idx = np.array(single_el_idx)
  #Combines all class samples into single arrays.
  #Converts the index list to a NumPy array.

  samples = {
      'images_noisy': sample_images_noisy,
      'images': sample_images,
      'labels': sample_labels,
      'single_el_idx': single_el_idx

  }
  return samples
# Creates and returns a dictionary with all collected samples.


# This function ensures we have:
# 
# A balanced number of samples for each class (equal representation)
# Both noisy and clean versions of each image
# A mapping between the noisy and clean versions
# A subset of indices for visualization purposes
# This is particularly useful for creating visualizations that show how the model behaves across different classes, or for comparing reconstruction quality across digits.






In [None]:
import numpy as np
a = 7
b = 1
x = 0
1 / (1 + np.exp(-(a + b*x)))

In [None]:
np.exp(-3)

## Define file paths: 

In [None]:

project_path = Path.cwd()
root_path = (project_path / '..').resolve()

# Define paths
image_dir = root_path/'data'  # Replace with your directory containing images

In [None]:
results = analyze_image_sizes(image_dir, extensions=('.jpg', '.jpeg', '.tif', '.tiff'))
results

## Load image ids and labels:

#### Using 'recognisable' as a label (indicating if a person or several persons recognisable as such are in the image).

In [None]:
with_without_person = pd.read_csv(image_dir/'with_without_person_mod.csv')
with_without_person

In [None]:
img_ids = list(with_without_person.image_id)

In [None]:
with_without_person['image_id'] = img_idc.reconvert_image_ids(img_ids)

In [None]:
with_without_person.head()

## Load image paths:

In [None]:
image_paths = load_image_paths(image_dir)
image_paths[0:3]

In [None]:
type(image_paths)

In [None]:
img_ids = []
for image_path in image_paths:
    path_str = str(image_path)
    parts = path_str.split('.tif')
    img_id = parts[-2][-3:]
    img_ids.append(img_id)

In [None]:
img_ids[0:3]

In [None]:
image_paths_for_mapping = pd.DataFrame({'image_id': img_ids, 'image_paths': image_paths})
image_paths_for_mapping.head()

## Map image paths to labels:

In [None]:
ids_labels_mapping = with_without_person.merge(image_paths_for_mapping, how='inner', on='image_id')
ids_labels_mapping.head()

In [None]:
labels = list(ids_labels_mapping.recognisable)
print(type(labels))
print(labels[0:3])

In [None]:
image_paths = list(ids_labels_mapping.image_paths)
print(type(image_paths))
print(image_paths[0:3])

## Test convert_to_grayscale function:

In [None]:

# Load and process image
image_path = image_paths[100]
original = Image.open(image_path)
processed = convert_to_grayscale(original)  # Or any other function

# Plot side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

ax1.imshow(original)
ax1.set_title('Original')
ax1.axis('off')

#ax2.imshow(processed, cmap='gray' if processed.mode == 'L' else None)
ax2.imshow(processed)
ax2.set_title('After convert_to_grayscale')
ax2.axis('off')

plt.tight_layout()
plt.show()

In [None]:
processed = convert_to_grayscale(original)
original_array = np.array(original)
processed_array = np.array(processed)
import matplotlib.pyplot as plt

In [None]:
# Check the values
print(f"Shape: {original_array.shape}")
print(f"Data type: {original_array.dtype}")
print(f"Min value: {original_array.min()}")
print(f"Max value: {original_array.max()}")
print(f"Mean value: {original_array.mean():.2f}")
print(f"Unique values (first 10): {np.unique(original_array)[:10]}")

In [None]:
# Check the values
print(f"Shape: {processed_array.shape}")
print(f"Data type: {processed_array.dtype}")
print(f"Min value: {processed_array.min()}")
print(f"Max value: {processed_array.max()}")
print(f"Mean value: {processed_array.mean():.2f}")
print(f"Unique values (first 10): {np.unique(processed_array)[:10]}")

## Test the apply_aging_effect function:

In [None]:

# Load and process image
image_path = image_paths[0]
original = Image.open(image_path)
processed = apply_aging_effect(original)  # Or any other function

# Plot side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

ax1.imshow(original)
ax1.set_title('Original')
ax1.axis('off')

ax2.imshow(processed, cmap='gray' if processed.mode == 'L' else None)
ax2.set_title('After aging effect')
ax2.axis('off')

plt.tight_layout()
plt.show()

In [None]:
original_array = np.array(original)
processed_array = np.array(processed)

In [None]:

# Create histogram
plt.figure(figsize=(10, 6))
plt.hist(original_array.flatten(), bins=50, alpha=0.7, edgecolor='black')
plt.title('Histogram of Pixel Values')
plt.xlabel('Pixel Value')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:

# Create histogram
plt.figure(figsize=(10, 6))
plt.hist(processed_array.flatten(), bins=50, alpha=0.7, edgecolor='black')
plt.title('Histogram of Pixel Values')
plt.xlabel('Pixel Value')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
test_image = load_and_preprocess_image(image_paths[0], target_size=(512, 512), convert_grayscale=True, apply_aging=False)
type(test_image)

In [None]:
plt.imshow(test_image, cmap='gray')

In [None]:
test_image_2 = load_and_preprocess_image(image_paths[0], target_size=(512, 512), convert_grayscale=True, apply_aging=True)
type(test_image_2)

In [None]:
plt.imshow(test_image_2, cmap='gray')

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

ax1.imshow(test_image, cmap='gray')
ax1.set_title('Original')
ax1.axis('off')

ax2.imshow(test_image_2, cmap='gray')
ax2.set_title('After aging')
ax2.axis('off')

plt.tight_layout()
plt.show()

## Test target_size parameter:

In [None]:
test_image_3 = load_and_preprocess_image(image_paths[0], target_size=(50, 50), convert_grayscale=True, apply_aging=True)
type(test_image_3)

In [None]:
plt.imshow(test_image_3, cmap='gray')

In [None]:
# Example of how to use the pipeline

# Assuming you have:
# - Images in './my_images/' directory
# - Labels as a list of integers


# Create dataset
dataset = create_image_dataset(
    image_paths=image_paths,
    labels=labels,
    target_size=(512, 512),
    convert_grayscale=True,
    apply_aging=False,
    mean=0.5,
    std=1.0
)

# Test the dataset
print(f"Dataset length: {len(dataset)}")

# Get first sample
if len(dataset) > 0:
    image, label = dataset[0]
    print(f"Image shape: {image.shape}")
    print(f"Label: {label}")
    print(f"Image dtype: {image.dtype}")
    print(f"Label dtype: {label.dtype}")

# Create DataLoader (same as your MNIST setup)
from torch.utils.data import DataLoader

# You can use your existing collate function
BATCH_SIZE = 32
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# Test with one batch
for batch in train_loader:
    images, labels = batch
    print(f"Batch images shape: {images.shape}")
    print(f"Batch labels shape: {labels.shape}")
    break

In [None]:
# This code compares the pixel value distributions of an input image and its reconstruction. 
# Here's what each line does:

x = xns[0]# - y[1]
# Selects the first image from the batch of noisy inputs.
# Note that there's a commented-out subtraction (# - y[1])

d = y[0]# - y[1]
# Selects the first image from the batch of reconstructed outputs.
# Again, there's a commented-out subtraction

im0 = x[0].detach().cpu().numpy()
# Takes the first channel of the selected input image
# Detaches it from the computation graph (no gradients needed)
# Moves it to CPU if it was on GPU
# Converts it to a NumPy array

im1 = d[0].detach().cpu().numpy()
# Does the same conversion process for the reconstructed image

# plt.imshow(im, cmap='gray', vmin=-1, vmax=1)
# This is a commented-out visualization that would display the image

bins = np.linspace(-1, 1, 100)
# Creates 100 evenly spaced histogram bins from -1 to 1
# This range is chosen to match the expected range of pixel values

plt.hist(im0.flatten(), bins, alpha=0.3);
# Creates a histogram of all pixel values in the input image
# flatten() converts the 2D image to a 1D array
# alpha=0.3 makes the histogram semi-transparent

plt.hist(im1.flatten(), bins, alpha=0.3);
# Creates a histogram of all pixel values in the reconstructed image
# Using the same bins and transparency
# Overlaid on the same plot as the input image histogram


# This visualization allows comparing the distribution of pixel values between 
# the noisy input and the reconstruction. It helps assess how well the autoencoder 
# is preserving the overall pixel value distribution and whether 
# it's correctly mapping values from the input distribution to the expected output distribution.
# The semi-transparent overlapping histograms make it easy to see differences 
# in how pixel values are distributed between the original and reconstructed images.

In [None]:
samples = get_samples(train_loader)