In [None]:
import os
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from tqdm import tqdm
import albumentations as A
import cv2

# -----------------------------
# PATHS
# -----------------------------
base_dir    = r"E:/SLIIT/Year 2 Semester 1/IT2011 - Artficial Intelligence and Machine Learning/Assignment/WildFireDetection/data/raw"
resize_dir  = r"E:/SLIIT/Year 2 Semester 1/IT2011 - Artficial Intelligence and Machine Learning/Assignment/WildFireDetection/results/outputs/group_pipeline/resized"
color_dir   = r"E:/SLIIT/Year 2 Semester 1/IT2011 - Artficial Intelligence and Machine Learning/Assignment/WildFireDetection/results/outputs/group_pipeline/color_balanced"
denoise_dir = r"E:/SLIIT/Year 2 Semester 1/IT2011 - Artficial Intelligence and Machine Learning/Assignment/WildFireDetection/results/outputs/group_pipeline/denoised"
aug_dir     = r"E:/SLIIT/Year 2 Semester 1/IT2011 - Artficial Intelligence and Machine Learning/Assignment/WildFireDetection/results/outputs/group_pipeline/augmented"
edge_dir    = r"E:/SLIIT/Year 2 Semester 1/IT2011 - Artficial Intelligence and Machine Learning/Assignment/WildFireDetection/results/outputs/group_pipeline/edge"
norm_dir    = r"E:/SLIIT/Year 2 Semester 1/IT2011 - Artficial Intelligence and Machine Learning/Assignment/WildFireDetection/results/outputs/group_pipeline/normalized"

splits = ['train', 'val', 'test']
classes = ['fire', 'nofire']

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# TRANSFORMS
to_tensor = transforms.ToTensor()
to_pil = transforms.ToPILImage()

# Normalization transform
normalize_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Gaussian blur for denoising
gaussian_blur = transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0))

# Augmentations
fire_aug = A.Compose([
    A.Rotate(limit=30, p=0.7),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.HueSaturationValue(p=0.5),
])

nofire_aug = A.Compose([
    A.Rotate(limit=15, p=0.3),
    A.HorizontalFlip(p=0.3),
])


# 1. RESIZE by Sanchala K.A.N (IT24100260)

print("\n=== RESIZING IMAGES ===")

for split in splits:
    for cls in classes:
        input_path = os.path.join(base_dir, split, cls)
        output_path = os.path.join(resize_dir, split, cls)
        os.makedirs(output_path, exist_ok=True)

        img_files = [f for f in os.listdir(input_path) if f.endswith('.jpg')]
        
        for img_name in tqdm(img_files, desc=f"Resizing {split}/{cls}", unit="img"):
            img = Image.open(os.path.join(input_path, img_name)).convert("RGB")
            
            # Resize transform
            resize_transform = transforms.Compose([
                transforms.Resize((600, 600)),
                transforms.ToTensor()
            ])
            
            img_tensor = resize_transform(img).to(device)
            
            # Save resized image
            img_pil = to_pil_image(img_tensor.cpu())
            img_pil.save(os.path.join(output_path, img_name))

# 2. COLOR BALANCE by Rajapakshe R.P.P.S (IT24100368)

print("\n=== COLOR BALANCING ===")

for split in splits:
    for cls in classes:
        in_path  = os.path.join(resize_dir, split, cls)
        out_path = os.path.join(color_dir, split, cls)
        os.makedirs(out_path, exist_ok=True)
        
        img_files = [f for f in os.listdir(in_path) if f.endswith('.jpg')]
        
        for img_name in tqdm(img_files, desc=f"Color balancing {split}/{cls}", unit="img"):
            try:

                img = cv2.imread(os.path.join(in_path, img_name))
                hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
                hsv[:,:,2] = cv2.equalizeHist(hsv[:,:,2])
                balanced = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
                cv2.imwrite(os.path.join(out_path, img_name), balanced)

            except Exception as e:
                tqdm.write(f" Error processing {img_name}: {e}")

# 3. DENOISE + ENHANCE by Vithanage M.M (IT24100288)

print("\n=== DENOISING AND ENHANCEMENT ===")

for split in splits:
    for cls in classes:
        input_path = os.path.join(color_dir, split, cls)
        output_path = os.path.join(denoise_dir, split, cls)
        os.makedirs(output_path, exist_ok=True)

        img_files = [f for f in os.listdir(input_path) if f.endswith('.jpg')]
        
        for img_name in tqdm(img_files, desc=f"Denoising {split}/{cls}", unit="img"):

            try:
                img = Image.open(os.path.join(input_path, img_name)).convert("RGB")
                img_tensor = to_tensor(img).to(device)
                
                # Apply denoising (Gaussian blur)
                img_denoised = gaussian_blur(img_tensor.unsqueeze(0)).squeeze(0)
                
                # Apply enhancement (CLAHE)
                img_np = (img_denoised.cpu().permute(1, 2, 0).numpy() * 255).astype(np.uint8)
                lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
                clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
                lab[:, :, 0] = clahe.apply(lab[:, :, 0])
                enhanced_np = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
                img_enhanced = torch.from_numpy(enhanced_np).permute(2, 0, 1).float() / 255.0
                
                # Save result
                to_pil_image(img_enhanced.cpu()).save(os.path.join(output_path, img_name))
            
            except Exception as e:
                print(f"Error with {img_name}: {e}")


