In [None]:
# 3_Final_Inference.ipynb

import numpy as np
import cv2
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import os
import random
from tensorflow.keras.models import load_model
from tensorflow.keras import backend as K

# --- 1. CONFIGURATION ---

# Define available models
available_models = {
    "1": ("UNet (Standard)", "../saved_models/unet_oil_spill.h5"),
    "2": ("DeepLabV3+ (Experimental)", "../saved_models/deeplabv3_oil_spill.h5")
}

# Prompt user for input
print("\nðŸ¤– SELECT AI MODEL:")
print("1. UNet (Standard)")
print("2. DeepLabV3+ (Experimental)")

choice = input("\nðŸ‘‰ Enter model number (1 or 2): ")

# Set the path based on user choice (Default to UNet if invalid)
if choice == "2":
    MODEL_NAME, MODEL_PATH = available_models["2"]
else:
    MODEL_NAME, MODEL_PATH = available_models["1"]

print(f"\nâœ… Configuration Locked: Using {MODEL_NAME}")
print(f"ðŸ“‚ Path: {MODEL_PATH}")

AIS_DATA_PATH = '../data/ais_data/vessel_data_clean.csv' 

# AUTOMATICALLY FIND A RANDOM TEST IMAGE
TEST_DIR = '../data/test/images' 
TEST_IMG_PATH = None

if os.path.exists(TEST_DIR):
    files = os.listdir(TEST_DIR)
    valid_extensions = {".jpg", ".jpeg", ".png", ".bmp"}
    images = [f for f in files if os.path.splitext(f)[1].lower() in valid_extensions]

    if len(images) > 0:
        selected_file = random.choice(images)
        TEST_IMG_PATH = os.path.join(TEST_DIR, selected_file)
        print(f"ðŸŽ² Randomly selected test image: {selected_file}")
    else:
        print("WARNING: No valid images found in '../data/test/images'.")
        TEST_IMG_PATH = 'dummy_path.jpg'
else:
    print(f"Error: Directory {TEST_DIR} does not exist.")
    TEST_IMG_PATH = 'dummy_path.jpg' 

# --- 2. CUSTOM OBJECTS (Required to fix the 'str' error) ---
# These must match the definitions used in 1_UNet_Training.ipynb
def dice_loss(y_true, y_pred, smooth=1e-6):
    intersection = K.sum(y_true * y_pred, axis=[1,2,3])
    union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
    dice = K.mean((2. * intersection + smooth) / (union + smooth), axis=0)
    return 1 - dice

def iou_metric(y_true, y_pred, smooth=1e-6):
    y_pred_metric = K.cast(K.greater(y_pred, 0.5), K.floatx())
    intersection = K.sum(K.abs(y_true * y_pred_metric), axis=[1,2,3])
    union = K.sum(y_true,[1,2,3]) + K.sum(y_pred_metric,[1,2,3]) - intersection
    return K.mean((intersection + smooth) / (union + smooth), axis=0)

def dice_coeff_metric(y_true, y_pred, smooth=1e-6):
    y_pred_metric = K.cast(K.greater(y_pred, 0.5), K.floatx())
    intersection = K.sum(y_true * y_pred_metric, axis=[1,2,3])
    union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred_metric, axis=[1,2,3])
    return K.mean((2. * intersection + smooth) / (union + smooth), axis=0)

# --- 3. LOAD MODEL ---
print("Loading Model...")
try:
    # We pass the custom functions so Keras knows how to load the .h5 file
    model = load_model(MODEL_PATH, custom_objects={
        'dice_loss': dice_loss,
        'iou_metric': iou_metric,
        'dice_coeff_metric': dice_coeff_metric
    })
    print("Model loaded successfully!")
except Exception as e:
    model = None
    print(f"Error loading model: {e}")
    print("Check if 1_UNet_Training.ipynb completed successfully.")

