Connecting Drive to COllab

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
print(os.listdir("/content/drive/My Drive/iitisoc_ps1/sar_dataset/Train_Val_Oil_Spill_images/Oil/"))

image_tiff = "/content/drive/My Drive/iitisoc_ps1/sar_dataset/Train_Val_Oil_Spill_images/Oil/"

In [None]:
#installing required
!pip install --quiet rasterio scikit-image opencv-python

In [None]:
#reading an image from the 1200 images
import os
from glob import glob

image_folder = "/content/drive/My Drive/iitisoc_ps1/sar_dataset/Train_Val_Oil_Spill_images/Oil/"

tif_files = sorted(glob(os.path.join(image_folder, "*.tif")))

print(f"Total .tif files: {len(tif_files)}")


In [None]:
#premprocessimg Steps


import numpy as np
import matplotlib.pyplot as plt
import rasterio
from scipy.ndimage import binary_opening, binary_closing, median_filter # Import median filter
from skimage.filters import threshold_otsu
import os

# Load image (image #46 assumed here)
sample_tif = tif_files[46]  # Change index as needed

# --- STEP 1: Read bands ---
with rasterio.open(sample_tif) as src:
    vv = src.read(1).astype(np.float32)
    vh = src.read(2).astype(np.float32)

print("VH raw min/max:", np.nanmin(vh), np.nanmax(vh))  # Should be dB range

# --- STEP 2: Preprocess VH ---
# Skip dB conversion — already in dB!
vh_dB = vh.copy()

# Replace NaNs before filtering
vh_cleaned = np.nan_to_num(vh_dB, nan=np.nanmin(vh_dB))

# --- STEP 3: Speckle Reduction ---
# Apply Median Filter
vh_filtered = median_filter(vh_cleaned, size=3) # Adjust size as needed


# --- STEP 4: Thresholding ---
# Use manual threshold (31nd percentile)
manual_thresh = np.percentile(vh_filtered, 31)
print("Adaptive Threshold (17%):", manual_thresh)

binary_mask = vh_filtered < manual_thresh

# --- STEP 5: Morphological Cleanup ---
binary_cleaned = binary_closing(binary_opening(binary_mask))

# --- STEP 6: Visualization ---
plt.figure(figsize=(18, 5))

plt.subplot(1, 3, 1)
plt.imshow(vh_dB, cmap='gray')
plt.title("Original VH (dB)")
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(vh_filtered, cmap='gray')
plt.title("Filtered VH (Median)")
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(binary_cleaned, cmap='gray')
plt.title("Predicted Oil Spill Mask")
plt.axis('off')

plt.tight_layout()
plt.show()

# --- Histogram for inspection ---
plt.hist(vh_filtered.flatten(), bins=100)
plt.title("VH Histogram After Filtering")
plt.xlabel("Intensity")
plt.ylabel("Pixel Count")
plt.show()

In [None]:
#Viewing one ground turth mask

import os
from glob import glob
import matplotlib.pyplot as plt
import rasterio
import numpy as np

# Directory containing the ground truth masks
mask_folder = "/content/drive/My Drive/iitisoc_ps1/sar_dataset/Train_Val_Oil_Spill_mask/Mask_oil/"

# Assuming sample_tif is already defined from the previous cell (cell 4SfRBabrkqIP)
# Extract the filename from sample_tif
sample_image_filename = os.path.basename(sample_tif)

# Construct the expected mask filename
# Assuming mask filenames have the same base name as image files
sample_mask_filename = sample_image_filename # In this case, the filenames are the same

# Construct the full path to the expected mask file
sample_mask_file = os.path.join(mask_folder, sample_mask_filename)

# Check if the corresponding mask file exists
if os.path.exists(sample_mask_file):
    # Read the mask file
    try:
        with rasterio.open(sample_mask_file) as src:
            mask = src.read(1)

        # Plot the mask
        plt.figure(figsize=(8, 6))
        plt.imshow(mask, cmap='gray')
        plt.title(f"Ground Truth Mask for Image: {sample_image_filename}")
        plt.colorbar(label='Pixel Value')
        plt.axis('off')
        plt.show()

    except Exception as e:
        print(f"Error reading or plotting the mask file {sample_mask_file}: {e}")

