In [1]:
import os
import sys
import time
import sqlite3
import instaloader
import numpy as np
from PIL import Image
from sklearn.cluster import KMeans
from skimage import color, feature
from collections import Counter
from scipy.stats import skew, kurtosis
import cv2
from datetime import timedelta, datetime

# Check required libraries
def check_libraries():
    required_libraries = {
        'numpy': 'numpy',
        'Pillow': 'PIL',
        'scikit-learn': 'sklearn',
        'scikit-image': 'skimage',
        'opencv-python': 'cv2',
        'scipy': 'scipy'
    }

    missing_libraries = []

    for pip_name, import_name in required_libraries.items():
        try:
            __import__(import_name)
        except ImportError:
            missing_libraries.append(pip_name)

    if missing_libraries:
        print(f"The following libraries are required but not installed: {', '.join(missing_libraries)}")
        print("Please install them using pip:")
        print(f"pip install {' '.join(missing_libraries)}")
        sys.exit(1)

# Perform the library check
check_libraries()

# Define the database path
DATABASE_PATH = os.path.join("/Users/greyson/Projects/custom_gallery/gallery/prisma", 'image_analysis.db')

def connect_db():
    conn = sqlite3.connect(DATABASE_PATH)
    conn.execute('PRAGMA foreign_keys = ON;')  # Enable foreign key support
    return conn




total_posts = len(posts)  # Get total number of posts
print(f"Total posts to download: {total_posts}")

# Track progress and download time
start_time = time.time()  # Track the time when the download starts

# Connect to the database
conn = connect_db()
cursor = conn.cursor()

def is_similar(row, tolerance=10):
    """
    Check if all pixels in the row are similar within the given tolerance.
    """
    # Compute the difference between max and min for each channel
    diff = row.max(axis=0) - row.min(axis=0)
    return np.all(diff < tolerance)

def find_letterbox_height(image_np, tolerance=10, min_height=10, from_top=True):
    """
    Find the height of the letterbox from the top or bottom.
    """
    height, width, _ = image_np.shape
    letterbox_height = 0
    range_y = range(height) if from_top else range(height-1, -1, -1)

    for y in range_y:
        row = image_np[y, :, :]
        if is_similar(row, tolerance):
            letterbox_height +=1
        else:
            break
    # Ensure the detected letterbox is at least min_height pixels
    if letterbox_height >= min_height:
        return letterbox_height
    else:
        return 0

def remove_letterbox(image_np, tolerance=10, min_letterbox_height=10):
    """
    Detect and remove letterboxing from the top and bottom of the image.
    Returns the top and bottom letterbox heights, and the cropped image.
    """
    top_height = find_letterbox_height(image_np, tolerance, min_letterbox_height, from_top=True)
    bottom_height = find_letterbox_height(image_np, tolerance, min_letterbox_height, from_top=False)

    # and if the absolute value of the difference between the top and bottom is less than min_height
    if (top_height > 0 or bottom_height > 0) and abs(top_height - bottom_height) < min_letterbox_height:
        cropped_image = image_np[top_height: image_np.shape[0] - bottom_height, :, :]
        print(f"Removed letterbox: Top={top_height}px, Bottom={bottom_height}px")
        return top_height, bottom_height, cropped_image
    else:
        print("No letterbox detected.")
        return 0, 0, image_np

def insert_luminance(conn, image_id, luminance_metrics):
    cursor = conn.cursor()
    cursor.execute('''
    INSERT INTO Luminance (
        image_id, mean_luminance, median_luminance, std_luminance,
        dynamic_range, rms_contrast, michelson_contrast,
        luminance_skewness, luminance_kurtosis,
        min_luminance, max_luminance
    ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
    ''', (
        image_id,
        luminance_metrics["Mean Luminance"],
        luminance_metrics["Median Luminance"],
        luminance_metrics["Std Luminance"],
        luminance_metrics["Dynamic Range"],
        luminance_metrics["RMS Contrast"],
        luminance_metrics["Michelson Contrast"],
        luminance_metrics["Luminance Skewness"],
        luminance_metrics["Luminance Kurtosis"],
        luminance_metrics["Min Luminance"],
        luminance_metrics["Max Luminance"]
    ))
    conn.commit()

