# Dataset Filtering: Fire Class Only

This notebook filters the `MyData_Fire` dataset to create a new dataset containing **only** images and labels where the **Fire** class (ID 0) is present.

**Logic:**
- **Keep**: Labels containing Class 0 (Fire).
- **Discard**: Labels containing ONLY Class 1 (Smoke) or empty labels.

**Output:**
- `../Filtered_Fire_Dataset/images`
- `../Filtered_Fire_Dataset/labels`

In [None]:
import os
import glob
import shutil
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm

# Configuration
SOURCE_ROOT = r'../MyData_Fire'
OUTPUT_ROOT = r'../Filtered_Fire_Dataset'
SPLITS = ['train', 'valid', 'test']
TARGET_CLASS = 0 # Fire
CLASS_NAMES = {0: 'Fire', 1: 'Smoke'}

# Create Output Directories
os.makedirs(os.path.join(OUTPUT_ROOT, 'images'), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_ROOT, 'labels'), exist_ok=True)

print(f"Source: {os.path.abspath(SOURCE_ROOT)}")
print(f"Output: {os.path.abspath(OUTPUT_ROOT)}")

## 1. Filter and Copy Data

In [None]:
stats = {
    'processed': 0,
    'kept': 0,
    'discarded': 0
}

for split in SPLITS:
    print(f"Processing {split} split...")
    images_dir = os.path.join(SOURCE_ROOT, split, 'images')
    labels_dir = os.path.join(SOURCE_ROOT, split, 'labels')
    
    label_files = glob.glob(os.path.join(labels_dir, '*.txt'))
    
    for label_path in tqdm(label_files, desc=f"Filtering {split}"):
        stats['processed'] += 1
        
        # Check content
        has_fire = False
        with open(label_path, 'r') as f:
            lines = f.readlines()
            for line in lines:
                parts = line.strip().split()
                if len(parts) >= 5:
                    cls_id = int(parts[0])
                    if cls_id == TARGET_CLASS:
                        has_fire = True
                        break
        
        if has_fire:
            # Copy files
            basename = os.path.splitext(os.path.basename(label_path))[0]
            
            # Find corresponding image (check extensions)
            image_found = False
            for ext in ['.jpg', '.jpeg', '.png']:
                img_path = os.path.join(images_dir, basename + ext)
                if os.path.exists(img_path):
                    shutil.copy(img_path, os.path.join(OUTPUT_ROOT, 'images', basename + ext))
                    shutil.copy(label_path, os.path.join(OUTPUT_ROOT, 'labels', basename + '.txt'))
                    stats['kept'] += 1
                    image_found = True
                    break
            
            if not image_found:
                print(f"Warning: Image not found for {label_path}")
        else:
            stats['discarded'] += 1

print("\nFiltering Complete!")
print(f"Total Processed: {stats['processed']}")
print(f"Kept (Fire): {stats['kept']}")
print(f"Discarded (No Fire): {stats['discarded']}")

## 2. Verify Filtered Data
Check that the kept labels actually contain class 0.

In [None]:
filtered_labels = glob.glob(os.path.join(OUTPUT_ROOT, 'labels', '*.txt'))
print(f"Verifying {len(filtered_labels)} filtered labels...")

error_count = 0
for label_path in tqdm(filtered_labels):
    has_fire = False
    with open(label_path, 'r') as f:
        lines = f.readlines()
        for line in lines:
            parts = line.strip().split()
            if len(parts) >= 5 and int(parts[0]) == TARGET_CLASS:
                has_fire = True
                break
    if not has_fire:
        print(f"Error: {label_path} does not contain Fire!")
        error_count += 1

if error_count == 0:
    print("Verification Successful: All filtered files contain Fire.")
else:
    print(f"Verification Failed: {error_count} files incorrect.")

## 3. Visualize Samples

In [None]:
def visualize_sample(img_path, label_path, class_map=CLASS_NAMES):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h_img, w_img, _ = img.shape
    
    with open(label_path, 'r') as f:
        lines = f.readlines()
        
    for line in lines:
        parts = line.strip().split()
        if len(parts) >= 5:
            cls_id = int(parts[0])
            x_center, y_center, w, h = map(float, parts[1:5])
            
            # Convert YOLO format to corners
            x1 = int((x_center - w/2) * w_img)
            y1 = int((y_center - h/2) * h_img)
            x2 = int((x_center + w/2) * w_img)
            y2 = int((y_center + h/2) * h_img)
            
            color = (255, 0, 0) if cls_id == 0 else (0, 255, 0) # Red for Fire, Green for Smoke
            cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
            label = class_map.get(cls_id, str(cls_id))
            cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
            
    plt.figure(figsize=(10, 10))
    plt.imshow(img)
    plt.axis('off')
    plt.title(f"Filtered Sample: {os.path.basename(img_path)}")
    plt.show()

import random
filtered_images = glob.glob(os.path.join(OUTPUT_ROOT, 'images', '*'))
if filtered_images:
    print("Visualizing random samples from Filtered Dataset:")
    for _ in range(3):
        img_path = random.choice(filtered_images)
        basename = os.path.splitext(os.path.basename(img_path))[0]
        label_path = os.path.join(OUTPUT_ROOT, 'labels', basename + '.txt')
        visualize_sample(img_path, label_path)
else:
    print("No images in filtered dataset.")