In [1]:
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, Subset
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.amp import GradScaler, autocast
import os
import random
import pandas as pd
from math import comb

In [2]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to ConvMixer input size
    transforms.ToTensor()
])

# Load the dataset
dataset = datasets.Imagenette(root='/home/j597s263/scratch/j597s263/Datasets/imagenette', download=False, transform=transform)

# Shuffle indices with a fixed random seed for reproducibility
random.seed(42)  # Use any fixed seed for consistency
indices = list(range(len(dataset)))
random.shuffle(indices)

# Split shuffled indices into training and testing
train_indices = indices[:7568]
test_indices = indices[7568:8522]

# Create Subsets
train_data = Subset(dataset, train_indices)
test_data = Subset(dataset, test_indices)

# Create DataLoaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)  # Shuffle within batches
test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=False)  # No shuffle for test set

# Print dataset sizes
print(f"Total samples: {len(dataset)}")
print(f"Training samples: {len(train_data)}")
print(f"Test samples: {len(test_data)}")


Total samples: 9469
Training samples: 7568
Test samples: 954


In [3]:
def get_pixel_coords(flat_indices, width):
    return [divmod(idx, width) for idx in flat_indices]

In [4]:
def calculate_pixel_frequencies_from_loader(data_loader, pixel_coords):
    """
    Calculate the frequency of pixel values at specific coordinates from a DataLoader.

    Args:
        data_loader (DataLoader): A DataLoader containing the image dataset.
        pixel_coords (list of tuples): A list of (x, y) pixel coordinates to evaluate.

    Returns:
        dict: A dictionary where keys are pixel coordinates, and values are dictionaries of RGB frequencies.
    """
    pixel_freq = {coord: {} for coord in pixel_coords}

    for batch_idx, (images, _) in enumerate(data_loader):
        # Move batch to CPU for processing if it's on GPU
        images = images.cpu()

        # Iterate through the batch of images
        for img_idx, img_tensor in enumerate(images):
            # Convert to numpy array for easy access
            img_array = img_tensor.permute(1, 2, 0).numpy()  # (height, width, 3)

            # Check and count each specified pixel coordinate
            for (i, j) in pixel_coords:
                if i < img_array.shape[0] and j < img_array.shape[1]:
                    pixel = tuple(img_array[i, j])  # Extract RGB tuple
                    if pixel not in pixel_freq[(i, j)]:
                        pixel_freq[(i, j)][pixel] = []
                    pixel_freq[(i, j)][pixel].append((batch_idx * len(images) + img_idx, pixel))

        print(f"Processed batch {batch_idx + 1}/{len(data_loader)}")

    return pixel_freq


In [5]:
# Define the top 22 coordinates
top_22_coords = [
    (118, 178), (100, 181), (75, 164), (137, 103), (126, 78),
    (74, 175), (146, 110), (86, 46), (158, 98), (90, 173),
    (106, 134), (84, 165), (97, 45), (74, 174), (77, 163),
    (84, 110), (90, 174), (137, 87), (86, 106), (186, 142),
    (74, 173), (138, 87)
]

# Calculate pixel frequencies
pixel_freq = calculate_pixel_frequencies_from_loader(train_loader, top_22_coords)

# Inspect results for a specific coordinate
print(pixel_freq[(118, 178)])

Processed batch 1/119
Processed batch 2/119
Processed batch 3/119
Processed batch 4/119
Processed batch 5/119
Processed batch 6/119
Processed batch 7/119
Processed batch 8/119
Processed batch 9/119
Processed batch 10/119
Processed batch 11/119
Processed batch 12/119
Processed batch 13/119
Processed batch 14/119
Processed batch 15/119
Processed batch 16/119
Processed batch 17/119
Processed batch 18/119
Processed batch 19/119
Processed batch 20/119
Processed batch 21/119
Processed batch 22/119
Processed batch 23/119
Processed batch 24/119
Processed batch 25/119
Processed batch 26/119
Processed batch 27/119
Processed batch 28/119
Processed batch 29/119
Processed batch 30/119
Processed batch 31/119
Processed batch 32/119
Processed batch 33/119
Processed batch 34/119
Processed batch 35/119
Processed batch 36/119
Processed batch 37/119
Processed batch 38/119
Processed batch 39/119
Processed batch 40/119
Processed batch 41/119
Processed batch 42/119
Processed batch 43/119
Processed batch 44/1

