In [None]:
import os
import pickle
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm
import sqlite3
import shutil
from datetime import datetime

# Load PCA-reduced embeddings and image paths
pca_embeddings_path = "combined_embeddings.pkl"
db_path = "image_metadata.db"

# Load the PCA-reduced embeddings
with open(pca_embeddings_path, "rb") as f:
    embeddings_data = pickle.load(f)

uuids = list(embeddings_data.keys())
embeddings = np.array(list(embeddings_data.values()))


# Function to fetch image paths from SQLite database
def fetch_image_paths_from_db(db_path, uuids, batch_size=512):
    image_paths = {}
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    for i in tqdm(range(0, len(uuids), batch_size), desc="Fetching image paths"):
        uuid_batch = uuids[i : i + batch_size]
        placeholders = ", ".join("?" for _ in uuid_batch)
        query = f"SELECT uuid, file_path FROM images WHERE uuid IN ({placeholders})"
        cursor.execute(query, uuid_batch)
        rows = cursor.fetchall()
        image_paths.update({uuid: file_path for uuid, file_path in rows})

    conn.close()
    return image_paths


# Fetch image paths
image_paths = fetch_image_paths_from_db(db_path, uuids)

# Reduce the size of images
image_size = (32, 32)
transform = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])

# Process images in batches
batch_size = 512
num_batches = len(embeddings) // batch_size + 1

# Global log directory for all experiments
global_log_dir = "logs/pca_embeddings"

# Experiment-specific subdirectory (using a timestamp for uniqueness)
experiment_name = datetime.now().strftime("%Y%m%d-%H%M%S")
experiment_log_dir = os.path.join(global_log_dir, experiment_name)

os.makedirs(experiment_log_dir, exist_ok=True)

writer = SummaryWriter(experiment_log_dir)

# Add progress bar with tqdm
for i in tqdm(range(num_batches), desc="Processing Batches"):
    start_idx = i * batch_size
    end_idx = min((i + 1) * batch_size, len(embeddings))

    batch_embeddings = embeddings[start_idx:end_idx]
    batch_uuids = uuids[start_idx:end_idx]
    batch_image_paths = [image_paths[uuid] for uuid in batch_uuids]
    batch_labels = [os.path.basename(path) for path in batch_image_paths]

    batch_images = []
    for path in batch_image_paths:
        try:
            img = Image.open(path).convert("RGB")
            img_tensor = transform(img)
            batch_images.append(img_tensor)
        except Exception as e:
            print(f"Error processing image {path}: {e}")
            batch_images.append(torch.zeros(3, *image_size))

    batch_images = torch.stack(batch_images)

    # Debugging: Check tensor sizes before logging
    print(
        f"Logging batch {i+1}/{num_batches} with {len(batch_embeddings)} embeddings and {len(batch_images)} images."
    )
    print(f"  - Embeddings shape: {batch_embeddings.shape}")
    print(f"  - Images tensor shape: {batch_images.shape}")

    writer.add_embedding(
        torch.tensor(batch_embeddings),
        metadata=batch_labels,
        label_img=batch_images,
        global_step=i,  # Different step for each batch
    )

# Flush and close the writer
writer.flush()
writer.close()
print(f"Embeddings logged to {experiment_log_dir}.")

In [None]:
import os
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import pandas as pd
import torchvision.transforms as transforms
from sklearn.metrics.pairwise import cosine_similarity

# Define the root directory containing all subfolders
root_dir = "logs\\pca_embeddings\\20240812-130708"
folders = [
    f for f in sorted(os.listdir(root_dir)) if os.path.isdir(os.path.join(root_dir, f))
]

# Initialize lists to hold combined data
all_embeddings = []
all_metadata = []
all_label_imgs = []

sprite_size = (32, 32)  # Reduced sprite size to make the overall log size smaller
max_sprite_dim = 4096  # TensorBoard max allowed dimensions

# Initial batch size
batch_size = 100  # Start with a small batch size
max_batch_size = 100  # Max batch size to prevent excessive memory usage
batch_counter = 0


# Function to filter out black or nearly black images
def is_black_or_nearly_black(img_tensor, threshold=10):
    return img_tensor.mean() < threshold / 255.0


