# Extracting zipped files

In [None]:
import zipfile
import os

zip_path = "/content/drive/MyDrive/691_Team4_Dataset/Dataset.zip"
extract_to = "/content/Dataset"

# Unzip only if not already extracted
if not os.path.exists(extract_to):
    os.makedirs(extract_to, exist_ok=True)
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    print("✅ Extraction complete.")
else:
    print("✅ Already extracted.")

In [None]:
pip install geopandas rasterio shapely

In [None]:
import json
import matplotlib.pyplot as plt
from shapely import wkt
import geopandas as gpd
import rasterio
from rasterio.plot import show
from glob import glob
import random

#Assigning Labels and Colours to classes

In [None]:
import numpy as np
import cv2
from shapely.geometry import Polygon

# xBD damage label mapping
LABEL_NAME_TO_NUM = {
    'no-damage': 1,
    'minor-damage': 2,
    'major-damage': 3,
    'destroyed': 4,
    'un-classified': 5
}

# Assign RGB colors to each class (visual)
CLASS_COLORS = {
    1: (0, 255, 0),      # no-damage: green
    2: (255, 255, 0),    # minor: yellow
    3: (255, 165, 0),    # major: orange
    4: (255, 0, 0),      # destroyed: red
    5: (128, 128, 128),  # unclassified: gray
}

#Creating Mask Based on Damage Labels

In [None]:
def create_mask(image_shape, label_json):
    height, width = image_shape # Extracting height and width from the image shape
    mask = np.zeros((height, width), dtype=np.uint8) # Initializing a mask with zeros (black)

    for feat in label_json['features']['xy']:
        damage = feat['properties'].get('subtype', 'no-damage')# Get the damage subtype
        label = LABEL_NAME_TO_NUM.get(damage, 1) # Mapping the damage subtype to a numeric label
        polygon = wkt.loads(feat['wkt'])
        coords = np.array(polygon.exterior.coords, np.int32)# Getting the coordinates of the polygon's exterior

        # Filling the polygon area in the mask with the corresponding label
        cv2.fillPoly(mask, [coords], label)
    return mask  # Return the final mask with filled polygons corresponding to the damage regions

#Overlay Mask on Image for Visualization

In [None]:
def overlay_mask(image_path, label_json_path, title=""):
    with open(label_json_path) as f:
        label_json = json.load(f)

    with rasterio.open(image_path) as src:
        img = src.read([1, 2, 3]).transpose(1, 2, 0)  # RGB

    img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    mask = create_mask(img.shape[:2], label_json)

    color_mask = np.zeros_like(img)
    for val, color in CLASS_COLORS.items():
        color_mask[mask == val] = color

    overlay = cv2.addWeighted(img, 0.6, color_mask, 0.4, 0)

    plt.figure(figsize=(6, 6))
    plt.imshow(overlay)
    plt.title(title)
    plt.axis("off")
    plt.show()

In [None]:
import os

#File Paths and Sample Selection for Pre-Disaster Labels and Images

In [None]:
label_dir = "/content/Dataset/Dataset/train/hurricane/labels"
pre_img_dir = "/content/Dataset/Dataset/train/hurricane/images"
post_img_dir = "/content/Dataset/Dataset/train/hurricane/images"
label_files = sorted(glob(os.path.join(label_dir, "*pre_disaster.json")))
sample_files = random.sample(label_files, 10) # Randomly select 10 label files for processing

#Looping through each randomly sampled pre-disaster label file for visualization

In [None]:
for pre_label_path in sample_files:
    base = os.path.basename(pre_label_path).replace("_pre_disaster.json", "")
    post_label_path = pre_label_path.replace("pre_disaster", "post_disaster")

    pre_image = os.path.join(pre_img_dir, f"{base}_pre_disaster.tif")
    post_image = os.path.join(post_img_dir, f"{base}_post_disaster.tif")

    print(f"▶️ Visualizing: {base}")
    overlay_mask(pre_image, pre_label_path, title=f"{base} — Pre-disaster (Building Mask)")
    overlay_mask(post_image, post_label_path, title=f"{base} — Post-disaster (Damage Mask)")

#Analysis and Visualization of Pre- and Post-Disaster Data: Area Calculations and Damage Class Distribution

In [None]:
import json
import numpy as np
from shapely import wkt
from shapely.geometry import Polygon
import matplotlib.pyplot as plt
from glob import glob

# Paths
label_dir = "/content/Dataset/Dataset/train/hurricane/labels"
pre_label_files = sorted(glob(os.path.join(label_dir, "*pre_disaster.json"))) # List of pre-disaster label files (sorted)
post_label_files = sorted(glob(os.path.join(label_dir, "*post_disaster.json"))) # List of post-disaster label files (sorted)

# For bar chart 1: building vs non-building (pre-disaster)
# Initializing lists to store areas of buildings and non-building regions
building_areas = []
non_building_areas = []

# For bar chart 2: count of each damage class (post-disaster)
# Initializing a dictionary to count the number of instances for each damage class (post-disaster)
damage_class_counts = {
    'no-damage': 0,
    'minor-damage': 0,
    'major-damage': 0,
    'destroyed': 0,
    'un-classified': 0
}

# Compute areas
for pre_json_path in pre_label_files:
    with open(pre_json_path) as f:
        data = json.load(f)
    polygons = [wkt.loads(feat['wkt']) for feat in data['features']['xy']]
    area = sum(poly.area for poly in polygons if poly.is_valid)  # Compute the total area of valid polygons
    building_areas.append(area) # Add the computed building area to the list
    # Assume full tile is 1024x1024 = 1,048,576 pixels
    non_building_areas.append(1024 * 1024 - area)

# Count damage classes
for post_json_path in post_label_files:
    with open(post_json_path) as f:
        data = json.load(f)
    for feat in data['features']['xy']:
        damage = feat['properties'].get('subtype', 'no-damage')
        if damage in damage_class_counts: # If the damage type is one of the predefined categories
            damage_class_counts[damage] += 1
        else:
            damage_class_counts['no-damage'] += 1  # default fallback

# Plot building vs non-building area
avg_building_area = np.mean(building_areas)
avg_non_building_area = np.mean(non_building_areas)

plt.figure(figsize=(6, 4))
plt.bar(["Building", "Non-Building"], [avg_building_area, avg_non_building_area], color=["green", "gray"])
plt.title("Average Area: Building vs Non-Building (Pre-disaster)")
plt.ylabel("Average Area (pixels)")
plt.grid(True)
plt.show()

