In [None]:
import os
import sys

# Get the directory of the current notebook 
notebook_dir_path = os.path.dirname(os.path.abspath('__file__'))
project_path = os.path.join(notebook_dir_path, '..') # Change this

sys.path.append(notebook_dir_path)
sys.path.append(project_path)

from struct import unpack
from collections import defaultdict
import math
import os
import matplotlib.pyplot as plt
import binascii
import json
import numpy as np
import time
from IPython.display import display, Markdown


# Setup

In [None]:
total_num_of_images = 150  # Example value, set this to the desired number of iterations

# image_directory_path = f"/share/openimage_validation"
image_directory_path = f"/share/flickr30k_images"

# image_directory_path = f"{project_path}/data/test_images"
jpeg_data_path_to_save = f"{project_path}/outputs/jpeg_data_{total_num_of_images}.bin"
metrics_path_to_save = f"{project_path}/outputs/metrics_{total_num_of_images}.json"

target_image_size = 256  # Optional: Resize all images to (H, W) while H=W
batch_size=32

In [None]:


marker_mapping = {
    0xffd8: "Start of Image",
    0xfffe: "Comment",
    0xffe0: "Application Default Header",
    0xffdb: "Quantization Table",
    0xffc0: "Start of Frame",
    0xffc4: "Define Huffman Table",
    0xffda: "Start of Scan",
    0xffd9: "End of Image"
}


class JPEG:
    def __init__(self, image_file):
        with open(image_file, 'rb') as f:
            self.img_data = f.read()
    
    def decode(self):
        data = self.img_data
        while(True):
            marker, = unpack(">H", data[0:2])
            print(marker_mapping.get(marker))
            if marker == 0xffd8: # "Start of Image", there is no `length` field after it.
                data = data[2:]
            elif marker == 0xffd9: # "End of Image"
                return
            elif marker == 0xffda: # If pointing to "Start of Scan", jump to the last 2nd byte, i.e., "End of Image".
                len_start_of_scan_and_image_data = len(data[:2])
                print(f"Length of \"Start of Scan\" segment (including \"Image Data\" segment): {len_start_of_scan_and_image_data}")
                
                data = data[-2:]
            else:
                lenchunk, = unpack(">H", data[2:4])
                print(f"Length of this segment: {lenchunk}")
                data = data[2+lenchunk:]            
            if len(data)==0:
                break 
    def extract_data(self):
        '''
        Extract the binaries between `0xffd8` and `0xffd9`, i.e., all the data between "Start of Image" and "End of Image" segments.
        '''
        data = self.img_data
        start_idx = data.find(b'\xff\xd8')  # Start of Image
        if start_idx == -1:
            return None  # Start marker not found

        end_idx = data.find(b'\xff\xd9', start_idx)  # End of Image
        if end_idx == -1:
            return None  # End marker not found

        start_idx += 2 # Exclute the start marker
        length = end_idx - start_idx
        return data[start_idx:end_idx], length  # Exclute the end marker.



def calculate_ngram_distribution(tokens, N=20):
    """Calculate the n-gram distribution for a list of tokens."""
    # Dictionary to count occurrences of each n-gram
    ngram_counts = defaultdict(int)
    
    # Iterate over the token list by stepping through each token
    for i in range(len(tokens) - N + 1):  # Adjust loop to ensure enough tokens for the last n-gram
        # Extract N consecutive tokens to form an n-gram
        ngram = tuple(tokens[i:i+N])
        ngram_counts[ngram] += 1
    
    # Total number of n-grams
    total_ngrams = sum(ngram_counts.values())
    
    # Dictionary to store the probability of each n-gram
    ngram_probabilities = {k: v / total_ngrams for k, v in ngram_counts.items()}
    
    return ngram_probabilities