else:
    print(f"Error: Corresponding mask file not found for image {sample_image_filename} at {sample_mask_file}")

In [None]:
#pipeline for generating masks for all 1200 oil spill images

import os
from glob import glob
import matplotlib.pyplot as plt
import rasterio
import numpy as np
from scipy.ndimage import binary_opening, binary_closing, median_filter

# Setting Directories
image_folder = "/content/drive/My Drive/iitisoc_ps1/sar_dataset/Train_Val_Oil_Spill_images/Oil/"
output_mask_folder = "/content/drive/My Drive/iitisoc_ps1/oil_spill_mask/"

# Create output directory if it doesn't exist
os.makedirs(output_mask_folder, exist_ok=True)

# Getting list of image files
tif_files = sorted(glob(os.path.join(image_folder, "*.tif")))

print(f"Found {len(tif_files)} .tif files to process.")

# Iterating through images
for i, sample_tif in enumerate(tif_files):
    image_filename = os.path.basename(sample_tif)
    output_mask_path = os.path.join(output_mask_folder, image_filename)

    print(f"Processing image {i+1}/{len(tif_files)}: {image_filename}")

    # Processing each image
    try:
        # Read bands
        with rasterio.open(sample_tif) as src:
            vv = src.read(1).astype(np.float32)
            vh = src.read(2).astype(np.float32)
            profile = src.profile # Get metadata for saving

        # Preprocess VH
        vh_dB = vh.copy()
        vh_cleaned = np.nan_to_num(vh_dB, nan=np.nanmin(vh_dB))

        # Speckle Reduction (Median Filter)
        vh_filtered = median_filter(vh_cleaned, size=3)

        # Contrast Stretching
        p2, p98 = np.percentile(vh_filtered, (2, 98))
        vh_contrast = np.clip((vh_filtered - p2) / (p98 - p2), 0, 1)

        # Thresholding
        manual_thresh = np.percentile(vh_contrast, 32)
        binary_mask = vh_contrast < manual_thresh

        # Morphological Cleanup
        binary_cleaned = binary_closing(binary_opening(binary_mask))

        # Saving the mask
        # Update profile for saving the binary mask
        profile.update(
            dtype=rasterio.uint8,
            count=1,
            compress='lzw')

        with rasterio.open(output_mask_path, 'w', **profile) as dst:
            dst.write(binary_cleaned.astype(rasterio.uint8), 1)

    except Exception as e:
        print(f"Error processing {image_filename}: {e}")

print("\nMask generation complete.")



Found 1200 .tif files to process.
Processing image 1/1200: 00000.tif




Processing image 2/1200: 00001.tif




Processing image 3/1200: 00002.tif




Processing image 4/1200: 00003.tif




Processing image 5/1200: 00004.tif




Processing image 6/1200: 00005.tif




Processing image 7/1200: 00006.tif




Processing image 8/1200: 00007.tif




Processing image 9/1200: 00008.tif




Processing image 10/1200: 00009.tif




Processing image 11/1200: 00010.tif




Processing image 12/1200: 00011.tif




Processing image 13/1200: 00012.tif




Processing image 14/1200: 00013.tif




Processing image 15/1200: 00014.tif




Processing image 16/1200: 00015.tif




Processing image 17/1200: 00016.tif




Processing image 18/1200: 00017.tif




Processing image 19/1200: 00018.tif




Processing image 20/1200: 00019.tif




Processing image 21/1200: 00020.tif




Processing image 22/1200: 00021.tif




Processing image 23/1200: 00022.tif




Processing image 24/1200: 00023.tif




Processing image 25/1200: 00024.tif




Processing image 26/1200: 00025.tif




Processing image 27/1200: 00026.tif




Processing image 28/1200: 00027.tif




Processing image 29/1200: 00028.tif




Processing image 30/1200: 00029.tif




