In [1]:
import numpy as np
import cv2
from skimage.color import rgb2lab
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
def read_bounding_boxes(label_path):
    """Read YOLO format bounding boxes and convert to pixel coordinates."""
    boxes = []
    if label_path.exists():
        with open(label_path, 'r') as f:
            for line in f:
                if line.strip():
                    # Parse YOLO format: class x_center y_center width height
                    _, x_center, y_center, width, height = map(float, line.strip().split())
                    
                    # Convert to pixel coordinates (assuming 512x512 images)
                    img_width = img_height = 512
                    x1 = int((x_center - width/2) * img_width)
                    y1 = int((y_center - height/2) * img_height)
                    x2 = int((x_center + width/2) * img_width)
                    y2 = int((y_center + height/2) * img_height)
                    
                    # Ensure coordinates are within image bounds
                    x1 = max(0, min(img_width, x1))
                    y1 = max(0, min(img_height, y1))
                    x2 = max(0, min(img_width, x2))
                    y2 = max(0, min(img_height, y2))
                    
                    boxes.append([x1, y1, x2, y2])
    return boxes

In [3]:
def create_cielab_mask(lab_image, hue_lower=-14.98, hue_upper=96.22, chroma_threshold=6.11):
    """Create binary mask using CIELAB color space."""
    # Extract a* and b* components
    a = lab_image[:, :, 1]
    b = lab_image[:, :, 2]
    
    # Calculate hue in degrees
    hue = np.degrees(np.arctan2(b, a))
    
    # Ensure hue is in range [-180, 180]
    hue = np.where(hue < -180, hue + 360, hue)
    hue = np.where(hue > 180, hue - 360, hue)
    
    # Calculate chroma
    chroma = np.sqrt(a**2 + b**2)
    
    # Create binary mask
    mask = np.zeros_like(hue, dtype=np.uint8)
    
    # Apply conditions
    fire_pixels = (hue >= hue_lower) & (hue <= hue_upper) & (chroma >= chroma_threshold)
    mask[fire_pixels] = 255
    
    return mask

In [4]:
def create_multichannel_image(image_path, label_path):
    """Create 8-channel image from RGB image and bounding boxes."""
    # Read image and convert to RGB
    image = cv2.imread(str(image_path))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Read bounding boxes
    boxes = read_bounding_boxes(label_path)
    
    # Create masks for bounding boxes
    box_mask = np.zeros(image.shape[:2], dtype=np.uint8)
    for box in boxes:
        x1, y1, x2, y2 = box
        box_mask[y1:y2, x1:x2] = 255
    
    # Convert to CIELAB
    lab_image = rgb2lab(image)
    
    # Create CIELAB mask only within bounding boxes
    cielab_mask = np.zeros_like(box_mask)
    for box in boxes:
        x1, y1, x2, y2 = box
        box_lab = lab_image[y1:y2, x1:x2]
        box_cielab_mask = create_cielab_mask(box_lab)
        cielab_mask[y1:y2, x1:x2] = box_cielab_mask
    
    # Create first 4-channel image (RGB + CIELAB aux)
    first_4ch = np.dstack((image, cielab_mask))
    
    # Create second RGB image (only within boxes)
    masked_image = np.zeros_like(image)
    for box in boxes:
        x1, y1, x2, y2 = box
        masked_image[y1:y2, x1:x2] = image[y1:y2, x1:x2]
    
    # Create second 4-channel image
    second_4ch = np.dstack((masked_image, cielab_mask))
    
    # Combine into 8-channel image
    multichannel = np.dstack((first_4ch, second_4ch))
    
    return multichannel

In [8]:
def process_dataset(dataset_path='../fire_detection_dataset'):
    """Process all images in the dataset."""
    dataset_path = Path(dataset_path)
    total_processed = 0
    total_errors = 0
    
    for split in ['train', 'val', 'test']:
        print(f"\nProcessing {split} split...")
        
        # Create output directory if it doesn't exist
        (dataset_path / split / 'multichannel').mkdir(exist_ok=True)
        
        # Get all images in the split
        image_paths = list((dataset_path / split / 'images').glob('*.jpg'))
        
        for image_path in tqdm(image_paths):
            try:
                # Get corresponding label path
                label_path = dataset_path / split / 'labels' / f"{image_path.stem}.txt"
                
                # Create multichannel image
                multichannel = create_multichannel_image(image_path, label_path)
                
                # Save as numpy array
                output_path = dataset_path / split / 'multichannel' / f"{image_path.stem}.npy"
                np.save(str(output_path), multichannel)
                
                total_processed += 1
                
            except Exception as e:
                print(f"Error processing {image_path}: {str(e)}")
                total_errors += 1
                continue

In [9]:
process_dataset()


Processing train split...


100%|██████████| 1645/1645 [00:56<00:00, 29.15it/s]



Processing val split...


100%|██████████| 151/151 [00:04<00:00, 31.04it/s]



Processing test split...


100%|██████████| 207/207 [00:06<00:00, 30.01it/s]


In [None]:
def visualize_channels(multichannel):
    """Visualize all channels of a multichannel image."""
    # Create figure
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    
    # Channel names
    channel_names = [
        'RGB-R (Full)', 'RGB-G (Full)', 'RGB-B (Full)', 'CIELAB Aux',
        'RGB-R (Boxed)', 'RGB-G (Boxed)', 'RGB-B (Boxed)', 'CIELAB Aux'
    ]
    
    # Plot each channel
    for i in range(8):
        row = i // 4
        col = i % 4
        
        axes[row, col].imshow(multichannel[:, :, i], cmap='gray')
        axes[row, col].set_title(channel_names[i])
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()