def calculate_ngram_frequencies(tokens, N=20):
    """Calculate the n-gram frequencies for a list of tokens."""
    # Dictionary to count occurrences of each n-gram
    ngram_counts = defaultdict(int)
    
    # Iterate over the token list by stepping through each token
    for i in range(len(tokens) - N + 1):  # Adjust loop to ensure enough tokens for the last n-gram
        # Extract N consecutive tokens to form an n-gram
        ngram = tuple(tokens[i:i+N])
        ngram_counts[ngram] += 1
    
    return ngram_counts

def calculate_entropy(probabilities):
    entropy = 0
    for prob in probabilities.values():
        if prob > 0:  # Log of zero is undefined, so we skip those probabilities
            entropy -= prob * math.log2(prob)
    return entropy


def calculate_kl_divergence(ngram_distribution):
    """Calculate the KL divergence between n-gram distribution and a uniform distribution."""
    alphabet_size = len(ngram_distribution)
    if alphabet_size == 0:
        return 0

    uniform_prob = 1 / alphabet_size
    kl_divergence = sum(
        p * math.log2(p / uniform_prob)
        for p in ngram_distribution.values()
        if p > 0
    )
    return kl_divergence

def calculate_token_frequencies(tokens):
    """Calculate the frequency of each token in the list."""
    token_counts = defaultdict(int)
    for token in tokens:
        token_counts[token] += 1
    return token_counts

def filter_top_frequent_token(tokens):
    """Filter out the top-1 frequent token from the token list."""
    # Calculate token frequencies
    token_counts = calculate_token_frequencies(tokens)
    
    # Identify the most frequent token
    most_frequent_token = max(token_counts, key=token_counts.get)
    
    # Create a new list excluding the most frequent token
    filtered_tokens = [token for token in tokens if token != most_frequent_token]
    
    return filtered_tokens

def remove_top_ngram(ngram_frequencies, ngram_distribution):
    """Remove the n-gram with the top frequency/probability and recalculate the distribution."""
    
    # Identify the n-gram with the highest frequency/probability
    top_ngram = max(ngram_frequencies, key=ngram_frequencies.get)
    
    # Remove the top n-gram
    ngram_frequencies.pop(top_ngram, None)
    
    # Recalculate the total n-grams
    total_ngrams = sum(ngram_frequencies.values())
    
    # Recalculate the probabilities
    ngram_distribution = {k: v / total_ngrams for k, v in ngram_frequencies.items()}
    
    return ngram_frequencies, ngram_distribution
    

## Byte pair encoding

In [None]:
from collections import defaultdict

def get_byte_pairs(tokens):
    """Extract all byte pairs from the tokens."""
    pairs = defaultdict(int)
    prev_token = tokens[0]
    for token in tokens[1:]:
        pairs[(prev_token, token)] += 1
        prev_token = token
    return pairs

def merge_byte_pair(tokens, pair_to_merge, new_token):
    """Merge the most frequent byte pair in the tokens using a new token."""
    merged_tokens = []
    i = 0
    while i < len(tokens):
        if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == pair_to_merge:
            merged_tokens.append(new_token)
            i += 2
        else:
            merged_tokens.append(tokens[i])
            i += 1
    return merged_tokens

def byte_pair_encoding(tokens, num_merges):
    """Perform BPE on token data."""
    max_token_value = max(tokens)
    current_token_value = max_token_value + 1
    merge_operations = []

    for _ in range(num_merges):
        pairs = get_byte_pairs(tokens)
        if not pairs:
            break
        most_frequent_pair = max(pairs, key=pairs.get)
        tokens = merge_byte_pair(tokens, most_frequent_pair, current_token_value)
        merge_operations.append((current_token_value, most_frequent_pair))
        current_token_value += 1

    return tokens, merge_operations

def decode_bpe(encoded_tokens, merge_operations):
    """Decode the BPE encoded tokens using the merge operations."""
    for new_token, pair in reversed(merge_operations):
        decoded_tokens = []
        i = 0
        while i < len(encoded_tokens):
            if encoded_tokens[i] == new_token:
                decoded_tokens.extend(pair)
                i += 1
            else:
                decoded_tokens.append(encoded_tokens[i])
                i += 1
        encoded_tokens = decoded_tokens
    return encoded_tokens