Processing image 31/1200: 00030.tif




Processing image 32/1200: 00031.tif




Processing image 33/1200: 00032.tif




Processing image 34/1200: 00036.tif




Processing image 35/1200: 00037.tif
Processing image 36/1200: 00038.tif


KeyboardInterrupt: 

Exception ignored in: 'rasterio._env.log_error'
Traceback (most recent call last):
  File "/usr/lib/python3.11/logging/__init__.py", line 1544, in log
    def log(self, level, msg, *args, **kwargs):

KeyboardInterrupt: 


Processing image 37/1200: 00039.tif




Processing image 38/1200: 00040.tif




Processing image 39/1200: 00041.tif




Processing image 40/1200: 00042.tif




Processing image 41/1200: 00043.tif




Processing image 42/1200: 00045.tif




Processing image 43/1200: 00046.tif




Processing image 44/1200: 00047.tif




Processing image 45/1200: 00048.tif




Processing image 46/1200: 00050.tif




Processing image 47/1200: 00051.tif




Processing image 48/1200: 00052.tif




Processing image 49/1200: 00053.tif




Processing image 50/1200: 00054.tif




Processing image 51/1200: 00055.tif




Processing image 52/1200: 00056.tif




Processing image 53/1200: 00057.tif




Processing image 54/1200: 00058.tif




Processing image 55/1200: 00059.tif




Processing image 56/1200: 00061.tif


In [None]:
#Plotting Predicted mask and Truth MAsk

import os
from glob import glob
import matplotlib.pyplot as plt
import rasterio

# Directory for generated masks
output_mask_directory = "/content/drive/My Drive/iitisoc_ps1/oil_spill_mask/"

# Directory for ground truth masks
ground_truth_directory = "/content/drive/My Drive/iitisoc_ps1/sar_dataset/Train_Val_Oil_Spill_mask/Mask_oil/"


# Plot a sample from the generated masks
mask_files_to_plot = sorted(glob(os.path.join(output_mask_directory, "*.tif")))

if mask_files_to_plot:
    # Select the first mask file found
    sample_mask_file = mask_files_to_plot[51]
    print(f"Plotting sample generated mask file: {sample_mask_file}")

    try:
        with rasterio.open(sample_mask_file) as src:
            mask_image = src.read(1)  # Read the first band

        plt.figure(figsize=(8, 6))
        plt.imshow(mask_image, cmap='gray')
        plt.title(f"Sample Generated Mask: {os.path.basename(sample_mask_file)}")
        plt.colorbar(label='Pixel Value')
        plt.axis('off')
        plt.show()

    except Exception as e:
        print(f"Error reading or plotting {sample_mask_file}: {e}")
else:
    print(f"No .tif files found in the generated masks directory: {output_mask_directory}")

print("-" * 30) # Separator

# Plot a sample from the ground truth masks
ground_truth_files_to_plot = sorted(glob(os.path.join(ground_truth_directory, "*.tif")))

if ground_truth_files_to_plot:
    # Select the first ground truth mask file found
    sample_truth_file = ground_truth_files_to_plot[51]
    print(f"Plotting sample ground truth mask file: {sample_truth_file}")

    try:
        with rasterio.open(sample_truth_file) as src:
            truth_image = src.read(1)  # Read the first band

        plt.figure(figsize=(8, 6))
        plt.imshow(truth_image, cmap='gray')
        plt.title(f"Sample Ground Truth Mask: {os.path.basename(sample_truth_file)}")
        plt.colorbar(label='Pixel Value')
        plt.axis('off')
        plt.show()

    except Exception as e:
        print(f"Error reading or plotting {sample_truth_file}: {e}")
else:
    print(f"No .tif files found in the ground truth masks directory: {ground_truth_directory}")

In [None]:
import os
from glob import glob
import matplotlib.pyplot as plt
import rasterio
import numpy as np
import cv2
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def compute_iou(mask_true, mask_pred):
    """Compute Intersection over Union (IoU)"""
    intersection = np.logical_and(mask_true, mask_pred)
    union = np.logical_or(mask_true, mask_pred)
    return np.sum(intersection) / np.sum(union) if np.sum(union) > 0 else 0.0


