# Color Conversion Notebook for Semantic Scene Layouts

This notebook provides a complete pipeline to recolor an existing dataset of scene layout images. It replaces the original colors with a new, semantically meaningful color scheme derived from text embeddings.

### Workflow:
1.  **Collect Labels**: Scan all `_tokens.json` files to gather a unique set of all object labels.
2.  **Generate Semantic Colors**: Use a fine-tuned Sentence Transformer model to create embeddings for each label, then use PCA to map these embeddings to RGB colors. Labels with similar meanings will receive similar colors.
3.  **Map Old to New**: Load the original `color_legend.json` to create a mapping from the old color values to the new semantic ones.
4.  **Process Images**: Read each `.png` image, replace the old colors pixel by pixel with the new colors, and save the result to a new directory.

**Hardcoded Colors:**
* `floor`: Black (255, 255, 255)
* `wall`: Black (0, 0, 0)

## 1. Setup

First, this cell installs all the necessary libraries for the script.

In [1]:
!pip install numpy Pillow sentence-transformers scikit-learn tqdm -q

## 2. Imports and Configuration

This cell imports the required libraries and sets up the main configuration paths. Make sure the paths point to the correct directories in your project.

In [2]:
import os
import json
import numpy as np
from PIL import Image
from sklearn.decomposition import PCA
from sentence_transformers import SentenceTransformer
from tqdm.notebook import tqdm # Use notebook-friendly tqdm
from pathlib import Path

print("Imports and configuration loaded.")

 # --- Configuration ---
# Directory containing your original images and token files.
INPUT_DIR = Path("output_pairs_old")
# Directory where the recolored images will be saved.
OUTPUT_DIR = Path("recolored_images")
# Path to your original color legend.
ORIGINAL_LEGEND_PATH = INPUT_DIR / "color_legend.json"
# Path to the fine-tuned model you created in the embeddings_analysis notebook.
FINETUNED_MODEL_PATH = './fine_tuned_bert'

Imports and configuration loaded.


## 3. Function Definitions

Here we define all the necessary functions for the color conversion pipeline.

In [3]:
def collect_all_labels(directory: Path) -> list:
    """Scans all *_tokens.json files to find unique furniture labels."""
    print("--- Collecting all unique labels from token files... ---")
    all_labels = set()
    token_files = list(directory.glob("*_tokens.json"))

    if not token_files:
        print(f"[Warning] No token files found in '{directory}'.")
        return []
        
    for fname in tqdm(token_files, desc="Scanning labels"):
        try:
            with open(fname, "r") as f:
                tokens = json.load(f)
            # A triplet is a list of 3 strings
            triplets = [t for t in tokens if isinstance(t, list) and len(t) == 3]
            for subj, _, obj in triplets:
                # Add subject and object labels
                if isinstance(subj, str): all_labels.add(subj)
                if isinstance(obj, str): all_labels.add(obj)
        except (json.JSONDecodeError, FileNotFoundError):
            continue
            
    # Remove architectural elements that we will hardcode
    all_labels.discard('wall')
    all_labels.discard('floor')
    
    print(f"Found {len(all_labels)} unique furniture labels.")
    return sorted(list(all_labels))

def create_semantic_color_map(labels: list, model_path: str) -> dict:
    """Generates a new color map based on sentence transformer embeddings."""
    print(f"... Generating semantic colors using '{model_path}'... ---")
    if not labels:
        print("[Warning] No labels to process for color map generation.")
        return {}

    try:
        model = SentenceTransformer(model_path)
    except Exception as e:
        print(f"[Error] Could not load model from '{model_path}'. Make sure the path is correct.")
        print(f"Details: {e}")
        return {}

    embeddings = model.encode(labels, show_progress_bar=True)
    
    # Reduce dimensionality from 384D to 3D for RGB mapping
    pca = PCA(n_components=3)
    components = pca.fit_transform(embeddings)

    # Normalize components to a 0-1 range
    min_vals = components.min(axis=0)
    max_vals = components.max(axis=0)
    range_vals = max_vals - min_vals
    # Avoid division by zero if all values in a component are the same
    range_vals[range_vals == 0] = 1 
    
    normalized = (components - min_vals) / range_vals
    rgb_values = (normalized * 255).astype(int)

    # Create the label-to-color map
    color_map = {label: tuple(rgb) for label, rgb in zip(labels, rgb_values)}
    
    # --- Apply Hardcoded Colors ---
    color_map['wall'] = (0, 0, 0)      # Black
    color_map['floor'] = (255, 255, 255)     # white!!!
    
    print("Semantic color map created successfully.")
    return color_map