In [10]:
def aggregate_rgb_frequencies(pixel_freq, top_coords):
    """
    Aggregate RGB frequencies for the top pixel coordinates across images.

    Args:
        pixel_freq (dict): Dictionary of pixel frequencies with coordinates as keys
                           and RGB frequency counts as values.
        top_coords (list of tuples): List of top (x, y) pixel coordinates to aggregate.

    Returns:
        pd.DataFrame: Aggregated DataFrame with RGB frequencies for the top coordinates.
    """
    data = []

    # Filter pixel frequencies for only the top coordinates
    filtered_pixel_freq = {coord: pixel_freq[coord] for coord in top_coords if coord in pixel_freq}

    # Convert filtered pixel frequency data into a flat list for DataFrame
    for (i, j), rgb_counts in filtered_pixel_freq.items():
        for rgb, count in rgb_counts.items():
            data.append((i, j, rgb[0], rgb[1], rgb[2], len(count)))

    # Create a DataFrame
    df = pd.DataFrame(data, columns=['x', 'y', 'R', 'G', 'B', 'frequency'])

    # Aggregate frequencies for each (x, y) coordinate across images
    result = df.groupby(['x', 'y', 'R', 'G', 'B'], as_index=False).agg({'frequency': 'sum'})

    return result

In [11]:
# Assuming `pixel_freq` is the output from `calculate_pixel_frequencies_from_loader`
result_df = aggregate_rgb_frequencies(pixel_freq, top_22_coords)

# Display the result
print(result_df)

          x    y    R         G         B  frequency
0        74  173  0.0  0.000000  0.000000         27
1        74  173  0.0  0.000000  0.007843          3
2        74  173  0.0  0.003922  0.003922          1
3        74  173  0.0  0.003922  0.007843          1
4        74  173  0.0  0.003922  0.027451          1
...     ...  ...  ...       ...       ...        ...
159611  186  142  1.0  0.996078  1.000000          2
159612  186  142  1.0  1.000000  0.356863          1
159613  186  142  1.0  1.000000  0.992157          4
159614  186  142  1.0  1.000000  0.996078          1
159615  186  142  1.0  1.000000  1.000000         78

[159616 rows x 6 columns]


In [None]:
def analyze_max_x_for_epsilon(df, t, epsilon):
    """
    Analyze the maximum x (sample count) for epsilon constraint for each pixel and its RGB value.

    Args:
        df (pd.DataFrame): DataFrame containing pixel coordinates (x, y), RGB values (R, G, B), and frequencies.
        t (int): Threshold value for sampling.
        epsilon (float): Probability constraint for selecting samples.

    Returns:
        pd.DataFrame: DataFrame with columns (x, y, R, G, B, max_x) indicating the max samples for each pixel.
    """
    def max_x_for_epsilon(freq, t, epsilon):
        freq = int(freq)  # Ensure freq is an integer
        remaining_count = max(0, freq - int(t))  # Ensure remaining_count is an integer and non-negative
        max_x = 0
        for x in range(1, remaining_count + 1):  # Loop over integer values
            probability = comb(freq, x) / comb(remaining_count, x)
            if probability <= epsilon:
                max_x = x
            else:
                break
        return max_x

    results = []
    for _, row in df.iterrows():
        max_x = max_x_for_epsilon(row['frequency'], t, epsilon)
        results.append((row['x'], row['y'], row['R'], row['G'], row['B'], max_x))

    return pd.DataFrame(results, columns=['x', 'y', 'R', 'G', 'B', 'max_x'])


In [None]:
results_df = analyze_max_x_for_epsilon(result_df, t=2, epsilon=0.05)
print(results_df)

In [None]:
def sample_rgb_values(pixel_freq, pixel_coords, results_df, original_df):
    """
    Sample RGB values for specified pixel coordinates based on results and original data.

    Args:
        pixel_freq (dict): Pixel frequency data with coordinates as keys and RGB counts as values.
        pixel_coords (list of tuples): List of (x, y) coordinates to process.
        results_df (pd.DataFrame): DataFrame with pixel coordinates, RGB values, and max_x values.
        original_df (pd.DataFrame): Original aggregated RGB frequency DataFrame.

    Returns:
        dict: A dictionary with sampled RGB values for each pixel coordinate.
    """
    # Initialize sampled RGB values dictionary
    sampled_rgb_values = {coord: {} for coord in pixel_coords}

    # Process each pixel coordinate
    for (i, j) in pixel_coords:
        # Filter results for the current coordinate
        coord_df = results_df[(results_df['x'] == i) & (results_df['y'] == j)]

        # Sample RGB values for each row in the results DataFrame
        for _, row in coord_df.iterrows():
            r, g, b, max_x = row['R'], row['G'], row['B'], int(row['max_x'])

            # Filter original data for matching coordinates and RGB values
            original_coord_df = original_df[
                (original_df['x'] == i) & 
                (original_df['y'] == j) & 
                (original_df['R'] == r) & 
                (original_df['G'] == g) & 
                (original_df['B'] == b)
            ]

            # Extract image indices for the current pixel value
            rgb_values = [
                entry[0] for _, orig_row in original_coord_df.iterrows()
                for entry in pixel_freq[(i, j)][(orig_row['R'], orig_row['G'], orig_row['B'])]
            ]

            # Sample up to max_x values
            sampled_rgb_values[(i, j)][(r, g, b)] = (
                random.sample(rgb_values, max_x) if len(rgb_values) >= max_x else rgb_values
            )

    return sampled_rgb_values