# Directory for generated masks
output_mask_directory = "/content/drive/My Drive/iitisoc_ps1/oil_spill_mask/"

# Directory for ground truth masks
ground_truth_directory = "/content/drive/My Drive/iitisoc_ps1/sar_dataset/Train_Val_Oil_Spill_mask/Mask_oil/"

# --- Plotting Sample Masks (existing functionality) ---
mask_files_to_plot = sorted(glob(os.path.join(output_mask_directory, "*.tif")))
ground_truth_files_to_plot = sorted(glob(os.path.join(ground_truth_directory, "*.tif")))

# Select image index (for example, index 51)
image_index = 51

if len(mask_files_to_plot) > image_index and len(ground_truth_files_to_plot) > image_index:
    sample_generated_mask_file = mask_files_to_plot[image_index]
    sample_ground_truth_file = ground_truth_files_to_plot[image_index]

    print(f"Plotting sample generated mask file: {sample_generated_mask_file}")
    try:
        with rasterio.open(sample_generated_mask_file) as src:
            generated_mask_image = src.read(1)

        plt.figure(figsize=(8, 6))
        plt.imshow(generated_mask_image, cmap='gray')
        plt.title(f"Sample Generated Mask: {os.path.basename(sample_generated_mask_file)}")
        plt.colorbar(label='Pixel Value')
        plt.axis('off')
        plt.show()

    except Exception as e:
        print(f"Error reading or plotting {sample_generated_mask_file}: {e}")

    print("-" * 30) # Separator

    print(f"Plotting sample ground truth mask file: {sample_ground_truth_file}")
    try:
        with rasterio.open(sample_ground_truth_file) as src:
            ground_truth_image = src.read(1)

        plt.figure(figsize=(8, 6))
        plt.imshow(ground_truth_image, cmap='gray')
        plt.title(f"Sample Ground Truth Mask: {os.path.basename(sample_ground_truth_file)}")
        plt.colorbar(label='Pixel Value')
        plt.axis('off')
        plt.show()

    except Exception as e:
        print(f"Error reading or plotting {sample_ground_truth_file}: {e}")

    print("-" * 30) # Separator

    # --- Comparison for image number 51 ---
    print(f"Comparing masks for image index {image_index} ({os.path.basename(sample_generated_mask_file)})...")

    # Load predicted and truth masks for comparison (using cv2 for easier metric calculation)
    try:
        pred_mask = cv2.imread(sample_generated_mask_file, cv2.IMREAD_GRAYSCALE)
        truth_mask = cv2.imread(sample_ground_truth_file, cv2.IMREAD_GRAYSCALE)

        if pred_mask is None or truth_mask is None:
             print(f"Error loading masks for comparison.")
        elif pred_mask.shape != truth_mask.shape:
            print(f"Dimension mismatch for image index {image_index}. Cannot compare.")
        else:
            # Binarize masks (assuming 0-255 range, converting to 0 or 1)
            _, pred_bin = cv2.threshold(pred_mask, 127, 1, cv2.THRESH_BINARY)
            _, truth_bin = cv2.threshold(truth_mask, 127, 1, cv2.THRESH_BINARY)

            # Flatten arrays for sklearn metrics
            pred_flat = pred_bin.ravel()
            truth_flat = truth_bin.ravel()

            # Compute metrics
            # Removed the check for positive pixels in ground truth to calculate metrics regardless
            accuracy = accuracy_score(truth_flat, pred_flat)
            # Handle cases with no positive predictions or no positive truth for precision/recall/f1
            precision = precision_score(truth_flat, pred_flat, zero_division=0)
            recall = recall_score(truth_flat, pred_flat, zero_division=0)
            f1 = f1_score(truth_flat, pred_flat, zero_division=0)
            iou = compute_iou(truth_bin, pred_bin)

            print(f"Metrics for Image Index {image_index} ({os.path.basename(sample_generated_mask_file)}):")
            print(f"Accuracy:  {accuracy:.4f}")
            print(f"Precision: {precision:.4f}")
            print(f"Recall:    {recall:.4f}")
            print(f"F1 Score:  {f1:.4f}")
            print(f"IoU:       {iou:.4f}")


    except Exception as e:
        print(f"Error during comparison for image index {image_index}: {e}")