# --- 4. PREDICT FUNCTION ---
def predict_spill(image_path):
    if model is None:
        return np.zeros((256,256,3)), np.zeros((256,256,1)), np.zeros((256,256,1))

    if not os.path.exists(image_path):
        print(f"Error: Image not found at {image_path}")
        return np.zeros((256,256,3)), np.zeros((256,256,1)), np.zeros((256,256,1)) 

    original_img = cv2.imread(image_path)
    original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(original_img, (256, 256))
    img_input = np.expand_dims(img, axis=0) / 255.0 
    
    raw_pred = model.predict(img_input)[0]
    print(f"DEBUG: Model Output Stats -> Max: {np.max(raw_pred):.4f}, Mean: {np.mean(raw_pred):.4f}")

    mask_pred = (raw_pred > 0.05).astype(np.uint8) 
    return img, mask_pred, raw_pred

# --- 5. ANOMALY DETECTION ---
def detect_anomaly(ais_csv, spill_lat, spill_lon, search_radius=0.5):
    if not os.path.exists(ais_csv):
        return []
    df = pd.read_csv(ais_csv)
    nearby_ships = df[
        (df['LAT'] > spill_lat - search_radius) & (df['LAT'] < spill_lat + search_radius) & 
        (df['LON'] > spill_lon - search_radius) & (df['LON'] < spill_lon + search_radius)
    ]
    report_data = []
    for _, ship in nearby_ships.iterrows():
        status = "STOPPED (SUSPECT)" if ship['SOG'] < 1.0 else "MOVING"
        report_data.append({
            "Status": status, "Vessel Name": str(ship['VesselName']),
            "MMSI": ship['MMSI'], "Speed (knots)": ship['SOG']
        })
    return report_data

# --- 6. DAMAGE ASSESSMENT ---
def assess_damage(mask, raw_pred):
    oil_pixel_count = np.count_nonzero(mask)
    if oil_pixel_count > 0:
        avg_confidence = np.mean(raw_pred[mask > 0]) * 100
        total_pixels = mask.shape[0] * mask.shape[1]
        spill_percentage = (oil_pixel_count / total_pixels) * 100
    else:
        avg_confidence, spill_percentage = 0.0, 0.0
    
    total_area_km2 = (oil_pixel_count * 100) / 1_000_000
    
    print("\n--- DAMAGE ASSESSMENT REPORT ---")
    print(f"Estimated Spill Area:  {total_area_km2:.4f} sq. km")
    print(f"Average AI Confidence: {avg_confidence:.2f}%")

    if total_area_km2 > 1.0: print("SEVERITY: CRITICAL")
    elif total_area_km2 > 0.1: print("SEVERITY: HIGH")
    elif total_area_km2 > 0.0: print("SEVERITY: MODERATE")
    else: print("SEVERITY: NONE")

# --- 7. RUN PIPELINE ---
if model is None:
    print("FATAL ERROR: Model not defined. Pipeline halted.")
elif TEST_IMG_PATH is None or not os.path.exists(TEST_IMG_PATH):
    print("FATAL ERROR: Invalid test image path.")
else:
    print("\n--- STARTING ANALYSIS ---")
    final_img, final_mask, raw_pred = predict_spill(TEST_IMG_PATH) 

    # Visualization
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1); plt.title("Satellite Input"); plt.imshow(final_img); plt.axis('off')
    plt.subplot(1, 3, 2); plt.title("AI Mask"); plt.imshow(final_mask * 255, cmap='gray'); plt.axis('off')
    
    mask_red = np.zeros_like(final_img)
    mask_red[:,:,0] = final_mask.squeeze() * 255
    overlay = cv2.addWeighted(final_img, 0.7, mask_red, 0.3, 0)
    plt.subplot(1, 3, 3); plt.title("Forensic Overlay"); plt.imshow(overlay); plt.axis('off')
    plt.show()

    assess_damage(final_mask, raw_pred)
    
    anomalies = detect_anomaly(AIS_DATA_PATH, 28.5, -90.5)
    print("\n" + "="*50 + "\nðŸš¢ NEARBY VESSEL ACTIVITY REPORT\n" + "="*50)
    if anomalies:
        print(pd.DataFrame(anomalies).sort_values(by="Speed (knots)").to_string(index=False))
    else:
        print("No vessels detected.")