# Plot damage class distribution
plt.figure(figsize=(8, 4))
plt.bar(damage_class_counts.keys(), damage_class_counts.values(), color=["green", "yellow", "orange", "red", "gray"])
plt.title("Count of Buildings per Damage Class (Post-disaster)")
plt.ylabel("Number of Buildings")
plt.xticks(rotation=15)
plt.grid(True)
plt.show()

In [None]:
import pandas as pd
import seaborn as sns

# DataFrame Creation and Seaborn Visualization: Plotting Average Building vs Non-Building Area


In [None]:
# DataFrame for seaborn
area_df = pd.DataFrame({
    "Category": ["Building", "Non-Building"],
    "Average Area (pixels)": [avg_building_area, avg_non_building_area]
})

# Seaborn bar plot with annotations
plt.figure(figsize=(6, 4))
ax1 = sns.barplot(data=area_df, x="Category", y="Average Area (pixels)", palette=["green", "gray"])
plt.title("Average Area: Building vs Non-Building (Pre-disaster)")
plt.grid(True)

# Add value annotations
for p in ax1.patches:
    height = p.get_height()
    ax1.annotate(f'{height:,.0f}',
                 (p.get_x() + p.get_width() / 2., height),
                 ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

#Plotting Damage Class Distribution

In [None]:
# Convert dict to DataFrame
damage_df = pd.DataFrame(list(damage_class_counts.items()), columns=["Damage Class", "Count"])

# Color palette
colors = {
    "no-damage": "green",
    "minor-damage": "yellow",
    "major-damage": "orange",
    "destroyed": "red",
    "un-classified": "gray"
}
# Map the color palette to the "Damage Class" column in the DataFrame
palette = [colors[d] for d in damage_df["Damage Class"]]

# Seaborn bar plot with annotations
plt.figure(figsize=(8, 4))
ax2 = sns.barplot(data=damage_df, x="Damage Class", y="Count", palette=palette)
plt.title("Count of Buildings per Damage Class (Post-disaster)")
plt.ylabel("Number of Buildings")
plt.xticks(rotation=15)
plt.grid(True)

# Add value annotations
for p in ax2.patches:
    height = p.get_height()
    ax2.annotate(f'{height:,}',
                 (p.get_x() + p.get_width() / 2., height),
                 ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

# Image Mask Generation from Damage Labels: Polygon Drawing and Mask Saving

In [None]:
import os
import json
import numpy as np
from shapely import wkt
from shapely.geometry import Polygon, mapping
from skimage.io import imread, imsave
import cv2
from tqdm import tqdm

# Define the class-to-index mapping
LABEL_NAME_TO_NUM = {
    'no-damage': 1,
    'minor-damage': 2,
    'major-damage': 3,
    'destroyed': 4,
    'un-classified': 5  # use 5 for buildings without damage labels
}

def get_image_shape(image_path):
    img = imread(image_path)
    return img.shape[:2]  # (H, W)

def read_label_json(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
    return data['features']['xy']

def get_polygons_and_labels(features):
    polygons = []
    for feat in features:
        wkt_poly = feat['wkt']
        damage = feat['properties'].get('subtype', 'no-damage')
        label = LABEL_NAME_TO_NUM.get(damage, 1)  # default to 1 (no-damage)
        polygon = wkt.loads(wkt_poly)
        polygons.append((np.array(polygon.exterior.coords, dtype=np.int32), label))
    return polygons

def draw_mask(shape, polygons, border_shrink=1):
    mask = np.zeros(shape, dtype=np.uint8)

    for coords, label in polygons:
        polygon = Polygon(coords)
        centroid_x, centroid_y = polygon.centroid.coords[0]
        shrunk_coords = []
        for x, y in polygon.exterior.coords:
            x = x + border_shrink if x < centroid_x else x - border_shrink
            y = y + border_shrink if y < centroid_y else y - border_shrink
            shrunk_coords.append([x, y])
        shrunk_coords = np.array(shrunk_coords, dtype=np.int32)
        cv2.fillPoly(mask, [shrunk_coords], label)

    return mask

def generate_mask(image_path, json_path, output_path, border_shrink=1):
    print(f"Processing: {os.path.basename(image_path)}")
    shape = get_image_shape(image_path)
    features = read_label_json(json_path)
    polygons = get_polygons_and_labels(features)
    mask = draw_mask(shape, polygons, border_shrink)
    imsave(output_path, mask)

# Mask Generation for Hurricane Dataset: Processing Images and Generating Masks

In [None]:
# Define the directories for images, labels, and output masks
image_dir = "/content/Dataset/Dataset/train/hurricane/images"
label_dir = "/content/Dataset/Dataset/train/hurricane/labels"
output_dir = "/content/masks"
os.makedirs(output_dir, exist_ok=True)

# Iterate over each file in the image directory
for filename in tqdm(os.listdir(image_dir)):
    if not filename.endswith(".tif"):
        continue
    base_name = filename.replace(".tif", "")
    # Construct the full paths for the image, label JSON, and output mask
    image_path = os.path.join(image_dir, filename)
    json_path = os.path.join(label_dir, f"{base_name}.json")
    output_path = os.path.join(output_dir, f"{base_name}_mask.png")

    if os.path.exists(json_path):
        generate_mask(image_path, json_path, output_path, border_shrink=1)
    else:
        print(f"Missing label for {filename}")

# Visualization of Post-Disaster Satellite Image with Damage Class Overlay

In [None]:
import os
import cv2
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from glob import glob

# Color palette for each damage class (1–5)
CLASS_COLORS = {
    1: (0, 255, 0),     # no-damage - green
    2: (255, 255, 0),   # minor-damage - yellow
    3: (255, 165, 0),   # major-damage - orange
    4: (255, 0, 0),     # destroyed - red
    5: (128, 128, 128)  # unclassified - gray
}

# visualize the post-disaster image with an overlay of the damage classes
def visualize_post_image_with_mask(image_path, mask_path, title=""):
    # Load RGB satellite image
    with rasterio.open(image_path) as src:
        img = src.read([1, 2, 3]).transpose(1, 2, 0) # Read the 3 RGB channels (Red, Green, Blue)
        img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) # Normalize the image to [0, 255]


    # Load grayscale mask image
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    # Initialize an empty mask image for colored overlay
    color_mask = np.zeros_like(img)

 # Iterate over each damage class (1-5) and assign the corresponding color from CLASS_COLORS
    for cls, color in CLASS_COLORS.items():
        color_mask[mask == cls] = color

 # Overlay the color mask on the original image with a weight of 0.7 for the image and 0.3 for the mask
    overlay = cv2.addWeighted(img, 0.7, color_mask, 0.3, 0)

    # Show side-by-side
    fig, axs = plt.subplots(1, 2, figsize=(12, 6)) # Create two subplots (side by side)
    axs[0].imshow(img)
    axs[0].set_title("Post-Disaster Image")
    axs[1].imshow(overlay)
    axs[1].set_title("Overlay with Damage Classes")
    for ax in axs:
        ax.axis("off")
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

# Visualization of Post-Disaster Images with Corresponding Damage Masks

In [None]:
# Paths to post-disaster images and masks
image_dir = "/content/Dataset/Dataset/train/hurricane/images"
mask_dir = "/content/masks"

# Pick only post-disaster images
post_images = sorted(glob(os.path.join(image_dir, "*post_disaster.tif")))
post_image_samples = post_images[:15]

# Visualize each with its corresponding mask
for image_path in post_image_samples:
   # Extract the base name (without extension) for matching mask files
    base_name = os.path.basename(image_path).replace(".tif", "")
    mask_path = os.path.join(mask_dir, f"{base_name}_mask.png")
    if os.path.exists(mask_path):
        visualize_post_image_with_mask(image_path, mask_path, title=base_name)
    else:
        print(f"❌ Missing mask for: {base_name}")

# Tile Generation for Satellite Images: Pre-disaster and Post-disaster Image Splitting with Damage and Building Masks

In [None]:
import os
import cv2
import numpy as np
import tifffile as tiff
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt

def make_output_dirs(base_dir):
    for sub in ['images_pre', 'images_post', 'damage_masks', 'building_masks']:
        os.makedirs(os.path.join(base_dir, sub), exist_ok=True)

def get_base_ids_from_paths(pre_image_paths):
    return [os.path.basename(f).replace("_pre_disaster.tif", "") for f in pre_image_paths]

# This ensures the images are consistent for processing and visualization
def convert_to_rgb(img):
    if img.dtype != np.uint8:
        img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX)
        img = img.astype(np.uint8)
    if len(img.shape) == 2:
        return np.stack([img] * 3, axis=-1)
    if img.shape[2] >= 3:
        return img[:, :, :3]
    raise ValueError(f"Unsupported image shape: {img.shape}, dtype: {img.dtype}")

# Main function to create image tiles and corresponding damage and building masks
def create_tiles(image_dir, mask_dir, output_dir, tile_size=256, visualize_first=True):
    make_output_dirs(output_dir)

    pre_image_paths = sorted(glob(os.path.join(image_dir, "*_pre_disaster.tif")))
    base_ids = get_base_ids_from_paths(pre_image_paths)

 # Iterate over each base ID (representing a disaster event)
    for base in tqdm(base_ids, desc="Creating tiles"):
        pre_path = os.path.join(image_dir, f"{base}_pre_disaster.tif")
        post_path = os.path.join(image_dir, f"{base}_post_disaster.tif")
        dmg_path = os.path.join(mask_dir, f"{base}_post_disaster_mask.png")

        if not (os.path.exists(pre_path) and os.path.exists(post_path) and os.path.exists(dmg_path)):
            continue

        pre_raw = tiff.imread(pre_path)
        post_raw = tiff.imread(post_path)

        pre_rgb = convert_to_rgb(pre_raw)
        post_rgb = convert_to_rgb(post_raw)

        damage_mask = cv2.imread(dmg_path, cv2.IMREAD_GRAYSCALE)
        building_mask = (damage_mask > 0).astype(np.uint8) * 255

        h, w = damage_mask.shape
        count = 0

 # Iterate through the image and generate tiles (patches) of the specified size
        for y in range(0, h, tile_size):
            for x in range(0, w, tile_size):
                if y + tile_size > h or x + tile_size > w:
                    continue

 # Extract image patches for the pre-disaster, post-disaster, damage mask, and building mask
                pre_patch = pre_rgb[y:y+tile_size, x:x+tile_size]
                post_patch = post_rgb[y:y+tile_size, x:x+tile_size]
                dmg_patch = damage_mask[y:y+tile_size, x:x+tile_size]
                bld_patch = building_mask[y:y+tile_size, x:x+tile_size]

                patch_id = f"{base}_{y}_{x}_{count}"# Generate a unique ID for each patch
                cv2.imwrite(os.path.join(output_dir, "images_pre", f"{patch_id}.png"), pre_patch)
                cv2.imwrite(os.path.join(output_dir, "images_post", f"{patch_id}.png"), post_patch)
                cv2.imwrite(os.path.join(output_dir, "damage_masks", f"{patch_id}.png"), dmg_patch)
                cv2.imwrite(os.path.join(output_dir, "building_masks", f"{patch_id}.png"), bld_patch)

# Executing Tile Generation for Satellite Images: Pre-disaster and Post-disaster Image Splitting with Damage and Building Masks

In [None]:
create_tiles(
    image_dir="/content/Dataset/Dataset/train/hurricane/images",
    mask_dir="/content/masks",
    output_dir="/content/tiles"
)

# Visualization of Satellite Image Tiles and Corresponding Masks: Pre-image, Post-image, Building Mask, and Damage Mask


In [None]:
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
from glob import glob

# Path to tile directory
tile_dir = "/content/tiles"  # <-- update if needed

# Get sample file names from images_pre
sample_paths = sorted(glob(os.path.join(tile_dir, "images_pre", "*.png")))[:5]  # Visualize 5 tiles

# Plotting
n = len(sample_paths) # Number of samples to plot
fig, axs = plt.subplots(n, 4, figsize=(12, 3 * n)) # Create subplots with 4 columns (pre, post, building mask, damage mask)

# Iterate over each tile for visualization
for i, pre_path in enumerate(sample_paths):
    base_name = os.path.basename(pre_path)

    # Construct paths for the corresponding post-image, building mask, and damage mask
    post_path = os.path.join(tile_dir, "images_post", base_name)
    bld_path = os.path.join(tile_dir, "building_masks", base_name)
    dmg_path = os.path.join(tile_dir, "damage_masks", base_name)

    # Load the pre-image, post-image, building mask, and damage mask
    pre = cv2.imread(pre_path)
    post = cv2.imread(post_path)
    bld = cv2.imread(bld_path, cv2.IMREAD_GRAYSCALE)
    dmg = cv2.imread(dmg_path, cv2.IMREAD_GRAYSCALE)

    # Plot the images and masks in the appropriate subplot axes
    axs[i, 0].imshow(cv2.cvtColor(pre, cv2.COLOR_BGR2RGB))
    axs[i, 0].set_title("Pre Image")

    axs[i, 1].imshow(cv2.cvtColor(post, cv2.COLOR_BGR2RGB))
    axs[i, 1].set_title("Post Image")

    axs[i, 2].imshow(bld, cmap="gray")
    axs[i, 2].set_title("Building Mask")

    axs[i, 3].imshow(dmg, cmap="nipy_spectral", vmin=0, vmax=4)
    axs[i, 3].set_title("Damage Mask")

    for ax in axs[i]:
        ax.axis("off")

plt.tight_layout()
plt.show()

In [None]:
# import shutil

# source_folder = "/content/shards"
# zip_name = "/content/shards"  # No .zip extension here

# # This creates /content/Dataset_backup.zip
# shutil.make_archive(zip_name, 'zip', source_folder)
# print("✅ Folder zipped successfully.")


In [None]:
# import shutil

# # Destination in Google Drive
# destination_path = "/content/final_splits_sliced.json"

# # Copy the zip
# shutil.copy("/content/drive/MyDrive/691_Team4_Dataset/final_splits_sliced.json", destination_path)
# print("✅ Zip file copied to Drive.")

# Check the Disk Usage of the "tiles" Directory to Monitor Storage Consumption

In [None]:
# !du -sh "/content/drive/MyDrive/691_Team4_Dataset/shards.zip"
!du -sh "/content/tiles"

# Counting PNG Files Recursively in Directories: Images and Masks

In [None]:
images_post_count = 0
for root, dirs, files in os.walk("/content/tiles/images_post"):
    images_post_count += sum(1 for f in files if f.endswith('.png'))

print("Total files (recursive):", images_post_count)

images_pre_count = 0
for root, dirs, files in os.walk("/content/tiles/images_pre"):
    images_pre_count += sum(1 for f in files if f.endswith('.png'))

print("Total files (recursive):", images_pre_count)

building_masks_count = 0
for root, dirs, files in os.walk("/content/tiles/building_masks"):
    building_masks_count += sum(1 for f in files if f.endswith('.png'))

print("Total files (recursive):", building_masks_count)

damage_masks_count = 0
for root, dirs, files in os.walk("/content/tiles/damage_masks"):
    damage_masks_count += sum(1 for f in files if f.endswith('.png'))

print("Total files (recursive):", damage_masks_count)

# Creating .npz Shards for Image Tiles: Pre- and Post-Disaster Images, Masks, and Validations

In [None]:
import os
import numpy as np
import cv2
from glob import glob
from tqdm import tqdm

def create_shards(tile_root, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    # Get sorted list of pre-disaster image file paths (PNG files)
    pre_image_paths = sorted(glob(os.path.join(tile_root, "images_pre", "*.png")))
    skipped_log = []

    # Iterate over each pre-disaster image file
    for pre_path in tqdm(pre_image_paths, desc="Creating individual .npz shards"):
        try:
            fname = os.path.basename(pre_path)
            base_id = fname.replace(".png", "")

            post_path = os.path.join(tile_root, "images_post", fname)
            bld_path = os.path.join(tile_root, "building_masks", fname)
            dmg_path = os.path.join(tile_root, "damage_masks", fname)

            # Read all components
            pre = cv2.imread(pre_path)
            post = cv2.imread(post_path)
            bld = cv2.imread(bld_path, cv2.IMREAD_GRAYSCALE)
            dmg = cv2.imread(dmg_path, cv2.IMREAD_GRAYSCALE)

            # Validations
            if pre is None or post is None or bld is None or dmg is None:
                raise ValueError("One or more files could not be read.")

            # Ensure that pre and post images are RGB (3-channel)
            if pre.ndim != 3 or post.ndim != 3 or pre.shape[2] != 3 or post.shape[2] != 3:
                raise ValueError("Pre/Post images are not RGB.")

            if pre.shape != post.shape or pre.shape[:2] != bld.shape or pre.shape[:2] != dmg.shape:
                raise ValueError("Shape mismatch among components.")

            # Save the images and masks as a compressed .npz file
            np.savez_compressed(
                os.path.join(output_dir, f"{base_id}.npz"),
                pre=pre.astype(np.uint8),
                post=post.astype(np.uint8),
                mask=dmg.astype(np.uint8),
                bld_mask=bld.astype(np.uint8),
                original_image_name=base_id
            )

        except Exception as e:
            skipped_log.append(f"{fname}: {str(e)}")
            continue

    # Save skipped log
    if skipped_log:
        with open(os.path.join(output_dir, "skipped_tiles_log.txt"), "w") as f:
            for entry in skipped_log:
                f.write(entry + "\n")

# Executing .npz Shard Creation for Image Tiles: Pre- and Post-Disaster Images, Masks, and Validations


In [None]:
create_shards(
    tile_root="/content/tiles",
    output_dir="/content/shards"
)

# Creating Data Split for .npz Files: Train, Validation, and Test Set Assignment


In [None]:
import os
import json
from glob import glob
from collections import defaultdict
import random

def create_data_split(npz_dir, output_json_path, seed=42, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
   # Ensure the ratios sum to 1
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-5

    npz_files = sorted(glob(os.path.join(npz_dir, "*.npz")))
    print(f"Found {len(npz_files)} .npz files")

    # Group by scene_id
    scene_groups = defaultdict(list) # Use a defaultdict to store files grouped by scene_id
    for f in npz_files:
        fname = os.path.basename(f).replace(".npz", "")
        scene_id = "_".join(fname.split("_")[:2])  # e.g., hurricane-harvey_00000092
        scene_groups[scene_id].append(fname)

    scene_ids = sorted(scene_groups.keys())
    random.seed(seed) # Set the random seed for reproducibility
    random.shuffle(scene_ids)

    # Calculate the number of scenes for training, validation, and testing
    num_scenes = len(scene_ids)
    num_train = int(train_ratio * num_scenes)
    num_val = int(val_ratio * num_scenes)
    num_test = num_scenes - num_train - num_val

     # Assign scenes to train, validation, and test sets
    train_scenes = scene_ids[:num_train]
    val_scenes = scene_ids[num_train:num_train + num_val]
    test_scenes = scene_ids[num_train + num_val:]


    # Create a dictionary to hold the split information for each scene_id
    split_dict = {}
    for scene_id in scene_ids:
        split_dict[scene_id] = {
            "train": scene_groups[scene_id] if scene_id in train_scenes else [],
            "val": scene_groups[scene_id] if scene_id in val_scenes else [],
            "test": scene_groups[scene_id] if scene_id in test_scenes else [],
        }

    with open(output_json_path, "w") as f:
        json.dump(split_dict, f, indent=4)

    print(f"✅ Split saved to {output_json_path}")
    print(f"📊 Scenes → Train: {len(train_scenes)}, Val: {len(val_scenes)}, Test: {len(test_scenes)}")

In [None]:
# Example usage
create_data_split(
    npz_dir="/content/shards",
    output_json_path="/content/final_splits_sliced.json"
)

# Compute and Save Mean and Standard Deviation for Image Tiles: Pre-image Normalization and Statistics Calculation

In [None]:
import os
import json
import numpy as np
import cv2
from glob import glob
from tqdm import tqdm

def compute_mean_std(tile_root, output_json):
    tile_stats = {} # Dictionary to store the mean and standard deviation for each tile

    tile_paths = sorted(glob(os.path.join(tile_root, "images_pre", "*.png")))

    # Iterate through each image file
    for path in tqdm(tile_paths, desc="Computing mean/std"):
        img = cv2.imread(path)
        if img is None:
            print(f"Image could not be read: {path}")
            continue

        img = img.astype(np.float32) / 255.0  # Normalize the image to the range [0, 1]
        mean = np.mean(img, axis=(0, 1)).tolist() # Compute the mean for each channel (RGB)
        std = np.std(img, axis=(0, 1)).tolist()   # Compute the standard deviation for each channel (RGB)

        key = os.path.basename(path).replace(".png", "")
        tile_stats[key] = [mean, std] # Store the mean and std in the dictionary

    with open(output_json, 'w') as f:
        json.dump(tile_stats, f, indent=4)

# Computing mean and standard deviation

In [None]:
# Run this
compute_mean_std(
    tile_root="/content/tiles",
    output_json="/content/mean_stddev_tiles.json"
)

In [None]:
import torch
from torch.utils.data import Dataset
import numpy as np
import json
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Custom Dataset for Damage Detection: Handling Image Tiles, Augmentations and Normalization

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import cv2
from albumentations import Compose, HorizontalFlip, RandomRotate90
from albumentations.pytorch import ToTensorV2
"""
        Initializes the dataset by setting parameters for tile IDs, base path,
        and whether to apply transformations and normalization.

        Args:
            tile_ids (list): List of tile IDs (file names without extension).
            mean_std_dict (dict): Dictionary containing the mean and std for each tile.
            base_path (str): Base directory where the tile data is located.
            transform (bool): Whether to apply augmentations.
            normalize (bool): Whether to normalize the images.
        """
class DamageDataset(Dataset):
    def __init__(self, tile_ids, mean_std_dict, base_path, transform=True, normalize=True):
        self.tile_ids = tile_ids
        self.mean_std_dict = mean_std_dict
        self.base_path = base_path
        self.transform = transform
        self.normalize = normalize

         # Default fallback: ImageNet mean and std for normalization
        self.default_mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        self.default_std = np.array([0.229, 0.224, 0.225], dtype=np.float32)

        # Augmentation
        self.transform_aug = Compose([
            HorizontalFlip(p=0.5),
            RandomRotate90(p=0.5),
            ToTensorV2()
        ], additional_targets={'image1': 'image', 'mask1': 'mask'})

         # Base transformation (no augmentation), only converts to tensor
        self.transform_base = Compose([
            ToTensorV2()
        ], additional_targets={'image1': 'image', 'mask1': 'mask'})

    def __len__(self):
      return len(self.tile_ids)  #Returns the number of samples in the dataset.

    def __getitem__(self, idx):
        tile_id = self.tile_ids[idx] #Fetches a sample from the dataset given an index.
        npz_path = os.path.join(self.base_path, f"{tile_id}.npz")
        if not os.path.exists(npz_path):
            raise FileNotFoundError(f"Missing: {npz_path}")

        data = np.load(npz_path)

        pre = data["pre"].astype(np.float32) / 255.0
        post = data["post"].astype(np.float32) / 255.0
        bld = data["bld_mask"]
        dmg = data["mask"]

        # --- Mean/Std Handling ---
        # Check if mean and std are available in the dictionary for the current tile_id
        if tile_id in self.mean_std_dict:
            try:
                mean_list, std_list = self.mean_std_dict[tile_id]
                mean = np.array(mean_list, dtype=np.float32)
                std = np.array(std_list, dtype=np.float32)
                std = np.where(std == 0, 1.0, std)  # prevent division by zero
            except Exception as e:
                print(f"⚠️ Error in mean/std for {tile_id}: {e}")
                mean = self.default_mean
                std = self.default_std
        else:
            print(f"🔁 Using default normalization for {tile_id}")
            mean = self.default_mean
            std = self.default_std

        if self.normalize:
            pre = (pre - mean) / std
            post = (post - mean) / std

        # --- Transform ---
         # Apply transformations: augmentation or base transformations
        if self.transform:
            transformed = self.transform_aug(image=pre, image1=post, mask=bld, mask1=dmg)
        else:
            transformed = self.transform_base(image=pre, image1=post, mask=bld, mask1=dmg)

        # Return the transformed images and masks as a dictionary
        return {
            "pre_image": transformed['image'],
            "post_image": transformed['image1'],
            "building_mask": transformed['mask'].long(),
            "damage_mask": transformed['mask1'].long()
        }

# Data Augmentation and Transformation for Image Tiles: Handling Train vs Test Transformations


In [None]:
def get_transform(train=True):
  # If in training mode, apply augmentations such as random horizontal flip and random rotation
    if train:
        return A.Compose([  # Define a composition of augmentations
            A.HorizontalFlip(p=0.5),  # Apply horizontal flip with a 50% probability
            A.RandomRotate90(p=0.5),  # Apply random 90-degree rotation with a 50% probability
            ToTensorV2() # Convert the image to a PyTorch tensor
        ], additional_targets={"image1": "image"}) # Apply the same transformations to 'image1' (post-disaster)

         # If in testing mode, only convert images to tensors (no augmentations)
    else:
        return A.Compose([  # Define a composition with only the tensor conversion
            ToTensorV2()  # Convert the image to a PyTorch tensor
        ], additional_targets={"image1": "image"})  # Apply the same transformation to 'image1' (post-disaster)

In [None]:
import torch
import torch.nn as nn
from collections import OrderedDict

# SiamUnet Architecture for Image Segmentation and Classification: Combining UNet with Siamese Networks


In [None]:
class SiamUnet(nn.Module):

    def __init__(self, in_channels=3, out_channels_s=2, out_channels_c=5, init_features=16):
        super(SiamUnet, self).__init__()

        features = init_features

        # UNet layers
         # Encoder layers (downsampling)
        self.encoder1 = SiamUnet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = SiamUnet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = SiamUnet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = SiamUnet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bottleneck layer
        self.bottleneck = SiamUnet._block(features * 8, features * 16, name="bottleneck")

        # --- Decoder layers (upsampling) ---
        self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2)
        self.decoder4 = SiamUnet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2)
        self.decoder3 = SiamUnet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
        self.decoder2 = SiamUnet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
        self.decoder1 = SiamUnet._block(features * 2, features, name="dec1")

        # Segmentation output layer
        self.conv_s = nn.Conv2d(in_channels=features, out_channels=out_channels_s, kernel_size=1)

        # Siamese classifier layers
        # Similar convolution layers for classifying differences between two inputs
        self.upconv4_c = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2)
        self.conv4_c = SiamUnet._block(features * 16, features * 16, name="conv4")

        self.upconv3_c = nn.ConvTranspose2d(features * 16, features * 4, kernel_size=2, stride=2)
        self.conv3_c = SiamUnet._block(features * 8, features * 8, name="conv3")

        self.upconv2_c = nn.ConvTranspose2d(features * 8, features * 2, kernel_size=2, stride=2)
        self.conv2_c = SiamUnet._block(features * 4, features * 4, name="conv2")

        self.upconv1_c = nn.ConvTranspose2d(features * 4, features, kernel_size=2, stride=2)
        self.conv1_c = SiamUnet._block(features * 2, features * 2, name="conv1")

         # Final classification layer
        self.conv_c = nn.Conv2d(in_channels=features * 2, out_channels=out_channels_c, kernel_size=1)


    def forward(self, x1, x2):

        # UNet on x1
        enc1_1 = self.encoder1(x1)
        enc2_1 = self.encoder2(self.pool1(enc1_1))
        enc3_1 = self.encoder3(self.pool2(enc2_1))
        enc4_1 = self.encoder4(self.pool3(enc3_1))

        bottleneck_1 = self.bottleneck(self.pool4(enc4_1))

        dec4_1 = self.upconv4(bottleneck_1)
        dec4_1 = torch.cat((dec4_1, enc4_1), dim=1) # Concatenate the encoder and decoder outputs
        dec4_1 = self.decoder4(dec4_1)
        dec3_1 = self.upconv3(dec4_1)
        dec3_1 = torch.cat((dec3_1, enc3_1), dim=1)
        dec3_1 = self.decoder3(dec3_1)
        dec2_1 = self.upconv2(dec3_1)
        dec2_1 = torch.cat((dec2_1, enc2_1), dim=1)
        dec2_1 = self.decoder2(dec2_1)
        dec1_1 = self.upconv1(dec2_1)
        dec1_1 = torch.cat((dec1_1, enc1_1), dim=1)
        dec1_1 = self.decoder1(dec1_1)

        # UNet on x2
        enc1_2 = self.encoder1(x2)
        enc2_2 = self.encoder2(self.pool1(enc1_2))
        enc3_2 = self.encoder3(self.pool2(enc2_2))
        enc4_2 = self.encoder4(self.pool3(enc3_2))

        bottleneck_2 = self.bottleneck(self.pool4(enc4_2))

        dec4_2 = self.upconv4(bottleneck_2)
        dec4_2 = torch.cat((dec4_2, enc4_2), dim=1)
        dec4_2 = self.decoder4(dec4_2)
        dec3_2 = self.upconv3(dec4_2)
        dec3_2 = torch.cat((dec3_2, enc3_2), dim=1)
        dec3_2 = self.decoder3(dec3_2)
        dec2_2 = self.upconv2(dec3_2)
        dec2_2 = torch.cat((dec2_2, enc2_2), dim=1)
        dec2_2 = self.decoder2(dec2_2)
        dec1_2 = self.upconv1(dec2_2)
        dec1_2 = torch.cat((dec1_2, enc1_2), dim=1)
        dec1_2 = self.decoder1(dec1_2)

        # Siamese (calculating differences between the pre and post images)
        dec1_c = bottleneck_2 - bottleneck_1

        dec1_c = self.upconv4_c(dec1_c) # features * 16 -> features * 8
        diff_2 = enc4_2 - enc4_1 # features * 16 -> features * 8
        dec2_c = torch.cat((diff_2, dec1_c), dim=1) # Combine differences and decoder output
        dec2_c = self.conv4_c(dec2_c)

        dec2_c = self.upconv3_c(dec2_c) # 512->256
        diff_3 = enc3_2 - enc3_1
        dec3_c = torch.cat((diff_3, dec2_c), dim=1) # Combine differences and decoder output
        dec3_c = self.conv3_c(dec3_c)

        dec3_c = self.upconv2_c(dec3_c) #512->256
        diff_4 = enc2_2 - enc2_1
        dec4_c = torch.cat((diff_4, dec3_c), dim=1) #
        dec4_c = self.conv2_c(dec4_c)

        dec4_c = self.upconv1_c(dec4_c)
        diff_5 = enc1_2 - enc1_1
        dec5_c = torch.cat((diff_5, dec4_c), dim=1)
        dec5_c = self.conv1_c(dec5_c)

        return self.conv_s(dec1_1), self.conv_s(dec1_2), self.conv_c(dec5_c)

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential( #Helper function to create the blocks used in both encoding and decoding phases.
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# Loading Dataset Splits, Mean/Std Statistics, and Creating Datasets for Train, Validation, and Test Sets

In [None]:
import os
import json

def load_data(split_json_path, mean_std_json_path, base_path="shards"):
    # Load split definitions
    with open(split_json_path, "r") as f:
        raw_splits = json.load(f)  # Read the split information (train, val, test) for each scene
    print("✅ Loaded scenes:", len(raw_splits)) # Print the number of scenes loaded from the split file

    # Load mean/std statistics
    with open(mean_std_json_path, "r") as f:
        mean_std = json.load(f)  # Read the mean and std statistics for each tile
    print("✅ Loaded mean/std stats:", len(mean_std))   # Print the number of mean/std entries

    # Extract all .npz files available (without extension)
    available_tiles = {
        fname.replace(".npz", "") # Remove the .npz extension to get the tile ID
        for fname in os.listdir(base_path)
        if fname.endswith(".npz")
    }

    # Helper to filter and match with available npz files
    def collect_ids(split_name):
        ids = [] # Initialize the list to store matching tile IDs
        for scene_id, splits in raw_splits.items():  # Iterate through the raw splits (train/val/test)
            for tile_id in splits.get(split_name, []): # Get tile IDs for the specific split (train/val/test)
                if tile_id in available_tiles and tile_id in mean_std:  # Check if tile exists and has mean/std stats
                    ids.append(tile_id) # Add the matching tile ID to the list
        return ids

    # Collect tile IDs for each split (train, validation, test)
    train_ids = collect_ids("train")
    val_ids = collect_ids("val")
    test_ids = collect_ids("test")

    print(f"✅ Final splits → Train: {len(train_ids)}, Val: {len(val_ids)}, Test: {len(test_ids)}")

    # Create dataset objects for each split using the filtered tile IDs and mean/std statistics
    train_ds = DamageDataset(train_ids, mean_std, transform=True, base_path=base_path) # Train dataset with augmentation
    val_ds = DamageDataset(val_ids, mean_std, transform=False, base_path=base_path) # Validation dataset without augmentation
    test_ds = DamageDataset(test_ids, mean_std, transform=False, base_path=base_path) # Test dataset without augmentation

    return train_ds, val_ds, test_ds

# Model Training and Validation: Training Loop with Loss Calculation, Optimizer, and Saving Best Model

In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F

def train_model(
    model,
    train_ds,
    val_ds,
    device,
    epochs=30,
    batch_size=8,
    save_dir="checkpoints",
    num_classes=5,
    ignore_index=255,
    early_stopping_patience=3
):
    os.makedirs(save_dir, exist_ok=True)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

    # Define loss function weights for segmentation and damage classification
    weights_seg = torch.tensor([1.0, 15.0], dtype=torch.float32).to(device)
    weights_dmg = torch.tensor([1.0, 2.11, 2.51, 9.51, 17.02], dtype=torch.float32).to(device)
    # weights_dmg = torch.tensor([1.0, 35.0, 70.0, 150.0, 120.0], dtype=torch.float32).to(device)

    # Using only CrossEntropyLoss
    loss_seg_pre = torch.nn.CrossEntropyLoss(weight=weights_seg, ignore_index=ignore_index)
    loss_seg_post = torch.nn.CrossEntropyLoss(weight=weights_seg, ignore_index=ignore_index)
    loss_dmg = torch.nn.CrossEntropyLoss(weight=weights_dmg, ignore_index=ignore_index)

    loss_weights = [0.0, 0.0, 1.0]
    train_losses = []
    val_losses = []
    best_val_loss = float("inf")
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0

        for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1} Training")):
            pre = batch["pre_image"].to(device)
            post = batch["post_image"].to(device)
            bld_mask = batch["building_mask"].to(device)
            dmg_mask = batch["damage_mask"].to(device)

            dmg_mask = torch.clamp(dmg_mask, 0, num_classes - 1) # Clamp damage mask to valid class indices

            optimizer.zero_grad() # Zero the gradients before the backward pass
            out_pre, out_post, out_dmg = model(pre, post) # Forward pass through the model

            # Calculate building mask predictions
            with torch.no_grad():
                pred_building = torch.argmax(out_pre.softmax(dim=1), dim=1)

            # Apply mask to damage predictions (consider only damage where building is present)
            for c in range(out_dmg.shape[1]):
                out_dmg[:, c] *= (pred_building == 1)

            # Calculate the individual losses for segmentation and damage classification
            try:
                loss_pre = loss_seg_pre(out_pre, bld_mask)
                loss_post = loss_seg_post(out_post, bld_mask)
                loss_damage = loss_dmg(out_dmg, dmg_mask)

                 # Combine the losses with appropriate weights
                loss = loss_weights[0]*loss_pre + loss_weights[1]*loss_post + loss_weights[2]*loss_damage
                loss.backward() # Backward pass
                optimizer.step() # Optimizer step
                total_train_loss += loss.item() # Add loss to total training loss
            except RuntimeError as e:
                print("❌ Runtime error during loss computation:", e)
                continue

        avg_train_loss = total_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}")

        # Validation
        model.eval() # Set the model to evaluation mode
        total_val_loss = 0 # Initialize total validation loss

        with torch.no_grad(): # No gradient calculation during validation
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} Validation"):
               # Load the validation data
                pre = batch["pre_image"].to(device)
                post = batch["post_image"].to(device)
                bld_mask = batch["building_mask"].to(device)
                dmg_mask = batch["damage_mask"].to(device)

                dmg_mask = torch.clamp(dmg_mask, 0, num_classes - 1) # Clamp damage mask to valid class indices
                out_pre, out_post, out_dmg = model(pre, post)

                 # Apply the building mask logic for damage predictions
                pred_building = torch.argmax(out_pre.softmax(dim=1), dim=1)
                for c in range(out_dmg.shape[1]):
                    out_dmg[:, c] *= (pred_building == 1)

                 # Calculate the validation losses
                loss_pre = loss_seg_pre(out_pre, bld_mask)
                loss_post = loss_seg_post(out_post, bld_mask)
                loss_damage = loss_dmg(out_dmg, dmg_mask)

                # Combine the validation losses
                loss = loss_weights[0]*loss_pre + loss_weights[1]*loss_post + loss_weights[2]*loss_damage
                total_val_loss += loss.item()

        # Calculate and print average validation loss for the epoch
        avg_val_loss = total_val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        print(f"Epoch {epoch+1}, Val Loss: {avg_val_loss:.4f}")

        scheduler.step()

        # Save the best model (with lowest validation loss)
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))
            print(f"✅ Best model saved at epoch {epoch+1} with val loss {avg_val_loss:.4f}")
        #     patience_counter = 0
        # else:
        #     patience_counter += 1
        #     print(f"⏳ No improvement. Patience counter: {patience_counter}/{early_stopping_patience}")
        #     if patience_counter >= early_stopping_patience:
        #         print("⏹️ Early stopping triggered.")
        #         break

    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training vs Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(save_dir, "loss_curve_ce_only.png"))
    plt.show()