else:
    print(f"Error: Not enough files found in one or both directories to select image index {image_index}.")

In [None]:
import os
import glob
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

def extract_id(path):
    """Extract numeric ID from filename"""
    filename = os.path.basename(path)
    digits = ''.join(filter(str.isdigit, filename))
    return int(digits) if digits else -1

def calculate_metrics(gt_mask, pred_mask):
    """Calculate class-wise and overall metrics with class-wise accuracy"""
    # Flatten masks to 1D arrays
    y_true = gt_mask.flatten()
    y_pred = pred_mask.flatten()

    # Initialize metrics storage
    metrics = {
        'class_0': {'name': 'Background'},
        'class_1': {'name': 'Oil Spill'},
        'overall': {}
    }

    # Calculate confusion matrix elements
    tp0 = np.sum((y_true == 0) & (y_pred == 0))
    fp0 = np.sum((y_true != 0) & (y_pred == 0))
    fn0 = np.sum((y_true == 0) & (y_pred != 0))
    tn0 = np.sum((y_true != 0) & (y_pred != 0))  # True negatives for background

    tp1 = np.sum((y_true == 1) & (y_pred == 1))
    fp1 = np.sum((y_true != 1) & (y_pred == 1))
    fn1 = np.sum((y_true == 1) & (y_pred != 1))
    tn1 = np.sum((y_true != 1) & (y_pred != 1))  # True negatives for oil spill

    # Class 0 (Background) metrics
    metrics['class_0']['precision'] = tp0 / (tp0 + fp0) if (tp0 + fp0) > 0 else 0
    metrics['class_0']['recall'] = tp0 / (tp0 + fn0) if (tp0 + fn0) > 0 else 0
    metrics['class_0']['f1'] = (2 * metrics['class_0']['precision'] * metrics['class_0']['recall'] /
                               (metrics['class_0']['precision'] + metrics['class_0']['recall'])) if (metrics['class_0']['precision'] + metrics['class_0']['recall']) > 0 else 0
    metrics['class_0']['iou'] = tp0 / (tp0 + fp0 + fn0) if (tp0 + fp0 + fn0) > 0 else 0
    metrics['class_0']['accuracy'] = (tp0 + tn0) / y_true.size  # Class-wise accuracy

    # Class 1 (Oil Spill) metrics
    metrics['class_1']['precision'] = tp1 / (tp1 + fp1) if (tp1 + fp1) > 0 else 0
    metrics['class_1']['recall'] = tp1 / (tp1 + fn1) if (tp1 + fn1) > 0 else 0
    metrics['class_1']['f1'] = (2 * metrics['class_1']['precision'] * metrics['class_1']['recall'] /
                               (metrics['class_1']['precision'] + metrics['class_1']['recall'])) if (metrics['class_1']['precision'] + metrics['class_1']['recall']) > 0 else 0
    metrics['class_1']['iou'] = tp1 / (tp1 + fp1 + fn1) if (tp1 + fp1 + fn1) > 0 else 0
    metrics['class_1']['accuracy'] = (tp1 + tn1) / y_true.size  # Class-wise accuracy

    # Overall metrics
    total_pixels = y_true.size
    metrics['overall']['accuracy'] = (tp0 + tp1) / total_pixels
    metrics['overall']['mean_iou'] = (metrics['class_0']['iou'] + metrics['class_1']['iou']) / 2
    metrics['overall']['mean_precision'] = (metrics['class_0']['precision'] + metrics['class_1']['precision']) / 2
    metrics['overall']['mean_recall'] = (metrics['class_0']['recall'] + metrics['class_1']['recall']) / 2
    metrics['overall']['mean_f1'] = (metrics['class_0']['f1'] + metrics['class_1']['f1']) / 2
    metrics['overall']['mean_accuracy'] = (metrics['class_0']['accuracy'] + metrics['class_1']['accuracy']) / 2

    return metrics