# Function to split large sprites into smaller ones
def split_sprite(sprite_img, sprite_size, max_sprite_dim):
    sprite_width, sprite_height = sprite_img.size
    n_images_x = sprite_width // sprite_size[0]
    n_images_y = sprite_height // sprite_size[1]

    sub_sprites = []
    sub_labels = []
    x_splits = max_sprite_dim // sprite_size[0]
    y_splits = max_sprite_dim // sprite_size[1]

    for i in range(0, n_images_y, y_splits):
        for j in range(0, n_images_x, x_splits):
            box = (
                j * sprite_size[0],
                i * sprite_size[1],
                min((j + x_splits) * sprite_size[0], sprite_width),
                min((i + y_splits) * sprite_size[1], sprite_height),
            )
            sub_sprite = sprite_img.crop(box)
            sub_sprites.append(sub_sprite)

            # Calculate the number of images in this sub-sprite
            sub_nx = (box[2] - box[0]) // sprite_size[0]
            sub_ny = (box[3] - box[1]) // sprite_size[1]

            # Create labels for this sub-sprite
            sub_labels.extend(
                [
                    f"sprite_{len(sub_sprites) - 1}_{x}_{y}"
                    for y in range(sub_ny)
                    for x in range(sub_nx)
                ]
            )

    return sub_sprites, sub_labels


# Function to save batch data
def save_batch(batch_counter, all_embeddings, all_metadata, all_label_imgs):
    # Concatenate all data
    embeddings_combined = torch.cat(all_embeddings, dim=0)
    label_img_combined = torch.cat(all_label_imgs, dim=0)

    # Ensure all arrays have the same number of elements
    min_length = min(
        len(embeddings_combined), len(all_metadata), len(label_img_combined)
    )
    embeddings_combined = embeddings_combined[:min_length]
    all_metadata = all_metadata[:min_length]
    label_img_combined = label_img_combined[:min_length]

    # Create TensorBoard writer
    log_dir = f"logs/combined_embeddings_try2/{batch_counter:05d}"
    os.makedirs(log_dir, exist_ok=True)
    writer = SummaryWriter(log_dir)

    # Log embeddings under a common name
    writer.add_embedding(
        embeddings_combined,
        metadata=all_metadata,
        label_img=label_img_combined,
        tag="Combined Embeddings",
        global_step=0,
    )

    # Flush and close the writer
    writer.flush()
    writer.close()

    print(f"Combined embeddings logged to {log_dir}.")


# Iterate over each folder and load data
for folder in folders:
    folder_path = os.path.join(
        root_dir, folder, "default"
    )  # Include 'default' in the path

    # Load metadata
    metadata_file = os.path.join(folder_path, "metadata.tsv")
    metadata = pd.read_csv(metadata_file, sep="\t", header=None)
    all_metadata.extend(metadata.values.flatten().tolist())

    # Load tensors
    tensors_file = os.path.join(folder_path, "tensors.tsv")
    tensors = pd.read_csv(tensors_file, sep="\t", header=None)
    tensor_embeddings = torch.tensor(tensors.values, dtype=torch.float32)
    all_embeddings.append(tensor_embeddings)

    # Calculate cosine similarity and cluster similar embeddings
    similarities = cosine_similarity(tensor_embeddings)
    sorted_indices = np.argsort(similarities, axis=1)[
        :, ::-1
    ]  # Sort by descending similarity

    # Load and split sprite image
    sprite_file = os.path.join(folder_path, "sprite.png")
    sprite_img = Image.open(sprite_file)
    sub_sprites, sub_labels = split_sprite(sprite_img, sprite_size, max_sprite_dim)

    # Process each sub-sprite, keeping only similar images and filtering out black images
    for sub_sprite in sub_sprites:
        sprite_width, sprite_height = sub_sprite.size
        n_images_x = sprite_width // sprite_size[0]
        n_images_y = sprite_height // sprite_size[1]

        label_imgs = []
        for i in range(n_images_y):
            for j in range(n_images_x):
                box = (
                    j * sprite_size[0],
                    i * sprite_size[1],
                    (j + 1) * sprite_size[0],
                    (i + 1) * sprite_size[1],
                )
                img = sub_sprite.crop(box)
                img_tensor = transforms.ToTensor()(img)

                # Filter out black or nearly black images
                if not is_black_or_nearly_black(img_tensor):
                    label_imgs.append(img_tensor)

        if label_imgs:
            all_label_imgs.append(torch.stack(label_imgs))

    # Check if the combined sprite, metadata, or embeddings are exceeding memory or dimension limits
    if (
        len(all_embeddings) >= batch_size
        or sum([img.nelement() for img in all_label_imgs]) >= max_sprite_dim**2
        or len(all_metadata) >= max_sprite_dim**2
    ):
        save_batch(batch_counter, all_embeddings, all_metadata, all_label_imgs)
        batch_counter += 1

        # Adjust batch size dynamically based on the previous batch
        if len(all_embeddings) >= batch_size:
            batch_size = min(
                batch_size + 10, max_batch_size
            )  # Increase batch size if possible
        else:
            batch_size = max(batch_size - 10, 10)  # Decrease batch size if necessary

        # Reset lists for the next batch
        all_embeddings = []
        all_metadata = []
        all_label_imgs = []

# Save any remaining data after the loop ends
if all_embeddings:
    save_batch(batch_counter, all_embeddings, all_metadata, all_label_imgs)

print(f"Processed {batch_counter + 1} batches.")