In [None]:
# del model
# torch.cuda.empty_cache()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SiamUnet().float().to(device)

train_ds, val_ds, test_ds = load_data(
    split_json_path="final_splits_sliced.json",
    mean_std_json_path="mean_stddev_tiles.json",
    base_path="/content/shards"
)

train_model(
    model=model,
    train_ds=train_ds,
    val_ds=val_ds,
    device=device,
    epochs=30,
    batch_size=32,
    save_dir="/content/drive/MyDrive/691_Team4_Outputs/checkpoints"
)


# Model Evaluation and Visualization: Generating Classification Report, Confusion Matrix, and IoU Scores

In [None]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
    confusion_matrix,
    classification_report,
    ConfusionMatrixDisplay,
    jaccard_score,
    accuracy_score
)

def denormalize(img_tensor, mean, std):
    """Reverse normalization for visualization"""
    mean = torch.tensor(mean).view(-1, 1, 1).to(img_tensor.device)
    std = torch.tensor(std).view(-1, 1, 1).to(img_tensor.device)
    return img_tensor * std + mean

def visualize_prediction(pre, post, pred, mask, save_path):
    fig, axs = plt.subplots(1, 4, figsize=(16, 4)) # Create 4 subplots for each image

    axs[0].imshow(pre.permute(1, 2, 0).cpu().numpy()) # Display pre-disaster image
    axs[0].set_title("Pre-disaster")

    axs[1].imshow(post.permute(1, 2, 0).cpu().numpy()) # Display post-disaster image
    axs[1].set_title("Post-disaster")

    axs[2].imshow(pred.cpu().numpy(), cmap="viridis", vmin=0, vmax=4)  # Display predicted mask
    axs[2].set_title("Predicted Mask")

    axs[3].imshow(mask.cpu().numpy(), cmap="viridis", vmin=0, vmax=4) # Display ground truth mask
    axs[3].set_title("Ground Truth")

    for ax in axs:
        ax.axis("off")

    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()