def evaluate_mask_pair(pred_dir, gt_dir, index):
    """
    Evaluate predicted and ground truth masks by index number

    Args:
        pred_dir: Directory with predicted masks
        gt_dir: Directory with ground truth masks
        index: Numeric index of the mask pair to evaluate
    """
    # Get all TIFF files in directories
    pred_files = glob.glob(os.path.join(pred_dir, "*.tif"))
    gt_files = glob.glob(os.path.join(gt_dir, "*.tif"))

    # Create ID-path mappings
    pred_dict = {extract_id(p): p for p in pred_files}
    gt_dict = {extract_id(p): p for p in gt_files}

    # Find requested files
    pred_path = pred_dict.get(index)
    gt_path = gt_dict.get(index)

    if not pred_path:
        raise FileNotFoundError(f"No predicted mask found with index {index} in {pred_dir}")
    if not gt_path:
        raise FileNotFoundError(f"No ground truth mask found with index {index} in {gt_dir}")

    # Load masks
    pred_mask = np.array(Image.open(pred_path))
    gt_mask = np.array(Image.open(gt_path))

    # Ensure binary masks (0 and 1 only)
    pred_mask = (pred_mask > 0).astype(np.uint8)
    gt_mask = (gt_mask > 0).astype(np.uint8)

    # Display masks
    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    plt.imshow(pred_mask, cmap='gray', interpolation='none')
    plt.title(f'Predicted Mask\n{os.path.basename(pred_path)}')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(gt_mask, cmap='gray', interpolation='none')
    plt.title(f'Ground Truth Mask\n{os.path.basename(gt_path)}')
    plt.axis('off')

    plt.tight_layout()
    plt.show()

    # Calculate metrics
    metrics = calculate_metrics(gt_mask, pred_mask)

    # Print results
    print("\n" + "="*60)
    print(f"EVALUATION RESULTS FOR MASK PAIR #{index}")
    print("="*60)

    for class_key in ['class_0', 'class_1']:
        class_name = metrics[class_key]['name']
        print(f"\n{class_name.upper()} CLASS METRICS:")
        print(f"  Accuracy:   {metrics[class_key]['accuracy']:.4f}")  # Added class-wise accuracy
        print(f"  Precision:  {metrics[class_key]['precision']:.4f}")
        print(f"  Recall:     {metrics[class_key]['recall']:.4f}")
        print(f"  F1-Score:   {metrics[class_key]['f1']:.4f}")
        print(f"  IoU:        {metrics[class_key]['iou']:.4f}")

    print("\nOVERALL METRICS:")
    print(f"  Accuracy:          {metrics['overall']['accuracy']:.4f}")
    print(f"  Mean Accuracy:     {metrics['overall']['mean_accuracy']:.4f}")  # Added mean class accuracy
    print(f"  Mean IoU:          {metrics['overall']['mean_iou']:.4f}")
    print(f"  Mean Precision:    {metrics['overall']['mean_precision']:.4f}")
    print(f"  Mean Recall:       {metrics['overall']['mean_recall']:.4f}")
    print(f"  Mean F1-Score:     {metrics['overall']['mean_f1']:.4f}")
    print("="*60 + "\n")

    return metrics


# Evaluate mask pair with index 43
metrics = evaluate_mask_pair(
    pred_dir="/content/drive/My Drive/iitisoc_ps1/oil_spill_mask/",
    gt_dir="/content/drive/My Drive/iitisoc_ps1/sar_dataset/Train_Val_Oil_Spill_mask/Mask_oil/",
    index=42)

In [None]:
import os
import glob
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import rasterio  # For reading specialized TIFF formats
from rasterio.plot import show  # For displaying raster images

def extract_id(path):
    """Extract numeric ID from filename"""
    filename = os.path.basename(path)
    digits = ''.join(filter(str.isdigit, filename))
    return int(digits) if digits else -1