def compression_ratio(original_data, compressed_data):
    """Calculate the compression ratio."""
    return len(compressed_data) / len(original_data)


In [None]:
# encoded_data = byte_pair_encoding(b'FFAAdwdwd212114241423432421431', num_merges=10)

# Read binary data from JPEG files

In [None]:
# Collect all JPEG or PNG image paths
image_paths = [os.path.join(image_directory_path, f) for f in os.listdir(image_directory_path) if f.endswith(('.jpg', '.jpeg', ''))]



binaries_and_lengths = []

for i, image_path in enumerate(image_paths):
    if i >= total_num_of_images:
        break
    jpeg_image = JPEG(image_path)
    data, length = jpeg_image.extract_data()
    binaries_and_lengths.append(
        {"data": data,
         "length": length}
    )
    # Optional: print information about each processed image
    print(f"File: {os.path.basename(image_path)}")
    print(f"Length of data: {length} bytes")
    # print(f"Data snippet (first 100 bytes or full data if shorter): {data[:100]}")

# Concatenate all binary data for byte pair distribution analysis
all_binary_data = b''.join([item['data'] for item in binaries_and_lengths if item['data']])

print(f"Number of images: {len(binaries_and_lengths)}")
print(f"Size in bytes: {len(all_binary_data)}")


In [None]:
#BPE

# Convert binary data to a list of integers
tokens = list(all_binary_data)

# Apply BPE on the byte data
num_merges = 5  # Adjust this value based on experimentation
encoded_tokens, merge_operations = byte_pair_encoding(tokens, num_merges)


# Print the result
print(f"Original data: {all_binary_data}")
print(f"Encoded tokens: {len(encoded_tokens)}")
print(f"Compression ratio: {compression_ratio(tokens, encoded_tokens):.2f}")

# # Decode the BPE encoded data
# decoded_tokens = decode_bpe(encoded_tokens, merge_operations)

# # Convert back to binary data
# decoded_binary_data = bytes(decoded_tokens)
# print(f"Decoded data: {decoded_binary_data}")

# # Verify the decoding process
# assert all_binary_data == decoded_binary_data, "Decoded data does not match original data!"

# print("Decoding successful and data matches the original.")


# encoded_tokens = filter_top_frequent_token(encoded_tokens)
# print(f"Tokens left after filtering: {len(encoded_tokens)}")

# Iterage N

In [None]:
%%time

metrics = []