def evaluate_model(model, test_loader, device, save_dir="checkpoints", num_classes=5, ignore_index=255):
   # Load the best model from the specified directory
    model_path = os.path.join(save_dir, "best_model.pth")
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Best model not found at: {model_path}")

    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    all_preds = []
    all_masks = []

    vis_dir = os.path.join(save_dir, "test_visuals")
    os.makedirs(vis_dir, exist_ok=True)

    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            pre = batch["pre_image"].to(device).float() # Load pre-disaster image
            post = batch["post_image"].to(device).float() # Load post-disaster image
            mask = batch["damage_mask"].to(device) # Load ground truth damage mask

            _, _, logits = model(pre, post) # Get the model predictions
            preds = torch.argmax(logits, dim=1) # Get the predicted class for each pixel

            # Append predictions and masks to lists
            for p, t in zip(preds, mask):
                all_preds.append(p.flatten().cpu().numpy())
                all_masks.append(t.flatten().cpu().numpy())

            # Denormalize for visualization
            if i < 5:
                pre_vis = denormalize(pre[0], mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                post_vis = denormalize(post[0], mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                visualize_prediction(pre_vis, post_vis, preds[0], mask[0], os.path.join(vis_dir, f"sample_{i}.png"))

    # Flatten the lists for metrics calculation
    all_preds = np.concatenate(all_preds)
    all_masks = np.concatenate(all_masks)

    valid = all_masks != ignore_index # Exclude invalid pixels (ignore_index)
    y_true = all_masks[valid] # Ground truth
    y_pred = all_preds[valid] # Predictions

    # 📄 Classification Report
    cls_report = classification_report(y_true, y_pred, digits=4, zero_division=0)
    print("✅ Classification Report:")
    print(cls_report)
    with open(os.path.join(save_dir, "classification_report.txt"), "w") as f:
        f.write(cls_report)

    # 📊 Confusion Matrix
    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes))) # Compute confusion matrix
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[f"Class {i}" for i in range(num_classes)])
    disp.plot(cmap="Blues", xticks_rotation=45)
    plt.savefig(os.path.join(save_dir, "confusion_matrix.png"))
    np.save(os.path.join(save_dir, "confusion_matrix.npy"), cm)
    plt.close()

    # 📈 IoU per class
    ious = jaccard_score(y_true, y_pred, average=None, labels=list(range(num_classes)), zero_division=0)
    iou_lines = [f"Class {i}: {iou:.4f}" for i, iou in enumerate(ious)]
    mean_iou = np.mean(ious)
    print("📈 Mean IoU per class:")
    print("\n".join(iou_lines))
    print(f"🔁 Mean IoU: {mean_iou:.4f}")

    with open(os.path.join(save_dir, "ious_per_class.txt"), "w") as f:
        f.write("\n".join(iou_lines))
        f.write(f"\nMean IoU: {mean_iou:.4f}")

    # ✅ Overall Accuracy
    acc = accuracy_score(y_true, y_pred)
    with open(os.path.join(save_dir, "accuracy.txt"), "w") as f:
        f.write(f"Overall Accuracy: {acc:.4f}")
    print(f"✅ Overall Accuracy: {acc:.4f}")

# Loading Test Data and Evaluating the Model: Prediction, Metrics, and Saving Results

In [None]:
from torch.utils.data import DataLoader
# Creating a DataLoader for the test dataset with specified parameters
test_loader = DataLoader(
    test_ds,              # your test dataset
    batch_size=8,         # reasonable batch size for evaluation
    shuffle=False,        # no shuffling for test evaluation
    num_workers=4,        # parallel loading
    pin_memory=True       # improves performance on CUDA
)

evaluate_model(
    model=model,  # your trained SiamUnet model
    test_loader=test_loader,
    device=device,
    save_dir="/content/drive/MyDrive/691_Team4_Outputs/checkpoints",  # where to save results
    num_classes=5,
    ignore_index=255
)