def insert_saturation(conn, image_id, saturation_metrics):
    cursor = conn.cursor()
    cursor.execute('''
    INSERT INTO Saturation (
        image_id, mean_saturation, median_saturation, std_saturation
    ) VALUES (?, ?, ?, ?)
    ''', (
        image_id,
        saturation_metrics["Mean Saturation"],
        saturation_metrics["Median Saturation"],
        saturation_metrics["Std Saturation"]
    ))
    conn.commit()

def insert_glcm(conn, image_id, glcm_metrics):
    cursor = conn.cursor()
    cursor.execute('''
    INSERT INTO GLCM (
        image_id, contrast, correlation
    ) VALUES (?, ?, ?)
    ''', (
        image_id,
        glcm_metrics["GLCM Contrast"],
        glcm_metrics["GLCM Correlation"]
    ))
    conn.commit()

def insert_laplacian(conn, image_id, laplacian_var):
    cursor = conn.cursor()
    cursor.execute('''
    INSERT INTO Laplacian (
        image_id, variance
    ) VALUES (?, ?)
    ''', (
        image_id,
        laplacian_var
    ))
    conn.commit()

def insert_kmeans_clustering(conn, image_id, clustering_data):
    cursor = conn.cursor()
    num_clusters = clustering_data["Number of Clusters"]
    cursor.execute('''
    INSERT INTO KMeansClustering (
        image_id, num_clusters
    ) VALUES (?, ?)
    ''', (
        image_id,
        num_clusters
    ))
    conn.commit()
    # Retrieve the last inserted row id for KMeansClustering
    clustering_id = cursor.lastrowid

    clusters_rgb = clustering_data["Cluster Centers (RGB)"]
    clusters_lab = clustering_data["Cluster Centers (LAB)"]
    counts = clustering_data["Cluster Counts"]
    percentages = clustering_data["Cluster Percentages"]

    for idx, (color_rgb, color_lab, count, pct) in enumerate(zip(clusters_rgb, clusters_lab, counts, percentages)):
        r, g, b = color_rgb
        l_val, a_val, b_val = color_lab
        cursor.execute('''
        INSERT INTO Clusters (
            clustering_id, cluster_index, r, g, b, l, a, b_channel, count, percentage
        ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        ''', (
            clustering_id,
            idx + 1,  # cluster_index starting from 1
            r,
            g,
            b,
            l_val,
            a_val,
            b_val,
            count,
            float(pct)  # Directly use pct as it's already a float
        ))
    conn.commit()