for N in range(1, 21):  # 21 is exclusive, so it iterates up to 20
    start_time = time.time()
    
    ngram_distribution = calculate_ngram_distribution(encoded_tokens, N=N)
    ngram_frequencies= calculate_ngram_frequencies(encoded_tokens, N=N)

    ngram_distribution, ngram_frequencies = remove_top_ngram(ngram_distribution, ngram_frequencies) # Remove top-1 frequent ngram
    # Calculate statistical measures
    frequency_values = list(ngram_frequencies.values())
    mean_frequency = np.mean(frequency_values)
    median_frequency = np.median(frequency_values)

    # Calculate entropy
    entropy = calculate_entropy(ngram_distribution)
    alphabet_size = len(ngram_distribution)
    entropy_if_uniform = math.log2(alphabet_size) if alphabet_size > 0 else 0
    difference = abs(entropy - entropy_if_uniform)
    kl_divergence = calculate_kl_divergence(ngram_distribution)
    
    
    # Sort ngrams by frequencies in descending order
    sorted_ngrams_by_freq = sorted(ngram_frequencies.items(), key=lambda x: x[1], reverse=True)

    # Find the index where frequency drops to 1
    # (assuming there are any n-grams with frequency 1)
    index_frequency_one = next((i for i, (_, freq) in enumerate(sorted_ngrams_by_freq) if freq == 1), None)

    top_5_freq_ngrams = sorted_ngrams_by_freq[:5]

    # print("Top-5 n-grams and their frequencies:")
    # for ngram, frequency in top_5_freq_ngrams:
    #     print(f"{ngram}: {frequency}")
        
    # # Assuming ngram_frequencies has been computed
    # top_ngrams = sorted(ngram_frequencies.items(), key=lambda x: x[1], reverse=True)[:5]
    # # Prepare and print top 5 n-grams with their frequencies
    # top_ngrams_formatted = []
    # for ngram, frequency in top_ngrams:
    #     # Convert the binary n-gram to a hexadecimal string
    #     hex_representation = "0x" + binascii.hexlify(ngram).decode().upper()
    #     # Append the formatted string with frequency to the list
    #     top_ngrams_formatted.append(f"{hex_representation} (Frequency: {frequency})")
        
    
    end_time = time.time()
    execution_time = end_time - start_time

    metrics.append({
        "N": N,
        "alphabet_size": alphabet_size,
        "frequency_values": frequency_values,
        "entropy": entropy,
        "kl_divergence": kl_divergence,
        "index_frequency_one": index_frequency_one,
        "execution_time": execution_time
    })
    
    # Output results
    print(f"N-gram where N={N}")
    print(f"Alphabet size (unique tokens): {alphabet_size}")
    print(f"Mean frequency: {mean_frequency:.2f}, Median frequency: {median_frequency}")
    print(f"The entropy of the token distribution is: {entropy:.4f} bits")
    print(f"The entropy of the ideal token uniform distribution is: {entropy_if_uniform:.4f} bits")
    print(f"The KL divergence from uniform distribution: {kl_divergence:.4f} bits")
    print(f"The difference: {difference:.4f} bits")

    # if index_frequency_one is not None:
    #     print(f"Index after which all n-grams have frequencies of 1: {index_frequency_one}")
    # else:
    #     print("No n-grams with frequency of 1 found.")

    # # Print all top n-grams
    # print("Top 5 N-grams and their frequencies:")
    # for ngram in top_ngrams_formatted:
    #     print(ngram)
    
    print(f"Execution time: {execution_time:.4f} s")
    
    print("-" * 50)  # Separator for readability between different N outputs


# Save JPEG data and metrics

In [None]:
%%time

# Save binary data to file
with open(jpeg_data_path_to_save, 'wb') as file:
    file.write(all_binary_data)
    
# Save metrics to file
with open(metrics_path_to_save, 'w') as file:
    json.dump(metrics, file, indent=4)


# Plot the metrics for Ns
See `./analyze_jepg_stats.ipynb`

In [None]:
%%time 

# Extract values for plotting
N_values = [metric['N'] for metric in metrics]
entropies = [metric['entropy'] for metric in metrics]
kl_divergences = [metric['kl_divergence'] for metric in metrics]
index_frequencies = [metric['index_frequency_one'] if metric['index_frequency_one'] is not None else -1 for metric in metrics]
alphabet_sizes = [metric['alphabet_size'] for metric in metrics]
frequency_values = [metric['frequency_values'] for metric in metrics] if 'frequency_values' in metrics[0] else [[] for _ in metrics]
execution_times = [metric['execution_time'] for metric in metrics]

# Plotting the metrics
plt.figure(figsize=(15, 18))

# Plot entropy
plt.subplot(6, 1, 1)
plt.plot(N_values, entropies, marker='o', linestyle='-', color='b')
plt.title('Entropy of tokens (BPEed N-gram) Distributions for Different N')
plt.xlabel('N (N-gram size)')
plt.ylabel('Entropy (bits)')
plt.grid(True)
plt.xticks(N_values)  # Set x-ticks to be integers
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.4f}'))
for i, entropy in enumerate(entropies):
    plt.annotate(f'{entropy:.4f}', (N_values[i], entropies[i]), textcoords="offset points", xytext=(0,5), ha='center')

