In [None]:
import pandas as pd
from source import image_id_converter as img_idc
import matplotlib.pyplot as plt
from PIL import Image


### Older version: 

### New version:

### Version 7:

In [None]:
import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image, ImageEnhance, ImageFilter
import numpy as np
from pathlib import Path
import tempfile


def load_image_paths(image_directory, extensions=('.jpg', '.jpeg', '.tif', '.tiff')):
    """
    Load all image file paths from a directory.
    
    Args:
        image_directory (str): Path to directory containing images
        extensions (tuple): Allowed file extensions
        
    Returns:
        list: Sorted list of image file paths
    """
    image_paths = []
    image_dir = Path(image_directory)
    
    if not image_dir.exists():
        raise FileNotFoundError(f"Directory {image_directory} does not exist")
    
    for ext in extensions:
        # Find files with current extension (case insensitive)
        image_paths.extend(image_dir.glob(f"*{ext}"))
        image_paths.extend(image_dir.glob(f"*{ext.upper()}"))
    
    # Sort paths to ensure consistent ordering
    image_paths = sorted([str(path) for path in image_paths])
    
    print(f"Found {len(image_paths)} images in {image_directory}")
    return image_paths


def convert_to_grayscale(image):
    """
    Convert PIL image to grayscale.
    
    Args:
        image (PIL.Image): Input image
        
    Returns:
        PIL.Image: Grayscale image
    """
    return image.convert('L')


def apply_aging_effect(image):
    """
    Apply aging effects to PIL image (adapted from your existing function).
    
    Args:
        image (PIL.Image): Input image
        
    Returns:
        PIL.Image: Aged image
    """
    # Ensure image is in RGB mode for aging effects
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # Heavy JPEG compression using temporary file
    with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp_file:
        temp_path = temp_file.name
        image.save(temp_path, 'JPEG', quality=15)
        image = Image.open(temp_path)
        os.unlink(temp_path)  # Clean up temp file
    
    # Significant brightness reduction
    enhancer = ImageEnhance.Brightness(image)
    image = enhancer.enhance(0.8)
    
    # Low contrast
    enhancer = ImageEnhance.Contrast(image)
    image = enhancer.enhance(0.9)
    
    # Add significant noise
    img_array = np.array(image)
    noise = np.random.normal(0, 0.1, img_array.shape).astype(np.uint8)
    img_array = np.clip(img_array.astype(np.int16) + noise, 0, 255).astype(np.uint8)
    image = Image.fromarray(img_array)
    
    # Strong blur
    image = image.filter(ImageFilter.GaussianBlur(radius=1.5))
    
    return image


def load_and_preprocess_image(image_path, target_size=(28, 28), convert_grayscale=True, apply_aging=False):
    """
    Load and preprocess a single image.
    
    Args:
        image_path (str): Path to image file
        target_size (tuple): Target size for resizing (width, height)
        convert_grayscale (bool): Whether to convert to grayscale
        apply_aging (bool): Whether to apply aging effects
        
    Returns:
        PIL.Image: Preprocessed image
    """
    try:
        # Load image
        image = Image.open(image_path)
        
        # Apply aging effects first (works best on RGB images)
        if apply_aging:
            image = apply_aging_effect(image)
        
        # Convert to grayscale if requested
        if convert_grayscale:
            image = convert_to_grayscale(image)
        
        # Resize to target size
        image = image.resize(target_size, Image.Resampling.LANCZOS)
        
        return image
        
    except Exception as e:
        print(f"Error processing {image_path}: {str(e)}")
        return None