def process_image(image_path, image_id, conn):
    """
    Process a single image: remove letterbox, perform k-means clustering,
    compute metrics, and save them to the SQLite database.
    Returns True if processed successfully, False otherwise.
    """
    try:
        image = Image.open(image_path)
    except FileNotFoundError:
        print(f"Image not found at path: {image_path}")
        return False
    except Exception as e:
        print(f"Error opening image {image_path}: {e}")
        return False

    image_np = np.array(image)

    # Ensure the image is in RGB format
    if image.mode != 'RGB':
        image = image.convert('RGB')
        image_np = np.array(image)

    # Remove alpha channel if present
    if image_np.shape[2] == 4:
        image_np = image_np[:, :, :3]

    # *** Detect and Remove Letterbox ***
    # Parameters can be adjusted based on the expected letterbox characteristics
    tolerance = 2              # Tolerance for color similarity (0-255)
    min_letterbox_height = 10  # Minimum height in pixels to consider as letterbox

    top_height, bottom_height, image_np = remove_letterbox(image_np, tolerance, min_letterbox_height)

    # Update the Images table with letterbox heights
    cursor = conn.cursor()
    cursor.execute('''
    UPDATE Images
    SET letterbox_top = ?, letterbox_bottom = ?
    WHERE id = ?
    ''', (top_height, bottom_height, image_id))
    conn.commit()

    # Proceed with image processing
    # Step 2: Normalize and convert to LAB color space
    image_normalized = image_np / 255.0
    image_lab = color.rgb2lab(image_normalized)

    # Step 3: Reshape image data for clustering
    pixels_lab = image_lab.reshape(-1, 3)

    # Step 4: Perform k-means clustering with k=8
    k = 8  # Set the number of clusters
    kmeans = KMeans(n_clusters=k, random_state=42)
    kmeans.fit(pixels_lab)

    # Step 5: Process cluster centers
    cluster_centers_lab = kmeans.cluster_centers_

    # Convert cluster centers from LAB to RGB
    cluster_centers_rgb = color.lab2rgb(cluster_centers_lab.reshape(1, -1, 3))
    cluster_centers_rgb = np.squeeze(cluster_centers_rgb)
    cluster_centers_rgb_uint8 = np.clip(cluster_centers_rgb * 255, 0, 255).astype(int)

    # Step 6: Quantify and visualize cluster sizes
    labels = kmeans.labels_
    counts = Counter(labels)
    total_pixels = sum(counts.values())

    # Sort clusters by the number of pixels
    sorted_counts = counts.most_common()
    sorted_cluster_indices = [item[0] for item in sorted_counts]
    sorted_cluster_sizes = [item[1] for item in sorted_counts]
    sorted_cluster_percentages = [(count / total_pixels) * 100 for count in sorted_cluster_sizes]
    sorted_colors_rgb = np.array([cluster_centers_rgb_uint8[i] for i in sorted_cluster_indices])
    sorted_colors_lab = np.array([cluster_centers_lab[i] for i in sorted_cluster_indices])

    # *** Compute Additional Image Metrics ***

    # Convert LAB to HSV for saturation metrics
    image_hsv = color.rgb2hsv(image_normalized)
    saturation = image_hsv[:, :, 1]
    mean_saturation = np.mean(saturation)
    median_saturation = np.median(saturation)
    std_saturation = np.std(saturation)

    # Luminance Metrics
    l_channel = image_lab[:, :, 0]

    # *** Outlier Removal based on Pixel Frequency ***
    l_flat = l_channel.flatten()
    total_pixels = l_flat.size
    
    # Compute histogram of luminance values
    hist, bin_edges = np.histogram(l_flat, bins=256, range=(0, 100))
    
    # Define a frequency threshold (e.g., pixels that make up less than 0.05% of the image)
    frequency_threshold = total_pixels * 0.0005  # Adjust this value as needed
    
    # Identify bins where the frequency exceeds the threshold
    valid_bins = np.where(hist > frequency_threshold)[0]
    
    if valid_bins.size > 0:
        l_min = bin_edges[valid_bins[0]]
        l_max = bin_edges[valid_bins[-1] + 1]
    else:
        l_min = l_flat.min()
        l_max = l_flat.max()
    
    # Filter out the outliers
    l_filtered = l_flat[(l_flat >= l_min) & (l_flat <= l_max)]


    mean_luminance = np.mean(l_filtered)
    median_luminance = np.median(l_filtered)
    std_luminance = np.std(l_filtered)
    dynamic_range = l_max - l_min

    # Contrast Metrics using adjusted min and max
    rms_contrast = np.sqrt(np.mean(l_filtered**2))
    michelson_contrast = (l_max - l_min) / (l_max + l_min)

    # Skewness and Kurtosis of Luminance
    lum_skew = skew(l_filtered)
    lum_kurt = kurtosis(l_filtered)

    # Texture Features using GLCM (use original l_channel)
    glcm = feature.graycomatrix(l_channel.astype(np.uint8), distances=[5], angles=[0],
                                levels=256, symmetric=True, normed=True)
    glcm_contrast = feature.graycoprops(glcm, 'contrast')[0, 0]
    glcm_correlation = feature.graycoprops(glcm, 'correlation')[0, 0]

    # Sharpness using Variance of Laplacian (use original l_channel)
    laplacian_var = cv2.Laplacian(l_channel, cv2.CV_64F).var()

    # *** Compile All Metrics ***
    all_metrics = {
        'Mean Luminance': mean_luminance,
        'Median Luminance': median_luminance,
        'Std Luminance': std_luminance,
        'Dynamic Range': dynamic_range,
        'RMS Contrast': rms_contrast,
        'Michelson Contrast': michelson_contrast,
        'Mean Saturation': mean_saturation,
        'Median Saturation': median_saturation,
        'Std Saturation': std_saturation,
        'GLCM Contrast': glcm_contrast,
        'GLCM Correlation': glcm_correlation,
        'Laplacian Variance': laplacian_var,
        'Luminance Skewness': lum_skew,
        'Luminance Kurtosis': lum_kurt,
        'Min Luminance': l_min,
        'Max Luminance': l_max,
        'KMeans Clustering': {
            'Number of Clusters': k,
            'Cluster Centers (RGB)': sorted_colors_rgb.tolist(),
            'Cluster Centers (LAB)': sorted_colors_lab.tolist(),
            'Cluster Counts': sorted_cluster_sizes,
            'Cluster Percentages': [pct for pct in sorted_cluster_percentages]
        }
    }

    # *** Insert Data into SQLite Database ***
    insert_luminance(conn, image_id, all_metrics)
    insert_saturation(conn, image_id, all_metrics)
    insert_glcm(conn, image_id, all_metrics)
    insert_laplacian(conn, image_id, all_metrics['Laplacian Variance'])
    insert_kmeans_clustering(conn, image_id, all_metrics['KMeans Clustering'])

    return True

