In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os


# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

## Cell 1: Install Dependencies

### We need transformers for the auto-captioning (BLIP) and diffusers for later training.

In [None]:
# Install dependencies for Preprocessing only (No xformers needed yet)
!pip install -q diffusers transformers accelerate

In [None]:
try:
    import diffusers
    import transformers
    import torch
    print("SUCCESS: All necessary libraries are imported correctly!")
    print(f"Torch version: {torch.__version__}")
except ImportError as e:
    print(f"FAILED: {e}")

## Cell 3: The Preprocessing Script

### This script fulfills the requirement to "Construct a dataset with paired images, conditions, and prompts". What this code does:

    Auto-Captions: Uses the GPU to generate captions via BLIP.

    Formats: Saves everything to /kaggle/working/, which is the writeable area in Kaggle.

In [None]:
import os
import shutil
import json
import torch
from PIL import Image
from tqdm.notebook import tqdm
from transformers import BlipProcessor, BlipForConditionalGeneration

# ================= CONFIGURATION =================
# Root paths (Verify these match your Kaggle input structure)
BASE_PATH = "/kaggle/input/flir-thermal-images-dataset/FLIR_ADAS_1_3"

# Define the tasks: (Input RGB, Input Thermal, Output Split Name)
TASKS = [
    {
        "rgb_dir": f"{BASE_PATH}/train/RGB",
        "thermal_dir": f"{BASE_PATH}/train/thermal_8_bit",
        "split_name": "train",
        "limit": 2000  # Limit training images to save time/space? (Set None for all)
    },
    {
        "rgb_dir": f"{BASE_PATH}/val/RGB",
        "thermal_dir": f"{BASE_PATH}/val/thermal_8_bit",
        "split_name": "validation",
        "limit": None  # Usually keep all validation images (dataset is smaller)
    }
]

# Output Location
WORKING_DIR = "/kaggle/working/controlnet_dataset"
RESOLUTION = 512
# =================================================

def setup_caption_model():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Loading BLIP captioning model on {device}...")
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
    return processor, model, device

def process_split(task, processor, model, device):
    """Processes a single split (train or validation)"""
    split = task['split_name']
    print(f"\n--- Processing Split: {split.upper()} ---")
    
    # 1. Create Split Directories
    split_dir = os.path.join(WORKING_DIR, split)
    os.makedirs(f"{split_dir}/images", exist_ok=True)
    os.makedirs(f"{split_dir}/conditioning_images", exist_ok=True)

    # 2. Get Files
    if not os.path.exists(task['rgb_dir']):
        print(f"WARNING: Path not found {task['rgb_dir']}")
        return

    all_files = sorted([f for f in os.listdir(task['rgb_dir']) if f.endswith('.jpg')])
    
    # Apply limit if set
    if task['limit']:
        all_files = all_files[:task['limit']]
        
    print(f"Found {len(all_files)} images. Starting processing...")
    metadata = []

    for filename in tqdm(all_files):
        rgb_path = os.path.join(task['rgb_dir'], filename)
        
        # Handle Thermal Extension Mismatch (.jpg vs .jpeg)
        thermal_name = filename.replace(".jpg", ".jpeg")
        thermal_path = os.path.join(task['thermal_dir'], thermal_name)
        
        # Fallback to original name if .jpeg doesn't exist
        if not os.path.exists(thermal_path):
             thermal_path = os.path.join(task['thermal_dir'], filename)

        if os.path.exists(thermal_path):
            try:
                # Resize
                img_rgb = Image.open(rgb_path).convert("RGB").resize((RESOLUTION, RESOLUTION))
                img_thermal = Image.open(thermal_path).convert("L").resize((RESOLUTION, RESOLUTION))
                
                # Caption (Only need to generate once, used for both)
                inputs = processor(img_rgb, return_tensors="pt").to(device)
                out = model.generate(**inputs, max_new_tokens=50)
                caption = processor.decode(out[0], skip_special_tokens=True)
                
                # Save Images (PNG is safer for lossless)
                save_name = filename.replace(".jpg", ".png").replace(".jpeg", ".png")
                
                # Note: We save RELATIVE paths in the metadata for HuggingFace ImageFolder
                img_rgb.save(f"{split_dir}/images/{save_name}")
                img_thermal.save(f"{split_dir}/conditioning_images/{save_name}")
                
                metadata.append({
                    "file_name": f"images/{save_name}",
                    "conditioning_image": f"conditioning_images/{save_name}",
                    "text": caption
                })
            except Exception as e:
                print(f"Error processing {filename}: {e}")

    # 3. Save JSONL Metadata
    # Hugging Face ImageFolder expects 'metadata.jsonl' in the split folder
    with open(f"{split_dir}/metadata.jsonl", 'w') as f:
        for line in metadata:
            f.write(json.dumps(line) + "\n")
            
    print(f"Completed {split}. Saved {len(metadata)} pairs.")