def create_transforms(mean=0.5, std=1.0):
    """
    Create image transforms matching your MNIST setup.
    
    Args:
        mean (float): Normalization mean
        std (float): Normalization standard deviation
        
    Returns:
        torchvision.transforms.Compose: Transform pipeline
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((mean,), (std,))
    ])
    return transform


def create_label_transform():
    """
    Create label transform matching your MNIST setup.
    
    Returns:
        torchvision.transforms.Compose: Label transform pipeline
    """
    label_transform = transforms.Compose([
        lambda x: torch.LongTensor([x])
    ])
    return label_transform


def analyze_image_sizes(image_directory, extensions=('.jpg', '.jpeg', '.tif', '.tiff')):
    """
    Analyze image sizes in a directory to help determine appropriate target_size.
    
    Args:
        image_directory (str): Path to directory containing images
        extensions (tuple): Allowed file extensions
        
    Returns:
        dict: Dictionary containing size analysis results
    """
    # Get all image paths
    image_paths = load_image_paths(image_directory, extensions)
    
    if not image_paths:
        print("No images found in directory")
        return None
    
    sizes = []
    widths = []
    heights = []
    failed_images = []
    
    print(f"Analyzing {len(image_paths)} images...")
    
    # Analyze each image
    for i, image_path in enumerate(image_paths):
        try:
            with Image.open(image_path) as img:
                width, height = img.size
                sizes.append((width, height))
                widths.append(width)
                heights.append(height)
        except Exception as e:
            failed_images.append((image_path, str(e)))
            print(f"Failed to read {image_path}: {e}")
        
        # Progress indicator for large datasets
        if (i + 1) % 100 == 0:
            print(f"Processed {i + 1}/{len(image_paths)} images...")
    
    if not sizes:
        print("No valid images found")
        return None
    
    # Calculate statistics
    unique_sizes = list(set(sizes))
    all_same_size = len(unique_sizes) == 1
    
    min_width = min(widths)
    max_width = max(widths)
    avg_width = sum(widths) / len(widths)
    
    min_height = min(heights)
    max_height = max(heights)
    avg_height = sum(heights) / len(heights)
    
    min_size = (min_width, min_height)
    max_size = (max_width, max_height)
    avg_size = (avg_width, avg_height)
    
    # Create results dictionary
    results = {
        'total_images': len(image_paths),
        'valid_images': len(sizes),
        'failed_images': len(failed_images),
        'all_same_size': all_same_size,
        'unique_sizes_count': len(unique_sizes),
        'min_size': min_size,
        'max_size': max_size,
        'avg_size': avg_size,
        'min_width': min_width,
        'max_width': max_width,
        'avg_width': avg_width,
        'min_height': min_height,
        'max_height': max_height,
        'avg_height': avg_height,
        'failed_images': failed_images
    }
    
    # Print summary
    print("\n" + "="*50)
    print("IMAGE SIZE ANALYSIS SUMMARY")
    print("="*50)
    print(f"Total images found: {results['total_images']}")
    print(f"Valid images: {results['valid_images']}")
    print(f"Failed to read: {results['failed_images']}")
    print(f"\nAll images same size: {'Yes' if all_same_size else 'No'}")
    print(f"Number of unique sizes: {results['unique_sizes_count']}")
    
    print(f"\nSize ranges:")
    print(f"  Minimum size: {min_size[0]} x {min_size[1]}")
    print(f"  Maximum size: {max_size[0]} x {max_size[1]}")
    print(f"  Average size: {avg_size[0]:.1f} x {avg_size[1]:.1f}")
    
    print(f"\nWidth range: {min_width} - {max_width} (avg: {avg_width:.1f})")
    print(f"Height range: {min_height} - {max_height} (avg: {avg_height:.1f})")
    
    if results['failed_images']:
        print(f"\nFailed images:")
        for path, error in results['failed_images'][:5]:  # Show first 5 failures
            print(f"  {path}: {error}")
        if len(results['failed_images']) > 5:
            print(f"  ... and {len(results['failed_images']) - 5} more")
    
    # Suggest target size
    if all_same_size:
        print(f"\nRecommendation: Use target_size={min_size} (all images are the same size)")
    else:
        # Suggest a reasonable target size based on minimum dimensions
        suggested_size = min(min_width, min_height)
        # Round to common sizes
        common_sizes = [28, 32, 64, 128, 224, 256, 512]
        suggested_size = min(common_sizes, key=lambda x: abs(x - suggested_size))
        print(f"\nRecommendation: Consider target_size=({suggested_size}, {suggested_size})")
        print(f"  (Based on minimum dimension and common image sizes)")
    
    return results


class CustomImageDataset(Dataset):
    """
    Custom dataset class that exactly mimics torchvision.datasets.MNIST structure.
    
    This dataset replicates ALL MNIST properties including .data and .targets tensors.
    """
    
    def __init__(self, image_paths, labels, target_size=(28, 28), 
                 convert_grayscale=True, apply_aging=False,
                 transform=None, target_transform=None):
        """
        Initialize the dataset with MNIST-compatible structure.
        
        Args:
            image_paths (list): List of image file paths
            labels (list): List of integer labels corresponding to images
            target_size (tuple): Target size for resizing images
            convert_grayscale (bool): Whether to convert images to grayscale
            apply_aging (bool): Whether to apply aging effects
            transform (callable): Transform to apply to images
            target_transform (callable): Transform to apply to labels
        """
        self.image_paths = image_paths
        self.target_size = target_size
        self.convert_grayscale = convert_grayscale
        self.apply_aging = apply_aging
        self.transform = transform
        self.target_transform = target_transform
        
        # Validate that we have equal number of images and labels
        if len(image_paths) != len(labels):
            raise ValueError(f"Number of images ({len(image_paths)}) must match number of labels ({len(labels)})")
        
        # Create MNIST-compatible attributes
        self.targets = torch.tensor(labels, dtype=torch.long)  # Shape: [N]
        self._data = None  # Will be loaded when first accessed
        
        print(f"Initializing dataset with {len(image_paths)} images...")
    
    @property
    def data(self):
        """
        MNIST-compatible data property that returns all images as a tensor.
        Shape: [N, H, W] for grayscale or [N, H, W, C] for color
        Dtype: torch.uint8 (same as MNIST)
        
        Images are loaded and cached on first access.
        """
        if self._data is None:
            print("Loading all images into .data tensor (this may take a moment)...")
            self._load_all_images()
        return self._data
    
    def _load_all_images(self):
        """Load all images into the .data tensor."""
        all_images = []
        
        for i, image_path in enumerate(self.image_paths):
            # Load and preprocess image
            image = load_and_preprocess_image(
                image_path,
                target_size=self.target_size,
                convert_grayscale=self.convert_grayscale,
                apply_aging=self.apply_aging
            )
            
            # Handle failed loading
            if image is None:
                if self.convert_grayscale:
                    image = Image.new('L', self.target_size, 0)
                else:
                    image = Image.new('RGB', self.target_size, (0, 0, 0))
            
            # Convert to numpy array with uint8 dtype (like MNIST)
            img_array = np.array(image, dtype=np.uint8)
            all_images.append(img_array)
            
            # Progress indicator
            if (i + 1) % 100 == 0:
                print(f"Loaded {i + 1}/{len(self.image_paths)} images...")
        
        # Stack all images and convert to tensor
        all_images = np.stack(all_images, axis=0)
        self._data = torch.from_numpy(all_images)
        
        print(f"Loaded .data tensor with shape: {self._data.shape}, dtype: {self._data.dtype}")
    
    def __len__(self):
        """Return the total number of samples."""
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        """
        Get a single sample from the dataset.
        
        Args:
            idx (int): Index of the sample to retrieve
            
        Returns:
            tuple: (image, label) where image is a tensor and label is a tensor
            
        Note: This applies transforms to the raw data, just like MNIST
        """
        # Get raw image from .data tensor (this will load all images if needed)
        raw_image = self.data[idx]  # Shape: [H, W] or [H, W, C]
        
        # Convert tensor back to PIL Image for transforms
        if raw_image.dim() == 2:  # Grayscale
            image = Image.fromarray(raw_image.numpy(), mode='L')
        else:  # Color
            image = Image.fromarray(raw_image.numpy(), mode='RGB')
        
        # Get label from targets tensor
        label = self.targets[idx].item()  # Convert to Python int
        
        # Apply transforms (same as MNIST)
        if self.transform:
            image = self.transform(image)
        
        if self.target_transform:
            label = self.target_transform(label)
        
        return image, label


def collate_ae_dataset(samples, noise_rate=0.0, device='cpu'):
    """
    The function collates samples into a batch, and creates noisy samples if DENOISING is True
    for the denoising autoencoder.
    
    Args:
        samples: List of (image, label) tuples from the dataset
        noise_rate (float): Rate of noise to add (0.0 = no noise)
        device (str or torch.device): Device to move tensors to
        
    Returns:
        tuple: (noisy_images, clean_images, labels) all on specified device
    """
    xs = [s[0] for s in samples]
    ys = [s[1] for s in samples]
    # Extracts the first element (input data) from each sample into list xs.
    # Extracts the second element (labels or targets) from each sample into list ys.
    # This assumes each sample is a tuple or list with at least two elements.
    
    xs = torch.stack(xs)
    ys = torch.concat(ys)
    # torch.stack(xs) combines the list of input tensors into a single 
    # tensor along a new dimension (creating a batch dimension).
    # torch.concat(ys) concatenates the label tensors along 
    # the existing first dimension. This suggests the labels might have 
    # variable lengths or already include a batch-like dimension.
    
    add_noise = noise_rate > 0.
    # Checks if noise should be added based on noise_rate parameter.
    # If noise_rate is greater than 0, noise will be added to the inputs.
    
    if add_noise:
        sh = xs.shape
        noise_mask = torch.bernoulli(torch.full(sh, noise_rate))  # 0 (keep) or 1 (replace with noise)
        # Gets the shape of the input tensor batch.
        # Creates a binary mask using Bernoulli sampling, where each element has noise_rate probability of being 1 
        # (indicating where noise will be applied) and 1-noise_rate probability of being 0.
              
        sp_noise = torch.bernoulli(torch.full(sh, 0.5))-0.5  # -0.5 or 0.5
        # Generates the actual noise values as either -0.5 or 0.5.
        # First creates a tensor of the same shape filled with 0.5, then applies Bernoulli 
        # sampling to get 0s or 1s.
        # Subtracts 0.5 to convert to -0.5 or 0.5 (this creates salt and pepper noise).
          
        xns = xs * (1-noise_mask) + sp_noise * noise_mask
        # Creates the noisy input xns by:
            # Keeping original values where the mask is 0: xs * (1-noise_mask)
            # Adding noise values where the mask is 1: sp_noise * noise_mask
            # The result is a tensor where some values are preserved 
            # from the original input and others are replaced with noise.
        
        # sp = sp_noise
    else:
        xns = xs
    # If no noise is to be added, the noisy input is the same as the original input.
    
    return xns.to(device), xs.to(device), ys.to(device)
    # Returns three tensors, all moved to the specified device (likely GPU):
    # xns: The inputs with noise added (or original inputs if no noise)
    # xs: The original clean inputs
    # ys: The labels or targets
    #
    #
    # This return structure is typical for denoising autoencoders, where you need 
    # both the noisy input (fed to the encoder) and the clean target 
    # (used to compute the reconstruction loss).
    #
    # This function is specifically designed for training denoising autoencoders, 
    # where the model learns to remove noise from corrupted inputs by trying 
    # to reconstruct the original clean data.


def create_image_dataset(image_paths, labels, target_size=(28, 28),
                        convert_grayscale=True, apply_aging=False,
                        mean=0.5, std=1.0):
    """
    Create a complete image dataset ready for use with DataLoader.
    
    Args:
        image_paths (list): List of image file paths
        labels (list): List of integer labels for the images
        target_size (tuple): Target size for resizing images
        convert_grayscale (bool): Whether to convert images to grayscale
        apply_aging (bool): Whether to apply aging effects
        mean (float): Normalization mean
        std (float): Normalization standard deviation
        
    Returns:
        CustomImageDataset: Dataset ready for use with DataLoader
    """
    # Step 1: Create transforms
    transform = create_transforms(mean=mean, std=std)
    target_transform = create_label_transform()
    
    # Step 2: Create dataset
    dataset = CustomImageDataset(
        image_paths=image_paths,
        labels=labels,
        target_size=target_size,
        convert_grayscale=convert_grayscale,
        apply_aging=apply_aging,
        transform=transform,
        target_transform=target_transform
    )
    
    return dataset






In [None]:
def collate_ae_dataset(samples):
    """
    The function collates sampels into a batch, and creates noisy samples if DENOISING is True
    for the denoising autoencoder.
    """
    xs = [s[0] for s in samples]
    ys = [s[1] for s in samples]
    #Extracts the first element (input data) from each sample into list xs.
    #Extracts the second element (labels or targets) from each sample into list ys.
    #This assumes each sample is a tuple or list with at least two elements.
    
    xs = torch.stack(xs)
    ys = torch.concat(ys)
    #torch.stack(xs) combines the list of input tensors into a single 
    #tensor along a new dimension (creating a batch dimension).
    #torch.concat(ys) concatenates the label tensors along 
    #the existing first dimension. This suggests the labels might have 
    #variable lengths or already include a batch-like dimension.

    add_noise = NOISE_RATE > 0.
    #Checks if noise should be added based on a global variable NOISE_RATE.
    #If NOISE_RATE is greater than 0, noise will be added to the inputs.
    
    if add_noise:
      sh = xs.shape
      noise_mask = torch.bernoulli(torch.full(sh, NOISE_RATE))  # 0 (keep) or 1 (replace with noise)
      #Gets the shape of the input tensor batch.
      #Creates a binary mask using Bernoulli sampling, where each element has NOISE_RATE probability of being 1 
      #(indicating where noise will be applied) and 1-NOISE_RATE probability of being 0.
            
      sp_noise = torch.bernoulli(torch.full(sh, 0.5))-0.5  # -1 or 1
      #Generates the actual noise values as either -0.5 or 0.5.
      #First creates a tensor of the same shape filled with 0.5, then applies Bernoulli 
      #sampling to get 0s or 1s.
      #Subtracts 0.5 to convert to -0.5 or 0.5 (this creates salt and pepper noise).
        
      xns = xs * (1-noise_mask) + sp_noise * noise_mask
      #Creates the noisy input xns by:
          #Keeping original values where the mask is 0: xs * (1-noise_mask)
          #Adding noise values where the mask is 1: sp_noise * noise_mask
          #The result is a tensor where some values are preserved 
          #from the original input and others are replaced with noise.
      
      # sp = sp_noise
    else:
       xns = xs
    #If no noise is to be added, the noisy input is the same as the original input.

    return xns.to(device), xs.to(device), ys.to(device)
    #Returns three tensors, all moved to the specified device (likely GPU):

    #xns: The inputs with noise added (or original inputs if no noise)
    #xs: The original clean inputs
    #ys: The labels or targets
    #
    #
    #This return structure is typical for denoising autoencoders, where you need 
    #both the noisy input (fed to the encoder) and the clean target 
    #(used to compute the reconstruction loss).
    #
    #This function is specifically designed for training denoising autoencoders, 
    #where the model learns to remove noise from corrupted inputs by trying 
    #to reconstruct the original clean data.

In [None]:
def get_samples(valid_loader):
  # 1. get numpy array of all validation images:
  val_images_noisy = []
  val_images = []
  val_labels = []

  for batch_idx, (noisy_data, data, target) in enumerate(valid_loader):
      val_images_noisy.append(noisy_data.detach().cpu().numpy())
      val_images.append(data.detach().cpu().numpy())
      val_labels.append(target.detach().cpu().numpy())

  val_images_noisy = np.concatenate(val_images_noisy, axis=0)
  val_images = np.concatenate(val_images, axis=0)
  val_labels = np.concatenate(val_labels, axis=0)

  # 2. get numpy array of balanced validation samples for visualization:
  sample_images_noisy = []
  sample_images = []
  sample_labels = []
  single_el_idx = []  # indexes of single element per class for visualization

  n_class = np.max(val_labels) + 1
  # Determines the number of classes (for MNIST, this would be 10, representing digits 0-9).
  for class_idx in range(n_class):
    map_c = val_labels == class_idx

    ims_c_noisy = val_images_noisy[map_c]
    ims_c = val_images[map_c]
    print('class label:')
    print(class_idx)
    print('shape selected class')
    print(ims_c.shape)
    # For each class:
       # Creates a boolean mask map_c identifying all samples of the current class.
       # Extracts noisy and clean images for just this class.
      

    samples_idx = np.random.choice(len(ims_c), N_SAMPLE, replace=False)

    ims_c_noisy_samples = ims_c_noisy[samples_idx]
    ims_c_samples = ims_c[samples_idx]
    # Randomly selects N_SAMPLE images from the current class.
    # replace=False ensures no duplicates are selected.
    # Extracts both noisy and clean versions of these sampled images.
      

    sample_images_noisy.append(ims_c_noisy_samples)
    sample_images.append(ims_c_samples)

    sample_labels.append([class_idx]*N_SAMPLE)

    # Adds the sampled noisy images, clean images, and labels to their respective lists.
    # Creates an array of N_SAMPLE repeated labels for this class.

    start_idx = N_SAMPLE*class_idx
    single_el_idx.extend([start_idx + i for i in range(min(N_VIS_SAMPLE, N_SAMPLE))])
    # Calculates the indices for the first N_VIS_SAMPLE elements of this class in the final concatenated array.
    # These indices will be used to extract a smaller subset for visualization.

    
  sample_images_noisy = np.concatenate(sample_images_noisy, axis=0)
  sample_images = np.concatenate(sample_images, axis=0)
  sample_labels = np.concatenate(sample_labels, axis=0)
  single_el_idx = np.array(single_el_idx)
  #Combines all class samples into single arrays.
  #Converts the index list to a NumPy array.

  samples = {
      'images_noisy': sample_images_noisy,
      'images': sample_images,
      'labels': sample_labels,
      'single_el_idx': single_el_idx

  }
  return samples
# Creates and returns a dictionary with all collected samples.


# This function ensures we have:
# 
# A balanced number of samples for each class (equal representation)
# Both noisy and clean versions of each image
# A mapping between the noisy and clean versions
# A subset of indices for visualization purposes
# This is particularly useful for creating visualizations that show how the model behaves across different classes, or for comparing reconstruction quality across digits.






In [None]:
def to_np_showable(pt_img):
  np_im = pt_img.detach().cpu().numpy()
  if len(np_im.shape) == 4:
    np_im = np_im[0]

  if np_im.shape[0] > 3:
    np_im = np_im[-3:]

  return (eo.rearrange(np_im, 'c h w -> h w c')/2+.5).clip(0., 1.)

#This function converts a PyTorch tensor image to a NumPy array suitable for visualization.
#pt_img.detach().cpu().numpy() - Detaches the tensor from the computation graph, moves it to CPU if it's on GPU, and converts it to a NumPy array.
#if len(np_im.shape) == 4: - Checks if the image has a batch dimension (shape: [batch, channels, height, width]).
#np_im = np_im[0] - If there's a batch dimension, takes only the first image in the batch.
#if np_im.shape[0] > 3: - Checks if there are more than 3 channels.
#np_im = np_im[-3:] - If there are more than 3 channels, keeps only the last 3 channels (useful for handling multi-channel data).
#eo.rearrange(np_im, 'c h w -> h w c') - Uses the einops library to rearrange the tensor from PyTorch's [channels, height, width] format to matplotlib's [height, width, channels] format.
#/2+.5 - Applies normalization assuming the image data is in the range [-1, 1], converting it to [0, 1].
#.clip(0., 1.) - Ensures all values are within the [0, 1] range, clamping any values outside this range.

def plot_im(im, is_torch=True):
  plt.imshow(to_np_showable(im) if is_torch else im, cmap='gray')
  plt.show()
  plt.close()

#This function plots a single image.
#is_torch=True - Default parameter indicating whether the input is a PyTorch tensor.
#to_np_showable(im) if is_torch else im - Converts the image to a NumPy array if it's a PyTorch tensor, otherwise uses it directly.
#plt.imshow(..., cmap='gray') - Displays the image using matplotlib with a grayscale colormap.
#plt.show() - Renders the plot.
#plt.close() - Closes the figure to free up memory.

def plot_im_samples(ds, n=5, is_torch=False):
  fig, axs = plt.subplots(1, n, figsize=(16, n))
  for i, image in enumerate(ds[:n]):
      axs[i].imshow(to_np_showable(image) if is_torch else image, cmap='gray')
      axs[i].set_axis_off()
  plt.show()
  plt.close()

#This function plots multiple images from a dataset in a row.
#ds - The dataset or collection of images to sample from.
#n=5 - Default number of images to display.
#is_torch=False - Default parameter indicating whether the inputs are PyTorch tensors.
#plt.subplots(1, n, figsize=(16, n)) - Creates a figure with a single row of n subplots, with a width of 16 inches and height of n inches.
#The loop iterates through the first n images in the dataset:
#
#axs[i].imshow(...) - Displays each image in its corresponding subplot.
#axs[i].set_axis_off() - Removes axis labels and ticks for cleaner visualization.
#
#
#plt.show() - Renders the entire plot with all images.
#plt.close() - Closes the figure to free up memory.

In [None]:
# get mean and std of an array with numpy:
def get_mean_std(x):
    x_mean = np.mean(x)
    x_std = np.std(x)
    return x_mean, x_std

# get min and max of an array with numpy:
def get_min_max(x):
    x_min = np.min(x)
    x_max = np.max(x)
    return x_min, x_max

def is_iterable(obj):
    try:
        iter(obj)
    except Exception:
        return False
    else:
        return True

#This function checks if an object is iterable (can be looped over).
#It uses a try-except block to attempt to call iter(obj), which will succeed only if obj is iterable.
#If calling iter(obj) raises any exception, the function returns False.
#If no exception occurs, the function returns True.

def type_len(obj):
    t = type(obj)
    if is_iterable(obj):
        sfx = f', shape: {obj.shape}' if t == np.ndarray else ''
        print(f'type: {t}, len: {len(obj)}{sfx}')
    else:
        print(f'type: {t}, len: {len(obj)}')

#This is a utility function for debugging that prints information about an object.
#t = type(obj) - Gets the type of the provided object.
#It checks if the object is iterable using the is_iterable function defined earlier.
#If the object is iterable:
#
#It checks if the object is a NumPy array (t == np.ndarray).
#If it's a NumPy array, it adds shape information to the output string.
#It prints the type and length of the object, along with shape information if applicable.
#
#
#If the object is not iterable, it still attempts to print the type and length (though this might raise an error if len() isn't applicable to the object).
#
#Note: There seems to be an issue with the type_len function - it tries to call len() on non-iterable objects in the else clause, 
#which would typically cause an error. This might be a bug in the code.


In [None]:
# merging 2d matrix of images in 1 image
def mosaic(mtr_of_ims):
  ny = len(mtr_of_ims)
  assert(ny != 0)
  #Gets the number of rows in the matrix and asserts that it's not empty.

  nx = len(mtr_of_ims[0])
  assert(nx != 0)
  #Gets the number of columns in the first row and asserts that it's not empty.

  im_sh = mtr_of_ims[0][0].shape

  assert (2 <= len(im_sh) <= 3)
  #Gets the shape of the first image in the matrix.
  #Verifies that the image is either 2D (grayscale) or 3D (with channels).
    
  multichannel = len(im_sh) == 3

  if multichannel:
    h, w, c = im_sh
  else:
    h, w = im_sh
  #Determines if the images have multiple channels.
  #If multichannel, unpacks height, width, and channels. Otherwise, just height and width.

  h_c = h * ny + 1 * (ny-1)
  w_c = w * nx + 1 * (nx-1)
  #Calculates the total height and width of the canvas.
  #Adds 1 pixel spacing between images (both horizontally and vertically).

  canv_sh = (h_c, w_c, c) if multichannel else (h_c, w_c)
  canvas = np.ones(shape=canv_sh, dtype=np.float32)*0.5
  #Defines the shape of the canvas based on whether images are multichannel.
  #Creates a canvas filled with gray (0.5) values, assuming image values are in [0,1] range.

  for iy, row in enumerate(mtr_of_ims):
    y_ofs = iy * (h + 1)
    #Loops through each row of images.
    #Calculates the vertical offset for the current row.
    for ix, im in enumerate(row):
      x_ofs = ix * (w + 1)
      #Loops through each image in the current row.
      #Calculates the horizontal offset for the current image.
      canvas[y_ofs:y_ofs + h, x_ofs:x_ofs + w] = im
      #Copies the current image to the appropriate position in the canvas.
      #This uses NumPy's array slicing to place the image at the correct location.
  return canvas

## Define file paths: 

In [None]:

project_path = Path.cwd()
root_path = (project_path / '..').resolve()

# Define paths
image_dir = root_path/'data'  # Replace with your directory containing images

## Before running the workflow below manually copy the image files into root_path/'data'

## Get information about image size: 

In [None]:
results = analyze_image_sizes(image_dir, extensions=('.jpg', '.jpeg', '.tif', '.tiff'))
results

## Load meta data about images: 

In [None]:
with_without_person = pd.read_csv(image_dir/'with_without_person_mod.csv')
with_without_person

## Map image paths to labels:

In [None]:
img_ids = list(with_without_person.image_id)

In [None]:
with_without_person['image_id'] = img_idc.reconvert_image_ids(img_ids)

In [None]:
with_without_person.head()

In [None]:
image_paths = load_image_paths(image_dir)
image_paths[0:3]

In [None]:
img_ids = []
for image_path in image_paths:
    path_str = str(image_path)
    parts = path_str.split('.tif')
    img_id = parts[-2][-3:]
    img_ids.append(img_id)

In [None]:
image_paths_for_mapping = pd.DataFrame({'image_id': img_ids, 'image_paths': image_paths})
image_paths_for_mapping.head()

In [None]:
ids_labels_mapping = with_without_person.merge(image_paths_for_mapping, how='inner', on='image_id')
ids_labels_mapping.head()

In [None]:
with_person_only = ids_labels_mapping.loc[ids_labels_mapping.with_person == 1].copy()
with_person_only.shape

In [None]:
labels = list(ids_labels_mapping.with_person)
print(type(labels))
print(labels[0:3])

## Test convert_to_grayscale function: 

In [None]:

# Load and process image
image_path = image_paths[100]
original = Image.open(image_path)
processed = convert_to_grayscale(original)  # Or any other function

# Plot side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

ax1.imshow(original)
ax1.set_title('Original')
ax1.axis('off')

#ax2.imshow(processed, cmap='gray' if processed.mode == 'L' else None)
ax2.imshow(processed)
ax2.set_title('After convert_to_grayscale')
ax2.axis('off')

plt.tight_layout()
plt.show()

In [None]:
processed = convert_to_grayscale(original)
original_array = np.array(original)
processed_array = np.array(processed)


In [None]:
# Check the values
print(f"Shape: {original_array.shape}")
print(f"Data type: {original_array.dtype}")
print(f"Min value: {original_array.min()}")
print(f"Max value: {original_array.max()}")
print(f"Mean value: {original_array.mean():.2f}")
print(f"Unique values (first 10): {np.unique(original_array)[:10]}")

In [None]:
# Check the values
print(f"Shape: {processed_array.shape}")
print(f"Data type: {processed_array.dtype}")
print(f"Min value: {processed_array.min()}")
print(f"Max value: {processed_array.max()}")
print(f"Mean value: {processed_array.mean():.2f}")
print(f"Unique values (first 10): {np.unique(processed_array)[:10]}")

## Test apply_aging_effect function:

In [None]:

# Load and process image
image_path = image_paths[0]
original = Image.open(image_path)
processed = apply_aging_effect(original)  # Or any other function

# Plot side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

ax1.imshow(original)
ax1.set_title('Original')
ax1.axis('off')

ax2.imshow(processed, cmap='gray' if processed.mode == 'L' else None)
ax2.set_title('After aging effect')
ax2.axis('off')

plt.tight_layout()
plt.show()

In [None]:
original_array = np.array(original)
processed_array = np.array(processed)

In [None]:

# Create histogram
plt.figure(figsize=(10, 6))
plt.hist(original_array.flatten(), bins=50, alpha=0.7, edgecolor='black')
plt.title('Histogram of Pixel Values')
plt.xlabel('Pixel Value')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:

# Create histogram
plt.figure(figsize=(10, 6))
plt.hist(processed_array.flatten(), bins=50, alpha=0.7, edgecolor='black')
plt.title('Histogram of Pixel Values')
plt.xlabel('Pixel Value')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
test_image_1 = load_and_preprocess_image(image_paths[0], target_size=(512, 512), convert_grayscale=True, apply_aging=False)
type(test_image_1)

In [None]:
test_image_2 = load_and_preprocess_image(image_paths[0], target_size=(512, 512), convert_grayscale=True, apply_aging=True)
type(test_image_2)

In [None]:
plt.imshow(test_image_1, cmap='gray')

In [None]:
plt.imshow(test_image_2, cmap='gray')

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

ax1.imshow(test_image_1, cmap='gray')
ax1.set_title('Original')
ax1.axis('off')

ax2.imshow(test_image_2, cmap='gray')
ax2.set_title('After aging')
ax2.axis('off')

plt.tight_layout()
plt.show()

In [None]:
test_image_3 = load_and_preprocess_image(image_paths[0], target_size=(50, 50), convert_grayscale=True, apply_aging=False)
type(test_image_3)

In [None]:
plt.imshow(test_image_3, cmap='gray')

In [None]:
# Create DataLoader with the collate function
from functools import partial

# Set your noise rate and device
NOISE_RATE = 0.1  # or whatever you use
device = torch.device("mps" if torch.backends.mps.is_built() else "cpu")

## Try out work flow on a small subset of the data: 

In [None]:
image_paths_few = image_paths[0:3]
labels_few = [0, 1, 2]

In [None]:
# Set globals (your existing code)
NOISE_RATE = 0.1
BATCH_SIZE = 32
device = torch.device("mps" if torch.backends.mps.is_built() else "cpu")
print("Using device:", device)


In [None]:


# Step 1: Load image paths
# already done

# Step 2: Create your labels list (matching image_paths order)
# already done

# Step 3: Create dataset (equivalent to train_dataset)
dataset = create_image_dataset(
    image_paths=image_paths_few,
    labels=labels_few,
    target_size=(572, 572),
    convert_grayscale=True,
    apply_aging=False,
    mean=0.5,
    std=1.0
)

In [None]:
dataset.data.shape

In [None]:


# Step 4: Create DataLoader with your original collate function
train_loader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=collate_ae_dataset,  # Your original function using globals
    drop_last=False
)


In [None]:
print(f"DataLoader length: {len(train_loader)}")
print(f"Dataset length: {len(dataset)}")

In [None]:
# Step 5: Use exactly like your MNIST workflow
for batch_idx, (noisy_data, data, target) in enumerate(train_loader):
    # Your existing autoencoder code works unchanged!
    print(f"Batch {batch_idx}: {noisy_data.shape}, {data.shape}, {target.shape}")
    break

In [None]:
N_SAMPLE = 1
N_VIS_SAMPLE = 2

In [None]:
samples = get_samples(train_loader)

In [None]:
single_el_idx = samples['single_el_idx']
plot_im_samples(samples['images_noisy'][single_el_idx, 0], n=4, is_torch=False)
plot_im_samples(samples['images'][single_el_idx, 0], n=4, is_torch=False)

In [None]:
#This code loads and examines a single sample from the validation dataset:
for sample in dataset:
    img, label = sample
    print(type_len(img))
    print(type_len(label))
    print(img.shape, label.shape)
    plt.hist(img.flatten(), bins=100)
    break

In [None]:
#This line of code creates a histogram of the raw MNIST validation dataset pixel values:

plt.hist(dataset.data.numpy().flatten(), bins=100);

In [None]:
for sample in train_loader:
    #- Begins iteration through the validation data loader, which provides batches of data.

    noisy_img, img, label = sample
    #Unpacks the first batch from the data loader into three components:
        #noisy_img: The input images with noise added (for denoising autoencoder training)
        #img: The original clean images (targets for reconstruction)
        #label: The class labels for the images
    print(type_len(noisy_img))
    print(type_len(img))
    print(type_len(label))
    
    print(noisy_img.shape, img.shape, label.shape)
    #Directly prints the shapes of all three tensors in the batch.
    
    #plt.hist(img.flatten(), bins=100)
    #A commented-out line that would create a histogram of all pixel values in the clean image batch if uncommented.
    
    break

In [None]:
import einops as eo

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import einops as eo
import pathlib as pl

import matplotlib.cm as cm
from matplotlib import collections  as mc
from matplotlib import animation
%matplotlib inline

from scipy.stats import norm
from scipy.stats import entropy

import pandas as pd
import pickle
from PIL import Image
from time import time as timer
#import umap

from IPython.display import HTML
from IPython.display import Audio
import IPython

import tqdm.auto as tqdm

import torch
from torchvision import datasets, transforms
from torch import nn
from torch import optim
import torch.nn.functional as F

from torchvision import transforms

import sys

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, input_size, code_size):
        self.input_size = list(input_size)  # shape of data sample
        self.flat_data_size = np.prod(self.input_size)
        self.hidden_size = 128

        self.code_size = code_size  # code size

        super(AutoEncoder, self).__init__()
        #Creates an autoencoder neural network that inherits from PyTorch's nn.Module.
        #Takes two parameters:
        #
        #input_size: The shape of input data (e.g., [1, 28, 28] for MNIST)
        #code_size: The dimension of the encoded representation (bottleneck)
        #
        #
        #Calculates the flattened input size by multiplying all dimensions.
        #Sets an intermediate hidden layer size of 128 neurons.
        #Calls the parent class initializer.

        
        self.encoder = nn.Sequential(
            nn.Flatten(),

            nn.Linear(self.flat_data_size, self.hidden_size),
            nn.ReLU(),

            nn.Linear(self.hidden_size, self.code_size),
            nn.Sigmoid(),
        )
        #Defines the encoder network as a sequence of operations:
            #
            #Flattens the input (e.g., converts a 2D image to 1D)
            #Linear layer mapping from input size to hidden size
            #ReLU activation
            #Linear layer mapping from hidden size to code size
            #Sigmoid activation (constrains the encoded values to [0, 1])
        
        self.decoder = nn.Sequential(
            nn.Linear(self.code_size, self.hidden_size),
            nn.ReLU(),

            nn.Linear(self.hidden_size, self.flat_data_size),
            nn.Tanh(),  # Think: why tanh?

            nn.Unflatten(1, self.input_size),
        )
        #Defines the decoder network:
            #Linear layer from code size to hidden size
            #ReLU activation
            #Linear layer from hidden size back to the flattened input size
            #Tanh activation (outputs values in [-1, 1], matching the normalized input range)
            #Unflattens the output back to the original input shape

#Regarding "why tanh?": Tanh is used because the input images were normalized to approximately [-0.5, 0.5] 
    #range (using m=0.5, s=1.0). Tanh outputs values in the range [-1, 1], 
    #which after scaling by 1.1 in the decode method closely matches the input data range.

    def forward(self, x, return_z=False):
        encoded = self.encode(x)
        decoded = self.decode(encoded)
        return (decoded, encoded) if return_z else decoded
    # The forward pass:
        #Encodes the input
        #Decodes the encoded representation
        #If return_z=True, returns both the reconstruction and the encoded values
        #Otherwise, just returns the reconstruction
        

    def encode(self, x):
        return self.encoder(x)

    def decode(self, z):
        return self.decoder(z)*1.1
# Helper methods to encode and decode separately
# Note the multiplication by 1.1 in the decode method, 
    # which slightly amplifies the output range to better match the input data distribution

        

    def get_n_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    # Utility method to count the total number of trainable parameters in the model


def eval_on_samples(ae_model, epoch, samples):
    # this is called on end of each training epoch
    xns = samples['images_noisy']
    xns = torch.tensor(xns, dtype=torch.float32).to(device)
    #labels = samples['labels']

# Function to evaluate the autoencoder on sample data after each epoch
# Takes the model, current epoch number, and samples dictionary
# Extracts noisy images from the samples and converts them to a PyTorch tensor on the target device
# The labels are extracted but commented out (not used)

    with torch.no_grad():
        yz = ae_model(xns, return_z=True)
        yz = [el.detach().cpu().numpy() for el in yz]

        y = yz[0]
        z = yz[1:]
    # Uses torch.no_grad() to disable gradient calculation (for efficiency during evaluation)
    # Gets both reconstructions and encodings (i.e. latent space!) by calling the model with return_z=True
    # Converts all outputs to NumPy arrays
    # Separates the reconstructions y and encodings z

    res = {'z': z, 'y': y, 'epoch': epoch}
    return res

# Creates and returns a dictionary containing:

# z: The encoded representations
# y: The reconstructed images
# epoch: The current epoch number
# 

# This evaluation function captures the model's performance at each epoch, allowing for tracking reconstruction quality and analyzing the learned representations over time.

In [None]:
def plot_hist(history, logscale=True):
    """
    plot training loss
    """

    loss = history['loss']
    v_loss = history['val_loss']
    epochs = history['epoch']

    # This function visualizes training history (loss over epochs).
    # Extracts training loss, validation loss, and epoch numbers from the history dictionary.

    
    plot = plt.semilogy if logscale else plt.plot
    # Cleverly chooses between logarithmic scale (plt.semilogy) or linear scale (plt.plot) based on the logscale parameter.
    # Default is logarithmic scale, which is often better for visualizing loss curves as they typically decrease exponentially.
    
    plot(epochs, loss, label='training');
    plot(epochs, v_loss, label='validation');
    # Plots both training and validation loss curves using the selected plotting function.
    # Labels each curve for the legend.
    
    plt.legend()
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.show()
    plt.close()
    # Adds a legend, axis labels, displays the plot, and then closes the figure.



def plot_samples(sample_history, samples, epoch_stride=5, fig_scale=1):
    """
    Plots input, noisy samples (for DAE) and reconstruction.
    Each `epoch_stride`-th epoch
    """
    # This function visualizes sample reconstructions over training epochs.
    # Shows how the model's reconstruction capability improves over time.

    single_el_idx = samples['single_el_idx']
    images_noisy = samples['images_noisy'][single_el_idx, 0]
    images = samples['images'][single_el_idx, 0]
    # Extracts indices for selected samples to visualize.
    # Gets the noisy input images and the original clean images for these samples.
    # The , 0 indexing suggests selecting the first channel of each image.

    last_epoch = np.max(list(sample_history.keys()))
    # Determines the last epoch number in the history data.

    for epoch_idx, hist_el in sample_history.items():
      if epoch_idx % epoch_stride != 0 and epoch_idx != last_epoch:
        continue
    # Iterates through each epoch's results in the history.
    # Uses epoch_stride to select only every nth epoch (to avoid too many visualizations).
    # Always includes the last epoch regardless of the stride.

      samples_arr = [images_noisy, hist_el['y'][single_el_idx, 0], images]
    # Creates an array of three sets of images to visualize side by side:
       # The noisy input images
       # The model's reconstructions for the current epoch
       # The original clean images (ground truth)

      ny = len(samples_arr)
      nx = len(samples_arr[0])

      plt.figure(figsize=(fig_scale*nx, fig_scale*ny))
      # Calculates the dimensions of the visualization grid.
      # Creates a figure with size proportional to the number of samples.

        
      m = mosaic(samples_arr)
      # Uses the previously defined mosaic function to create a grid of all images.

      plt.title(f'after epoch {int(epoch_idx)}')
      plt.imshow(m, cmap='gray', vmin=-.5, vmax=.5)
      # Adds a title showing which epoch this visualization represents.
      # Displays the mosaic with a grayscale colormap and fixed value range.
      # The vmin=-.5, vmax=.5 matches the normalized data range we've seen before.

        
      plt.tight_layout(pad=0.1, h_pad=0, w_pad=0)
      plt.show()
      plt.close()
      # Ensures proper spacing in the figure.
      # Displays the figure and then closes it to free memory.

# This function creates a powerful visualization showing the progression of the model's reconstruction ability across epochs. Each visualization has three rows:
# 
# The noisy inputs
# The model's reconstructions
# The original clean images (targets)
# 
# This makes it easy to see how the model gradually learns to denoise and reconstruct the images over the course of training.

In [None]:
# These are utility functions for working with trained models at different stages of training. Let me break them down:

def run_on_trained(model, root_dir, run_fn, ep=None, model_filename=None):
    """
    Helper function to excecute any function on model in state after `ep` training epoch
    """
    # This function loads a model checkpoint and runs a specified function on it.
    # Parameters:
    # 
    # model: The neural network model instance
    # root_dir: Directory containing saved model checkpoints
    # run_fn: The function to run on the loaded model
    # ep: Specific epoch to load (optional)
    # model_filename: Specific checkpoint file to load (optional)

    if model_filename is None:
        if ep is not None:
            model_filename = root_dir/f'model_{ep:03d}.pth'
        else:
            model_filename = sorted(list(root_dir.glob('*.pth')))[-1]  # last model state
    # Determines which model checkpoint file to load:
    # 
    # If a specific filename is provided, use that (in this case this code block would be skipped)
    # If an epoch number is provided, construct the filename using a pattern
    # If neither is provided, use the last checkpoint file (by alphabetical sorting)
    # The code uses pathlib's Path objects for file handling (using / for path joining)

    
    model_dict = torch.load(model_filename,weights_only=False)

    model.load_state_dict(model_dict['model_state_dict'])

    # Loads the saved model state from the specified file
    # The weights_only=False parameter indicates to load the full state dictionary (not just weights)
    # Restores the model parameters from the saved state dictionary
    

    run_fn(model)
    # Calls the provided function on the loaded model

def run_on_all_training_history(model, root_dir, run_fn, n_ep=None):
    """
    Helper function to excecute any function on model state after each of the training epochs
    """
    # This function runs a specified function on multiple model checkpoints from different training epochs.
    # Parameters:
    # 
    # model: The neural network model instance
    # root_dir: Directory containing saved model checkpoints
    # run_fn: The function to run on each loaded model state
    # n_ep: Specific number of epochs to process (optional)
    
    if n_ep is not None:
        for ep in range(n_ep):
            print(f'running on epoch {ep+1}/{n_ep}...')
            run_on_trained(model, root_dir, run_fn, ep=ep)
    # If a specific number of epochs is provided:
    # 
    # Iterates through each epoch from 0 to n_ep-1
    # Prints progress information
    # Calls run_on_trained for each epoch
    
    else:
        for model_filename in sorted(root_dir.glob('*.pth')):
            print(f'running on checkpoint {model_filename}...')
            run_on_trained(model, root_dir, run_fn, model_filename=model_filename)

    # If no specific number of epochs is provided:
    # 
    # Finds all .pth files in the root directory
    # Sorts them (presumably by name, which would be by epoch if using the naming pattern)
    # Processes each checkpoint file in order
    
    print(f'done')

    # Prints a completion message when all checkpoints have been processed
    # 
    # These utility functions make it easy to:
    # 
    # Analyze a model at a specific point in its training history
    # Run the same analysis across multiple stages of training
    # Visualize or evaluate how the model's behavior changes over the course of training
    # 
    # They're particularly useful for post-training analysis, debugging, and creating visualizations of model evolution.

## Try out some models to check if the code works:

### Try untrained model:

In [None]:
# This code block initializes and tests the autoencoder model with a sample batch. 
# Let me explain it line by line:
train_batch = next(iter(train_loader))
# Gets the first batch from the training data loader without running a full epoch
# iter(train_loader) creates an iterator from the data loader
# next() retrieves the first element from that iterator (the first batch)
xns, xs, ys = train_batch
# Unpacks the batch into three components:
 # xns: The noisy input images
 # xs: The clean original images
 # ys: The class labels

print('sample shapes:', xns.shape, xs.shape, ys.shape)
# Prints the shapes of all three tensors to verify their dimensions
# Likely shows something like [batch_size, 1, 28, 28] for the images

in_size = xns.shape[1:]
print(in_size)
# Extracts the input size excluding the batch dimension
# For MNIST, this would be [1, 28, 28] (channels, height, width)

ae = AutoEncoder(input_size=in_size, code_size=10).to(device)
# Creates an instance of the AutoEncoder model:
# 
# input_size is set to the dimensions of the input data
# code_size=10 defines the bottleneck dimension (the size of the encoded representation)
# .to(device) moves the model to the appropriate device (CPU or GPU)

y = ae(xns)
# Performs a forward pass through the model with the noisy images
# The model attempts to reconstruct the clean images from the noisy ones
# Since return_z=False by default, this only returns the reconstructions

print('output shape:', y.shape)
# Prints the shape of the model's output
# Should match the input shape, as the autoencoder reconstructs the original dimensions

plot_im_samples(xns, is_torch=True)
# Visualizes a few of the noisy input images using the previously defined function

plot_im_samples(y, is_torch=True)
# Visualizes the corresponding reconstructed images
# This allows comparing the model's initial reconstructions before training -> The reconstructions are 
# just noise because the model has not been trained. 

In [None]:
# This code compares the pixel value distributions of an input image and its reconstruction. 
# Here's what each line does:

x = xns[0]# - y[1]
# Selects the first image from the batch of noisy inputs.
# Note that there's a commented-out subtraction (# - y[1])

d = y[0]# - y[1]
# Selects the first image from the batch of reconstructed outputs.
# Again, there's a commented-out subtraction

im0 = x[0].detach().cpu().numpy()
# Takes the first channel of the selected input image
# Detaches it from the computation graph (no gradients needed)
# Moves it to CPU if it was on GPU
# Converts it to a NumPy array

im1 = d[0].detach().cpu().numpy()
# Does the same conversion process for the reconstructed image

# plt.imshow(im, cmap='gray', vmin=-1, vmax=1)
# This is a commented-out visualization that would display the image

bins = np.linspace(-1, 1, 100)
# Creates 100 evenly spaced histogram bins from -1 to 1
# This range is chosen to match the expected range of pixel values

plt.hist(im0.flatten(), bins, alpha=0.3);
# Creates a histogram of all pixel values in the input image
# flatten() converts the 2D image to a 1D array
# alpha=0.3 makes the histogram semi-transparent

plt.hist(im1.flatten(), bins, alpha=0.3);
# Creates a histogram of all pixel values in the reconstructed image
# Using the same bins and transparency
# Overlaid on the same plot as the input image histogram


# This visualization allows comparing the distribution of pixel values between 
# the noisy input and the reconstruction. It helps assess how well the autoencoder 
# is preserving the overall pixel value distribution and whether 
# it's correctly mapping values from the input distribution to the expected output distribution.
# The semi-transparent overlapping histograms make it easy to see differences 
# in how pixel values are distributed between the original and reconstructed images.

### Try overfitting dense model (no validation set) to check if the model learns anything:

In [None]:
# This code sets up the final model configuration and prepares the sample data for training:

CODE_SIZE = 20
# Sets the dimensionality of the encoded representation (bottleneck) to 5
# This is smaller than the previous test where code_size was 10, 
# creating a more compressed representation


NOISE_RATE = 0
# Sets the noise rate for the denoising autoencoder to 0
# This means no artificial noise will be added, making it 
# function as a standard autoencoder rather than a denoising one

MODEL_NAME = 'ae_model'
# Assigns a name to the model, likely used for saving checkpoints and organizing results

model = AutoEncoder(input_size=in_size, code_size=CODE_SIZE).to(device)
# Creates a new instance of the AutoEncoder with:
# 
# The previously determined input size (from the shape of the data)
# The newly defined CODE_SIZE of 5
# Placed on the appropriate device (CPU or GPU)

samples = get_samples(train_loader)
# Calls the previously defined get_samples function to create a balanced set of 
# validation samples
# These samples will be used to monitor reconstruction quality during training
# The function selects representative samples from each class for visualization
# 
# This code block is preparing the final model configuration before training. 
# It's worth noting that with NOISE_RATE set to 0, this will train a standard 
# autoencoder rather than a denoising autoencoder, despite the earlier code 
# being set up to handle noise addition.


In [None]:
# train the autoencoder model, for N_EPOCHS epochs,
# save history of loss values for training and validation sets,
# history of validation samples evolution, and model weights history,


# This code implements the complete training loop for the autoencoder. 
# Let me break it down:

N_EPOCHS = 50
LR = 0.0009
# Sets the number of training epochs to 50
# Sets the learning rate for the Adam optimizer to 0.0009

model_root = pl.Path(MODEL_NAME)
model_root.mkdir(exist_ok=True)
# Creates a directory path for saving model checkpoints using the MODEL_NAME ('ae_model')
# Makes sure the directory exists (creates it if it doesn't)

optimizer = optim.Adam(model.parameters(), lr=LR)
# Creates an Adam optimizer to update the model parameters
# Adam is an adaptive learning rate optimization algorithm well-suited for deep learning

# implement loss explicitly
loss = nn.MSELoss()
# Defines the loss function as Mean Squared Error (MSE)
# This measures the average squared difference between the reconstructed and target images

# train the model
history = {'loss': [], 'val_loss': [], 'epoch': []}
sample_history = {}
# Creates dictionaries to store training metrics and sample reconstructions
# history tracks training and validation losses across epochs
# sample_history will store sample reconstruction results at each epoch

pbar = tqdm.tqdm(range(N_EPOCHS), postfix=f'epoch 0/{N_EPOCHS}')
# Creates a progress bar for tracking the training process
# Will show the current epoch and update metrics during training

for epoch_idx in pbar:
# Starts the main training loop that runs for N_EPOCHS iterations
    epoch_loss = 0
    model.train()
    # Initializes the epoch loss accumulator
    # Sets the model to training mode (enables dropout, batch normalization updates, etc.)
    
    for batch_idx, (noisy_data, data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(noisy_data)
        loss_value = loss(output, data)
        loss_value.backward()
        optimizer.step()
        epoch_loss += loss_value.detach().cpu().item()
    # Iterates through all batches in the training dataset
    # For each batch:
    # 
        # Zeros out previous gradients
        # Passes the noisy input through the model
        # Calculates the MSE loss between the reconstruction and clean data
        # Computes gradients via backpropagation
        # Updates model parameters using the optimizer
        # Accumulates the loss value for epoch-level reporting
    
    epoch_loss /= len(train_loader)
    history['loss'].append(epoch_loss)
    history['epoch'].append(epoch_idx)
    # update progress bar

    # Calculates the average loss for the epoch
    # Records the loss and epoch number in the history

    # evaluate on validation set
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for batch_idx, (noisy_data, data, target) in enumerate(train_loader):
            output = model(noisy_data)
            loss_value = loss(output, data)
            val_loss += loss_value.detach().cpu().item()
        val_loss /= len(train_loader)
        history['val_loss'].append(val_loss)

    # Sets the model to evaluation mode (disables dropout, etc.)
    # Disables gradient calculation for efficiency
    # Computes the validation loss on the entire validation set
    # Records the average validation loss in the history
    
    pbar.set_postfix({'epoch': f'{epoch_idx+1}/{N_EPOCHS}', 'loss':f'{epoch_loss:.4f}', 'val_loss':f'{val_loss:.4f}'})
    # evaluate on samples
    # Updates the progress bar with current epoch, training loss, and validation loss
    
    sample_res = eval_on_samples(model, epoch_idx, samples=samples)
    # This saves the reconstructions and the latent space thanks to
    # the eval_on_samples function where in the application of the 
    # model to the evaluation data the return_z parameter is set 
    # to true: 
    # with torch.no_grad():
    #     yz = ae_model(xns, return_z=True)
    #     yz = [el.detach().cpu().numpy() for el in yz]
# 
    #     y = yz[0]
    #     z = yz[1:]
    
    # The output of eval_on_samples looks like this: 

    # sample_res = {'z': z, 'y': y, 'epoch': epoch}

    
    sample_history[epoch_idx] = sample_res
    # Evaluates the model on the sample images
    # Stores reconstructions for later visualization

    # save model weights
    torch.save({
                'epoch': epoch_idx,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
                }, model_root/f'model_{epoch_idx:03d}.pth')

    # Saves a checkpoint of the model at each epoch
    # The checkpoint includes:
    # 
    # Current epoch number
    # Model parameters
    # Optimizer state (allows resuming training)
    # Loss function
    # 
    # 
    # Uses a formatted filename with padded epoch number (e.g., 'model_001.pth')
# 
# This is a complete training pipeline that not only trains the model 
# but also tracks metrics, evaluates on validation data, 
# and creates visualizations to monitor progress - 
# all while saving checkpoints for later analysis or resuming training.
    

In [None]:
plot_hist(history)

In [None]:
plot_samples(sample_history, samples=samples, epoch_stride=5, fig_scale=1)