# Main loop over posts
for index, post in enumerate(posts, start=1):
    # Extract post data
    shortcode = post.shortcode
    post_date = post.date_utc.strftime('%Y-%m-%d %H:%M:%S')  # Format date as string
    caption = post.caption if post.caption else ""
    post_username = post.owner_username

    # Insert or ignore the post into Posts table
    cursor.execute('''
    INSERT OR IGNORE INTO Posts (shortcode, username, caption, post_date)
    VALUES (?, ?, ?, ?)
    ''', (shortcode, post_username, caption, post_date))

    # Retrieve the post ID
    cursor.execute('SELECT id FROM Posts WHERE shortcode = ?', (shortcode,))
    post_id = cursor.fetchone()[0]

    # Set the directory for the post
    dirname = f"../gallery/public/img/{username}/{post.date_utc:%Y-%m-%d_%H-%M-%S}_{post.shortcode}"
    L.dirname_pattern = dirname

    # Start the timer for this post
    post_start_time = time.time()

    # Download the post (images only)
    L.download_post(post, target='')

    # End the timer for this post
    post_end_time = time.time()

    # Calculate time taken for the current post
    time_for_post = post_end_time - post_start_time

    # Calculate total elapsed time so far
    elapsed_time = time.time() - start_time

    # Estimate remaining time based on the average time per post
    avg_time_per_post = elapsed_time / index
    remaining_posts = total_posts - index
    estimated_remaining_time = avg_time_per_post * remaining_posts

    # List the image files in 'dirname'
    image_files = []
    for root, dirs, files in os.walk(dirname):
        for file in files:
            if file.lower().endswith(('.jpg', '.jpeg', '.png', '.gif')):
                image_files.append(os.path.join(root, file))

    # Process each image
    for image_file in image_files:
        # Absolute file path
        absolute_file_path = os.path.abspath(image_file)
        # Relative file path (after 'public')
        try:
            public_index = absolute_file_path.index('public') + len('public')
            relative_file_path = absolute_file_path[public_index:].replace('\\', '/')
        except ValueError:
            # 'public' not in path, handle accordingly
            print(f"'public' not found in the path: {absolute_file_path}")
            relative_file_path = f"/img/{username}/{post.date_utc:%Y-%m-%d_%H-%M-%S}_{post.shortcode}/{os.path.basename(image_file)}"
        # Prepend '/' to make it an absolute URL path
        relative_file_path = f"/{relative_file_path.lstrip('/')}"
        # Filename
        image_filename = os.path.basename(image_file)
        # Processed at timestamp
        processed_at = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')

        # Insert into Images table
        cursor.execute('''
        INSERT OR IGNORE INTO Images (absolute_file_path, relative_file_path, filename, processed_at, post_id)
        VALUES (?, ?, ?, ?, ?)
        ''', (absolute_file_path, relative_file_path, image_filename, processed_at, post_id))

        # Retrieve the image ID
        cursor.execute('SELECT id FROM Images WHERE absolute_file_path = ?', (absolute_file_path,))
        image_id = cursor.fetchone()[0]

        # Process the image
        success = process_image(image_file, image_id, conn)
        if success:
            print(f"Processed image {image_filename}")
        else:
            print(f"Failed to process image {image_filename}")

    conn.commit()

    # Print progress
    print(f"Downloaded {index}/{total_posts} posts.")
    print(f"Time for last post: {time_for_post:.2f} seconds.")
    print(f"Estimated remaining time: {timedelta(seconds=int(estimated_remaining_time))}")
    print("-------------------------------------------------------")

# Close the database connection
conn.close()

# Calculate total time spent after all posts are downloaded
total_time_spent = time.time() - start_time
print(f"All {total_posts} posts downloaded and processed.")
print(f"Total time spent: {timedelta(seconds=int(total_time_spent))}")


Database initialized successfully.


KeyboardInterrupt: 