def main():
    # Setup
    if os.path.exists(WORKING_DIR): shutil.rmtree(WORKING_DIR)
    processor, model, device = setup_caption_model()
    
    # Process both Train and Val
    for task in TASKS:
        process_split(task, processor, model, device)
        
    # Zip Everything
    print("\nZipping dataset...")
    shutil.make_archive("/kaggle/working/flir_controlnet_split", 'zip', WORKING_DIR)
    print("DONE! Download 'flir_controlnet_split.zip' from Output.")

if __name__ == "__main__":
    main()

## Cell 4: Visualize Your Data (Quality Check)

### The guidelines require you to show "Condition â†’ Output grids". Use this code to verify your data looks correct before training.

In [None]:
import matplotlib.pyplot as plt
import random
import json
import os
from PIL import Image

# ================= CONFIGURATION =================
# Set this to the same directory used in the processing script
BASE_DIR = "/kaggle/working/controlnet_dataset" 
SPLIT = "validation"  # Change to 'validation' to check the other split
# =================================================

# 1. Define paths based on the split
split_dir = os.path.join(BASE_DIR, SPLIT)
metadata_path = os.path.join(split_dir, "metadata.jsonl")

# 2. Load the metadata
if not os.path.exists(metadata_path):
    print(f"Error: Metadata file not found at {metadata_path}")
else:
    with open(metadata_path, 'r') as f:
        lines = f.readlines()

    if not lines:
        print("Error: Metadata file is empty.")
    else:
        # 3. Pick a random sample
        sample = json.loads(random.choice(lines))

        # 4. Construct full image paths
        # Note: metadata stores paths like "images/file.png", so we join with split_dir
        target_path = os.path.join(split_dir, sample['file_name'])
        cond_path = os.path.join(split_dir, sample['conditioning_image'])

        # 5. Load Images
        try:
            img_target = Image.open(target_path)
            img_cond = Image.open(cond_path)

            # 6. Plot
            fig, ax = plt.subplots(1, 2, figsize=(12, 6))
            
            # Show Thermal Condition
            ax[0].imshow(img_cond, cmap='gray')
            ax[0].set_title("Condition (Thermal Input)")
            ax[0].axis('off')

            # Show RGB Target with Caption
            ax[1].imshow(img_target)
            # Wrap text if it's too long
            caption = sample['text']
            title_text = f"Target (RGB)\nPrompt: {caption[:50]}..." if len(caption) > 50 else f"Target (RGB)\nPrompt: {caption}"
            
            ax[1].set_title(title_text, fontsize=10)
            ax[1].axis('off')

            plt.tight_layout()
            plt.show()
            
            print(f"Sampled from: {SPLIT}")
            print(f"Full Caption: {caption}")

        except FileNotFoundError:
            print(f"Error: Could not find image files.\nTarget: {target_path}\nCond: {cond_path}")