# Plot KL divergence
plt.subplot(6, 1, 2)
plt.plot(N_values, kl_divergences, marker='o', linestyle='-', color='g')
plt.title('KL Divergence of tokens (BPEed N-gram) Distributions from Uniform Distributions for Different N')
plt.xlabel('N (N-gram size)')
plt.ylabel('KL Divergence (bits)')
plt.grid(True)
plt.xticks(N_values)  # Set x-ticks to be integers
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.4f}'))
for i, kl_divergence in enumerate(kl_divergences):
    plt.annotate(f'{kl_divergence:.4f}', (N_values[i], kl_divergences[i]), textcoords="offset points", xytext=(0,5), ha='center')

# Plot index where frequency drops to 1
plt.subplot(6, 1, 3)
plt.plot(N_values, index_frequencies, marker='o', linestyle='-', color='r')
plt.title('Index Frequency One of tokens (BPEed N-gram) Distributions for Different N')
plt.xlabel('N (N-gram size)')
plt.ylabel('Index Frequency One')
plt.grid(True)
plt.xticks(N_values)  # Set x-ticks to be integers
for i, index in enumerate(index_frequencies):
    plt.annotate(f'{index}', (N_values[i], index_frequencies[i]), textcoords="offset points", xytext=(0,5), ha='center')

# Plot alphabet size
plt.subplot(6, 1, 4)
plt.plot(N_values, alphabet_sizes, marker='o', linestyle='-', color='m')
plt.title('Alphabet Size of tokens (BPEed N-gram) Distributions for Different N')
plt.xlabel('N (N-gram size)')
plt.ylabel('Alphabet Size')
plt.grid(True)
plt.xticks(N_values)  # Set x-ticks to be integers
for i, alphabet_size in enumerate(alphabet_sizes):
    plt.annotate(f'{alphabet_size}', (N_values[i], alphabet_sizes[i]), textcoords="offset points", xytext=(0,5), ha='center')

# Plot frequency values (log scale)
plt.subplot(6, 1, 5)
for i, freq_vals in enumerate(frequency_values):
    plt.plot([N_values[i]] * len(freq_vals), freq_vals, 'bo', alpha=0.5)
plt.yscale('log')
plt.title('Frequency Values of tokens (BPEed N-gram) Distributions for Different N (Log Scale)')
plt.xlabel('N (N-gram size)')
plt.ylabel('Frequency Values')
plt.grid(True)
plt.xticks(N_values)  # Set x-ticks to be integers

# Plot execution time
plt.subplot(6, 1, 6)
plt.plot(N_values, execution_times, marker='o', linestyle='-', color='c')
plt.title('Execution Time for Different N')
plt.xlabel('N (N-gram size)')
plt.ylabel('Execution Time (seconds)')
plt.grid(True)
plt.xticks(N_values)  # Set x-ticks to be integers
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.4f}'))
for i, exec_time in enumerate(execution_times):
    plt.annotate(f'{exec_time:.4f}', (N_values[i], execution_times[i]), textcoords="offset points", xytext=(0,5), ha='center')

plt.tight_layout()
plt.show()


In [None]:
# Generate Markdown table for the metrics
markdown_table = "| N | Entropy (bits) | KL Divergence (bits) | Index Frequency One | Alphabet Size | Execution Time (seconds) |\n"
markdown_table += "|---|----------------|----------------------|----------------------|---------------|-------------------------|\n"

for metric in metrics:
    N = metric['N']
    entropy = f"{metric['entropy']:.4f}"
    kl_divergence = f"{metric['kl_divergence']:.4f}"
    index_frequency_one = metric['index_frequency_one'] if metric['index_frequency_one'] is not None else "N/A"
    alphabet_size = metric['alphabet_size']
    execution_time = f"{metric['execution_time']:.4f}"
    
    markdown_table += f"| {N} | {entropy} | {kl_divergence} | {index_frequency_one} | {alphabet_size} | {execution_time} |\n"

# Render the Markdown table
display(Markdown(markdown_table))