# 4. DATA AUGMENTATION by Nimneth P.B.Y (IT24100304)

print("\n=== DATA AUGMENTATION ===")

for split in splits:
    for cls in classes:
        input_path = os.path.join(denoise_dir, split, cls)
        output_path = os.path.join(aug_dir, split, cls)
        os.makedirs(output_path, exist_ok=True)

        img_files = [f for f in os.listdir(input_path) if f.endswith('.jpg')]
        
        if split == 'train':
            # Class balancing logic
            fire_count = len([f for f in os.listdir(os.path.join(denoise_dir, split, 'fire')) if f.endswith('.jpg')])
            nofire_count = len([f for f in os.listdir(os.path.join(denoise_dir, split, 'nofire')) if f.endswith('.jpg')])
            target_size = max(fire_count, nofire_count)
            needed = target_size - len(img_files)
            
            # Select augmentation pipeline
            aug_transform = fire_aug if cls == 'fire' else nofire_aug
            
            # Save originals first
            for img_name in img_files:
                img = Image.open(os.path.join(input_path, img_name)).convert("RGB")
                img.save(os.path.join(output_path, img_name))
            
            # Augment until class is balanced
            i = 0
            while i < needed:
                for img_name in img_files:
                    if i >= needed:
                        break

                    try:
                        img = Image.open(os.path.join(input_path, img_name)).convert("RGB")
                        img_np = np.array(img)
                        augmented = aug_transform(image=img_np)['image']
                        aug_img = Image.fromarray(augmented)
                        save_name = f"aug_{i}_{img_name}"
                        aug_img.save(os.path.join(output_path, save_name))
                        i += 1
                    except Exception as e:
                        tqdm.write(f"⚠️ Error processing {img_name}: {e}")

        else:
            # For val/test just copy images
            for img_name in img_files:
                try:
                    img = Image.open(os.path.join(input_path, img_name)).convert("RGB")
                    img.save(os.path.join(output_path, img_name))
                except Exception as e:
                    tqdm.write(f"⚠️ Error processing {img_name}: {e}")


# 5. EDGE DETECTION by Perera B.P.N (IT24100327)

print("\n=== EDGE DETECTION ===")

for split in splits:
    for cls in classes:
        in_path  = os.path.join(aug_dir, split, cls)
        out_path = os.path.join(edge_dir, split, cls)
        os.makedirs(out_path, exist_ok=True)
        
        img_files = [f for f in os.listdir(in_path) if f.endswith('.jpg')]
        
        for img_name in tqdm(img_files, desc=f"Edge detection {split}/{cls}", unit="img"):

            try:
                img = cv2.imread(os.path.join(in_path, img_name))
                gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
                
                # Edge detection (Sobel)
                sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
                sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
                sobel = np.sqrt(sobelx**2 + sobely**2)
                sobel = np.uint8(np.clip(sobel, 0, 255))
                
                # Save edge map as JPG
                cv2.imwrite(os.path.join(out_path, img_name), sobel)

            except Exception as e:
                tqdm.write(f" Error processing {img_name}: {e}")
            

# 6. NORMALIZATION by Thaveesha L.H.K (IT24100368)

print("\n=== NORMALIZATION ===")

for split in splits:
    for cls in classes:
        input_path = os.path.join(aug_dir, split, cls)
        output_path = os.path.join(norm_dir, split, cls)
        os.makedirs(output_path, exist_ok=True)

        img_files = [f for f in os.listdir(input_path) if f.endswith('.jpg')]
        
        for img_name in tqdm(img_files, desc=f"Normalizing {split}/{cls}", unit="img"):
            try: 
                img = Image.open(os.path.join(input_path, img_name)).convert("RGB")
                
                # Normalize
                img_norm = normalize_transform(img).to(device)
                
                # Save normalized image (reversed for visualization)
                img_save = img_norm * torch.tensor([0.229, 0.224, 0.225], device=device).view(3, 1, 1) + \
                        torch.tensor([0.485, 0.456, 0.406], device=device).view(3, 1, 1)
                img_save = img_save.clamp(0, 1)
                img_save = to_pil_image(img_save.cpu())
                img_save.save(os.path.join(output_path, img_name))

            except Exception as e:
                tqdm.write(f" Error processing {img_name}: {e}")

print("\n=== PIPELINE COMPLETE ===")
print("All processing steps finished successfully!")