def calculate_metrics(gt_mask, pred_mask):
    """Calculate class-wise and overall metrics with class-wise accuracy"""
    # Flatten masks to 1D arrays
    y_true = gt_mask.flatten()
    y_pred = pred_mask.flatten()

    # Initialize metrics storage
    metrics = {
        'class_0': {'name': 'Background'},
        'class_1': {'name': 'Oil Spill'},
        'overall': {}
    }

    # Calculate confusion matrix elements
    tp0 = np.sum((y_true == 0) & (y_pred == 0))
    fp0 = np.sum((y_true != 0) & (y_pred == 0))
    fn0 = np.sum((y_true == 0) & (y_pred != 0))
    tn0 = np.sum((y_true != 0) & (y_pred != 0))

    tp1 = np.sum((y_true == 1) & (y_pred == 1))
    fp1 = np.sum((y_true != 1) & (y_pred == 1))
    fn1 = np.sum((y_true == 1) & (y_pred != 1))
    tn1 = np.sum((y_true != 1) & (y_pred != 1))

    # Class 0 (Background) metrics
    metrics['class_0']['precision'] = tp0 / (tp0 + fp0) if (tp0 + fp0) > 0 else 0
    metrics['class_0']['recall'] = tp0 / (tp0 + fn0) if (tp0 + fn0) > 0 else 0
    metrics['class_0']['f1'] = (2 * metrics['class_0']['precision'] * metrics['class_0']['recall'] /
                               (metrics['class_0']['precision'] + metrics['class_0']['recall'])) if (metrics['class_0']['precision'] + metrics['class_0']['recall']) > 0 else 0
    metrics['class_0']['iou'] = tp0 / (tp0 + fp0 + fn0) if (tp0 + fp0 + fn0) > 0 else 0
    metrics['class_0']['accuracy'] = (tp0 + tn0) / y_true.size

    # Class 1 (Oil Spill) metrics
    metrics['class_1']['precision'] = tp1 / (tp1 + fp1) if (tp1 + fp1) > 0 else 0
    metrics['class_1']['recall'] = tp1 / (tp1 + fn1) if (tp1 + fn1) > 0 else 0
    metrics['class_1']['f1'] = (2 * metrics['class_1']['precision'] * metrics['class_1']['recall'] /
                               (metrics['class_1']['precision'] + metrics['class_1']['recall'])) if (metrics['class_1']['precision'] + metrics['class_1']['recall']) > 0 else 0
    metrics['class_1']['iou'] = tp1 / (tp1 + fp1 + fn1) if (tp1 + fp1 + fn1) > 0 else 0
    metrics['class_1']['accuracy'] = (tp1 + tn1) / y_true.size

    # Overall metrics
    total_pixels = y_true.size
    metrics['overall']['accuracy'] = (tp0 + tp1) / total_pixels
    metrics['overall']['mean_iou'] = (metrics['class_0']['iou'] + metrics['class_1']['iou']) / 2
    metrics['overall']['mean_precision'] = (metrics['class_0']['precision'] + metrics['class_1']['precision']) / 2
    metrics['overall']['mean_recall'] = (metrics['class_0']['recall'] + metrics['class_1']['recall']) / 2
    metrics['overall']['mean_f1'] = (metrics['class_0']['f1'] + metrics['class_1']['f1']) / 2
    metrics['overall']['mean_accuracy'] = (metrics['class_0']['accuracy'] + metrics['class_1']['accuracy']) / 2

    return metrics