def create_recolor_map(original_legend_path: Path, new_color_map: dict) -> dict:
    """Creates a map from old RGB values to new RGB values."""
    print("--- Mapping old colors to new semantic colors... ---")
    try:
        with open(original_legend_path, 'r') as f:
            original_map = json.load(f)
    except FileNotFoundError:
        print(f"[Error] Original color legend not found at '{original_legend_path}'. Cannot proceed.")
        return {}

    recolor_map = {}
    for label, new_color in new_color_map.items():
        if label in original_map:
            old_color = tuple(original_map[label])
            recolor_map[old_color] = new_color
            
    print(f"Created a mapping for {len(recolor_map)} colors.")
    return recolor_map

def process_images(input_dir: Path, output_dir: Path, recolor_map: dict):
    """Applies the color conversion to all images in the directory."""
    print("--- Processing and recoloring images... ---")
    output_dir.mkdir(parents=True, exist_ok=True)
    image_files = list(input_dir.glob("*.png"))

    if not image_files:
        print(f"[Warning] No PNG images found in '{input_dir}'.")
        return

    for img_path in tqdm(image_files, desc="Recoloring images"):
        try:
            img = Image.open(img_path).convert('RGB')
            data = np.array(img)
            
            # Create a copy to modify
            new_data = data.copy()

            # Replace colors efficiently using numpy masks
            for old_color, new_color in recolor_map.items():
                mask = np.all(data == old_color, axis=-1)
                new_data[mask] = new_color
            
            new_img = Image.fromarray(new_data)
            new_img.save(output_dir / img_path.name)
        except Exception as e:
            print(f"Could not process {img_path.name}. Error: {e}")

## 4. Main Execution

This final cell runs the entire pipeline. It calls the functions defined above in sequence to perform the color conversion.

In [4]:
if not INPUT_DIR.exists():
    print(f"Error: Input directory '{INPUT_DIR}' not found. Please check the path.")
elif not ORIGINAL_LEGEND_PATH.exists():
    print(f"Error: The original 'color_legend.json' file is required but was not found in '{INPUT_DIR}'.")
else:
    # Step 1: Collect labels
    unique_labels = collect_all_labels(INPUT_DIR)
    
    # Step 2: Create the new semantic color map
    semantic_map = create_semantic_color_map(unique_labels, FINETUNED_MODEL_PATH)
    
    if semantic_map:
        # Save the new color legend for reference
        OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
        semantic_legend_path = OUTPUT_DIR / "semantic_color_legend.json"
        with open(semantic_legend_path, 'w') as f:
            # Convert numpy tuples to standard python tuples for JSON serialization
            serializable_map = {k: [int(i) for i in v] for k, v in semantic_map.items()}
            json.dump(serializable_map, f, indent=4)
        print(f"New semantic color legend saved to '{semantic_legend_path}'")

        # Step 3: Create the final old-to-new color mapping
        final_recolor_map = create_recolor_map(ORIGINAL_LEGEND_PATH, semantic_map)

        # Step 4: Process the images
        if final_recolor_map:
            process_images(INPUT_DIR, OUTPUT_DIR, final_recolor_map)
            print(f"\n✨ Done! All images have been recolored and saved in '{OUTPUT_DIR}'.")
        else:
            print("\nCould not create a recolor map. Aborting image processing.")
    else:
        print("\nCould not create a semantic color map. Aborting.")

--- Collecting all unique labels from token files... ---


Scanning labels:   0%|          | 0/1372 [00:00<?, ?it/s]

Found 598 unique furniture labels.
... Generating semantic colors using './fine_tuned_bert'... ---


Batches:   0%|          | 0/19 [00:00<?, ?it/s]

  return forward_call(*args, **kwargs)


Semantic color map created successfully.
New semantic color legend saved to 'recolored_images\semantic_color_legend.json'
--- Mapping old colors to new semantic colors... ---
Created a mapping for 176 colors.
--- Processing and recoloring images... ---


Recoloring images:   0%|          | 0/1372 [00:00<?, ?it/s]


✨ Done! All images have been recolored and saved in 'recolored_images'.