def evaluate_mask_pair(pred_dir, gt_dir, sar_dir, index):
    """
    Evaluate predicted and ground truth masks by index number

    Args:
        pred_dir: Directory with predicted masks
        gt_dir: Directory with ground truth masks
        sar_dir: Directory with original SAR images
        index: Numeric index of the mask pair to evaluate
    """
    # Get all TIFF files in directories
    pred_files = glob.glob(os.path.join(pred_dir, "*.tif"))
    gt_files = glob.glob(os.path.join(gt_dir, "*.tif"))
    sar_files = glob.glob(os.path.join(sar_dir, "*.tif"))

    # Create ID-path mappings
    pred_dict = {extract_id(p): p for p in pred_files}
    gt_dict = {extract_id(p): p for p in gt_files}
    sar_dict = {extract_id(p): p for p in sar_files}

    # Find requested files
    pred_path = pred_dict.get(index)
    gt_path = gt_dict.get(index)
    sar_path = sar_dict.get(index)

    if not pred_path:
        raise FileNotFoundError(f"No predicted mask found with index {index} in {pred_dir}")
    if not gt_path:
        raise FileNotFoundError(f"No ground truth mask found with index {index} in {gt_dir}")
    if not sar_path:
        raise FileNotFoundError(f"No SAR image found with index {index} in {sar_dir}")

    # Load images using appropriate methods
    # Use rasterio for SAR images which might be specialized formats
    with rasterio.open(sar_path) as src:
        sar_image = src.read(2)  # Read first band

    # For masks, use Pillow as they should be simple
    pred_mask = np.array(Image.open(pred_path))
    gt_mask = np.array(Image.open(gt_path))

    # Ensure masks are binary (0 and 1 only)
    pred_mask = (pred_mask > 0).astype(np.uint8)
    gt_mask = (gt_mask > 0).astype(np.uint8)

    # Display images in a row: SAR, Predicted Mask, Ground Truth Mask
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # SAR Image
    show(sar_image, cmap='gray', ax=axes[0])
    axes[0].set_title(f'Original SAR Image\n{os.path.basename(sar_path)}')
    axes[0].axis('off')

    # Predicted Mask
    axes[1].imshow(pred_mask, cmap='gray', interpolation='none')
    axes[1].set_title(f'Predicted Mask\n{os.path.basename(pred_path)}')
    axes[1].axis('off')

    # Ground Truth Mask
    axes[2].imshow(gt_mask, cmap='gray', interpolation='none')
    axes[2].set_title(f'Ground Truth Mask\n{os.path.basename(gt_path)}')
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

    # Calculate metrics
    metrics = calculate_metrics(gt_mask, pred_mask)

    # Print results
    print("\n" + "="*60)
    print(f"EVALUATION RESULTS FOR MASK PAIR #{index}")
    print("="*60)

    for class_key in ['class_0', 'class_1']:
        class_name = metrics[class_key]['name']
        print(f"\n{class_name.upper()} CLASS METRICS:")
        print(f"  Accuracy:   {metrics[class_key]['accuracy']:.4f}")
        print(f"  Precision:  {metrics[class_key]['precision']:.4f}")
        print(f"  Recall:     {metrics[class_key]['recall']:.4f}")
        print(f"  F1-Score:   {metrics[class_key]['f1']:.4f}")
        print(f"  IoU:        {metrics[class_key]['iou']:.4f}")

    print("\nOVERALL METRICS:")
    print(f"  Accuracy:          {metrics['overall']['accuracy']:.4f}")
    print(f"  Mean Accuracy:     {metrics['overall']['mean_accuracy']:.4f}")
    print(f"  Mean IoU:          {metrics['overall']['mean_iou']:.4f}")
    print(f"  Mean Precision:    {metrics['overall']['mean_precision']:.4f}")
    print(f"  Mean Recall:       {metrics['overall']['mean_recall']:.4f}")
    print(f"  Mean F1-Score:     {metrics['overall']['mean_f1']:.4f}")
    print("="*60 + "\n")

    return metrics


# Evaluate mask pair with index 42
metrics = evaluate_mask_pair(
    pred_dir="/content/drive/My Drive/iitisoc_ps1/oil_spill_mask/",
    gt_dir="/content/drive/My Drive/iitisoc_ps1/sar_dataset/Train_Val_Oil_Spill_mask/Mask_oil/",
    sar_dir="/content/drive/My Drive/iitisoc_ps1/sar_dataset/Train_Val_Oil_Spill_images/Oil/",
    index=51
)