# SPECTRAL UNMIXING

## Import Libraries

In [None]:
import pandas as pd
import geopandas as gpd
import rasterio
from rasterio.transform import from_origin
from rasterio.transform import rowcol
import numpy as np
import os
import glob
from collections import defaultdict
from sklearn.ensemble import RandomForestRegressor

## Optionally mask S2 images with SCL band if it has not alrady been done

### Define Parameters

In [None]:
target_date = "2023-11-17"
folder_path = r"E:\TESI\OpenEO\Output\AllBands"
output_folder = r"E:\TESI\OpenEO\Output\MaskedMaps"
nodata_value = -99999 # Nodata value to use for masked pixels
bands = ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12"]

### Clouds, Cloud Shadows and Cyrrus Masking

In [None]:
scl_path = rf"E:\TESI\OpenEO\Output\SCL_20m_{target_date}.tiff"

# --- Load the SCL band ---
with rasterio.open(scl_path) as scl_ds:
    scl_data = scl_ds.read(1)

# Cloud mask: True where NOT cloudy, False where cloudy
valid_mask = ~np.isin(scl_data, [3, 8, 9, 10])

# --- Process each band ---
for band in bands:
    input_path = os.path.join(folder_path, f"Sentinel2_{band}_{target_date}.tiff")
    output_path = os.path.join(output_folder, f"Sentinel2_{band}_{target_date}_masked.tiff")

    print(f"Processing {band}...")

    with rasterio.open(input_path) as src:
        band_data = src.read(1).astype(np.int32)  # Promote to avoid overflow
        profile = src.profile

    # Apply the mask: set nodata_value where cloudy
    masked_band = np.where(valid_mask, band_data, nodata_value)

    # Update profile
    profile.update(dtype=rasterio.int32, nodata=nodata_value)

    # Ensure output folder exists
    os.makedirs(output_folder, exist_ok=True)

    with rasterio.open(output_path, 'w', **profile) as dst:
        dst.write(masked_band, 1)

    print(f"Saved masked band: {output_path}")


## Generate a daframe containing the spectral signature values for each point

### Define Parameters

In [None]:
target_date = "2023-11-17" # Date you are interested in
folder_path   = r"E:\TESI\OpenEO\Output\MaskedMaps" # Folder containing all Bands
bands         = ["B02","B03","B04","B05","B06","B07","B08","B8A","B11","B12"] # Bands you are interested in
gpkg_path     = rf"E:\TESI\Spectral_Unmixing_Geoinf_Proj\Unmixing_QGIS\Output\Final_points_{target_date}.gpkg" # Path to the geopakage containing training points
nan_values = -99999 # Nodata value to use for masked pixels
output_path = rf"E:\TESI\Spectral_Unmixing_Geoinf_Proj\Python_Output\fraction_maps_multiband_masked_{target_date}.tif" #Folder where fraction maps will be saved in
reference_raster =f"E:\TESI\OpenEO\Output\MaskedMaps\Sentinel2_B02_{target_date}_masked.tiff" # Reference raster to copy spatial metadata, could be a any Band

### Load and Stack Raster Bands

In [None]:
band_arrays = []
band_names  = []

for band in bands:
    pattern = os.path.join(folder_path, f"Sentinel2_{band}_{target_date}_masked.tiff") # Naming convention of your S2 images -- Change it based on your image name convention
    matches = glob.glob(pattern) # Images matching the pattern
    if not matches:
        raise FileNotFoundError(f"No file found for band {band!r} on {target_date}")
    tif_path = matches[0]
    
    with rasterio.open(tif_path) as src:
        if not band_arrays:
            # Capture spatial metadata from chosen image
            transform = src.transform
            print("Transform:", transform)
            crs       = src.crs
            print("Raster CRS:", crs)
            height, width = src.height, src.width
        
        arr = src.read(1)  # (rows, cols)
        band_arrays.append(arr)
        band_names.append(band)

# stack into (n_bands, rows, cols)
stacked_bands = np.stack(band_arrays, axis=0)
height, width = stacked_bands.shape[1:]  # rows, cols

# Get proper bounds from the first raster (assuming all images have the same spatial extent)
with rasterio.open(tif_path) as src:
    bounds = src.bounds  # BoundingBox(left, bottom, right, top)

# LOAD POINTS AND REPROJECT
gdf = gpd.read_file(gpkg_path)
gdf = gdf.to_crs(crs)  # ensure points are in the same CRS as the images

# EXTRACT SPECTRA AT POINT LOCATIONS
samples = []
for idx, pt in gdf.iterrows(): #iterate over each point as rows in a table, pt is a row including fid, endmember, endmember_class
    x, y = pt.geometry.x, pt.geometry.y
    #extracts values form the row
    fid           = idx
    endmember     = pt["endmember"]
    endmember_cls = pt["endmember_class"]
    
    # convert to row, col
    row, col = rowcol(transform, x, y)
    #col, row = rowcol(transform, x, y)
    col = int(col)
    row = int(row)
    
    # check bounds
    if not (0 <= row < stacked_bands.shape[1] and 0 <= col < stacked_bands.shape[2]):
        print(f"⚠️  Point {fid} at ({x:.0f},{y:.0f}) outside raster bounds; skipping")
        continue
    
    spectrum = stacked_bands[:, row, col]  # length = len(bands)
    
    # build dict: {"B02":val, ..., "B12":val, "endmember":..., "endmember_class":...}
    sample = {band_names[i]: float(spectrum[i]) for i in range(len(band_names))}
    sample["fid"]             = fid
    sample["endmember"]       = endmember
    sample["endmember_class"] = endmember_cls
    samples.append(sample)

# TRANSFORM TO DATAFRAME
endmember_df = pd.DataFrame(samples)

## Build a syntetic training set

### Get Pure Spectra for Each Endmember Class

In [None]:
# assume endmember_df has columns B02…B12 and endmember_class (integers 1–7)
spectra_by_class = defaultdict(list)

for _, row in endmember_df.iterrows():
    clsId = row["endmember_class"]
    spectrum = row[bands].values.astype(float)  # array of length = "number of bands"
    spectra_by_class[clsId].append(spectrum)

# Convert to numpy arrays
for clsId in spectra_by_class:
    spectra_by_class[clsId] = np.stack(spectra_by_class[clsId], axis=0)
    print(f"Class {clsId} has {spectra_by_class[clsId].shape[0]} pure spectra")

### Function to create a Synthetic Dataset
Given a class, make N mixed spectra by combining:
A random pure target spectrum
A random pure background spectrum (from any other class)
A random fraction f ∈ [0,1]

In [None]:
def make_synthetic_dataset(target_cls, spectra_by_class, n_samples=1000):
    X = []  # mixed spectra, a n band vector (n = number of bands) shape: (n_samples, n_bands)
    y = []  # target fractions, float from 0 to 1 shape: (n_samples,)

    target_specs = spectra_by_class[target_cls] # pure spectra for the class we focus on
    # build a background pool composed of all spectra from all other classes
    bg_specs = np.vstack([
        specs for cls, specs in spectra_by_class.items() if cls != target_cls
    ])

    for _ in range(n_samples):
        # pick one random pure target and one random pure background
        t = target_specs[np.random.randint(len(target_specs))] # pick a random target spectrum
        b = bg_specs[np.random.randint(len(bg_specs))] # pick a random background spectrum
        f = np.random.rand()  # fraction of target float from 0 to 1

        mix = f * t + (1 - f) * b #synthetic mixed spectrum resulting from a linear combination, with high values of f the resulting mix will be close to che pure spectrum
        X.append(mix)
        y.append(f)

    return np.vstack(X), np.array(y)

## Train One Regressor Per Class

In [None]:
models = {}
for cls in spectra_by_class:
    Xc, yc = make_synthetic_dataset(cls, spectra_by_class, n_samples=2000)
    rf = RandomForestRegressor(n_estimators=100, oob_score=True, random_state=42) # oob_score: (Out-Of-Bag) computed by testing on the samples the tree didn't see during training (usually around 1/3) || random_state=42 for reproducibility (42 because it is the answer of everything)
    rf.fit(Xc, yc) # train the model on the synthetic data
    print(f"Trained RF for class {cls}, OOB score: {rf.oob_score_:.3f}") # 1.0: perfect; 0.0: worst than random
    models[cls] = rf

## Apply each model to every pixel in the stacked image

!!!WARNING!!!   --    THIS PROCEDURE IS COMPUTATIONALLY INTENSIVE

In [None]:
flat_pixels = stacked_bands.reshape(10, -1).T  # shape (n_pixels, 10)
# Create mask of valid pixels (True where all bands are NOT -99999)
valid_mask = np.all(stacked_bands != nan_values, axis=0)  # shape: (height, width)
flat_valid_mask = valid_mask.flatten()
valid_pixels = flat_pixels[flat_valid_mask]

fraction_maps = {} # dictionary with 1 image corresponding to each class
for cls, rf in models.items():
    preds = rf.predict(valid_pixels)          # (n_pixels,)
    preds = np.clip(preds, 0, 1) # Clamp predictions between 0 and 1
    # reshape back to image
    #frac_map = preds.reshape(stacked_bands.shape[1:]) #convert results in the original shape
    frac_map = np.full(valid_mask.shape, np.nan, dtype=np.float32)
    frac_map[valid_mask] = preds

    fraction_maps[cls] = frac_map
    print(f"Predicted fraction map for class {cls}")

## Enforce constrains
All classes sum to 1 in each pixel

In [81]:
# Stack all fraction maps into (n_classes, rows, cols)
cls_list = sorted(fraction_maps.keys()) #list of all class labels
stacked_f = np.stack([fraction_maps[c] for c in cls_list], axis=0) #each "layer" in the 3D array is a class, and each pixel has a fraction value for that class.

# normalize along axis=0 per-pixel enforcing the constrain
s = np.sum(stacked_f, axis=0, keepdims=True)
s[s == 0] = 1 # avoids division-by-zero if a pixel has 0 for all classes
normalized = stacked_f / s

## PLOTTING

### Plot Each Fraction Map

In [None]:
import matplotlib.pyplot as plt

# Assuming you have the class labels in cls_list
n_classes = normalized.shape[0]
fig, axes = plt.subplots(1, n_classes, figsize=(4 * n_classes, 5))

for i, ax in enumerate(axes):
    im = ax.imshow(normalized[i], cmap='viridis', vmin=0, vmax=1)
    ax.set_title(f"Class {cls_list[i]}")
    ax.axis('off')
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

## Saving Maps

### Saving all bands in a single .tiff file

In [None]:
with rasterio.open(reference_raster) as ref: # just a reference raster 
    meta = ref.meta.copy()
    transform = ref.transform
    crs = ref.crs

# Update metadata for multi-band output
meta.update({
    "count": len(cls_list),           # number of bands = number of classes
    "dtype": "float32",               # fraction values
    "driver": "GTiff"
})

# Save the multi-band raster
with rasterio.open(output_path, "w", **meta) as dst:
    for i, cls in enumerate(cls_list):
        dst.write(normalized[i].astype("float32"), i + 1)  # rasterio is 1-indexed
        dst.set_band_description(i + 1, f"Class_{cls}")

print(f"✅ Saved multi-band fraction map to: {output_path}")

# Verify the results

## Check .tif file obtained with the code against the .tif file obtained with the plugin

### Import Libraries

In [3]:
import rasterio
import numpy as np
import matplotlib.pyplot as plt

### Define Parameters

In [5]:
target_date =  "2023-11-17"
path_plugin = rf"E:\TESI\Spectral_Unmixing_Geoinf_Proj\Unmixing_QGIS\Output\Final_Class_Fraction_Layer_Masked_{target_date}.tif" # path to the image generated via plugin
path_code = rf"E:\TESI\Spectral_Unmixing_Geoinf_Proj\Python_Output\fraction_maps_multiband_masked_{target_date}.tif" # path to the image generated via code

# Class labels
class_names = [
    "Shingle", "Metal", "Asphalt/Concrete", "Sand/Bare Soil",
    "Tall Vegetation", "Water", "Grass Low Vegetation"
]

# Thresholds for confident classification
#NOTE: a pixel is defined as confidentially classified if the class with the highest value inside the pixel has value higher then the "dominant_thresh" and
# if the second highest value inside the pixel has value lower than the "second_thresh"
dominant_thresh = 0.7
second_thresh = 0.2

### Verify the results against the one obtained trough the EnMap-Box3 QGIS plugin.
Both maps are creted using a RF as model.
We will refer to the 2 images as "Plugin" and "Coded"

#### Check Map Alignment

In [None]:
with rasterio.open(path_plugin) as src1:
    plugin_data = src1.read()  # shape (bands, rows, cols)

with rasterio.open(path_code) as src2:
    code_data = src2.read()    # shape (bands, rows, cols)

print("Plugin shape:", plugin_data.shape)
print("Code shape:", code_data.shape)

#### Visual Comparison

In [None]:
# Fixed color scaling
vmin, vmax = 0, 1 #Set min and max values so that it will be fixed between the 2 images
diff_vmin, diff_vmax = -0.1, 0.1 # Only pixels with differences bigger than ±0.1 will hit the extreme colors. Otherwise, the pixel will stay stay in a soft color range.

n_classes = plugin_data.shape[0]

# Create a big figure: rows = number of classes, cols = 3 (plugin, code, diff)
fig, axs = plt.subplots(n_classes, 3, figsize=(15, 5 * n_classes))

# If only 1 class, axs will not be 2D, fix that
if n_classes == 1:
    axs = np.expand_dims(axs, axis=0)

for cls_idx in range(n_classes):
    class_name = class_names[cls_idx] if cls_idx < len(class_names) else f"Class {cls_idx+1}"

    # Mask NaN values
    plugin_masked = np.ma.masked_invalid(plugin_data[cls_idx])
    code_masked = np.ma.masked_invalid(code_data[cls_idx])
    diff_masked = np.ma.masked_invalid(code_data[cls_idx] - plugin_data[cls_idx])

    # Plot
    axs[cls_idx, 0].imshow(plugin_masked, cmap='viridis', vmin=vmin, vmax=vmax)
    axs[cls_idx, 0].set_title(f"Plugin - {class_name}")

    axs[cls_idx, 1].imshow(code_masked, cmap='viridis', vmin=vmin, vmax=vmax)
    axs[cls_idx, 1].set_title(f"Coded - {class_name}")

    axs[cls_idx, 2].imshow(diff_masked, cmap='bwr', vmin=diff_vmin, vmax=diff_vmax) # negative = blue, positive = red
    axs[cls_idx, 2].set_title("Difference (zoomed)")

    for ax in axs[cls_idx]:
        ax.axis('off')

plt.tight_layout()
plt.show()

# Print statistics for each class
for cls_idx in range(n_classes):
    class_name = class_names[cls_idx] if cls_idx < len(class_names) else f"Class {cls_idx+1}"
    plugin_masked = np.ma.masked_invalid(plugin_data[cls_idx])
    code_masked = np.ma.masked_invalid(code_data[cls_idx])
    diff_masked = np.ma.masked_invalid(code_data[cls_idx] - plugin_data[cls_idx])

    plugin_min, plugin_max = plugin_masked.min(), plugin_masked.max()
    code_min, code_max = code_masked.min(), code_masked.max()
    diff_min, diff_max = diff_masked.min(), diff_masked.max()

    print(f"{class_name} (Class {cls_idx + 1})")
    print(f"  Plugin    min/max: {plugin_min:.3f}, {plugin_max:.3f}")
    print(f"  Coded min/max: {code_min:.3f}, {code_max:.3f}")
    print(f"  Difference range : {diff_min:.3f}, {diff_max:.3f}")
    print("-" * 50)

### Average Value Difference

In [None]:
# Compute the difference
diff = code_data - plugin_data

# Masking NaN values in the difference and absolute difference
diff_masked = np.ma.masked_invalid(diff)
abs_diff = np.abs(diff_masked)

# Compute the mean absolute difference per class, ignoring NaNs
mean_abs_diff_per_class = np.ma.mean(abs_diff, axis=(1, 2))

# Print the results
for i, diff_val in enumerate(mean_abs_diff_per_class):
    print(f"Class {i+1} - Mean Absolute Difference: {diff_val:.4f}")

### Correlation Coefficient

In [None]:
for i in range(code_data.shape[0]):
    flat_code = code_data[i].flatten()
    flat_plugin = plugin_data[i].flatten()

    # Mask NaN values in both the flattened arrays
    valid_mask = ~np.isnan(flat_code) & ~np.isnan(flat_plugin)

    # Apply the mask to exclude NaN values
    flat_code_valid = flat_code[valid_mask]
    flat_plugin_valid = flat_plugin[valid_mask]
    
    # Compute the Pearson correlation coefficient, ignoring NaNs
    if len(flat_code_valid) > 1:  # Check to avoid insufficient data for correlation
        corr = np.corrcoef(flat_code_valid, flat_plugin_valid)[0, 1]
        print(f"Class {i+1} - Correlation: {corr:.4f}")
    else:
        print(f"Class {i+1} - Not enough valid data for correlation.")

### Scatter Plot

In [None]:
for i in range(code_data.shape[0]): # loop over each class
    flat_code = code_data[i].flatten()
    flat_plugin = plugin_data[i].flatten()

    # Mask NaN values in both arrays
    valid_mask = ~np.isnan(flat_code) & ~np.isnan(flat_plugin)

    # Apply the mask to remove NaN values
    flat_code_valid = flat_code[valid_mask]
    flat_plugin_valid = flat_plugin[valid_mask]
    
    # Compute the Pearson correlation coefficient
    if len(flat_code_valid) > 1:  # Check if there's enough valid data
        corr = np.corrcoef(flat_code_valid, flat_plugin_valid)[0, 1]
    else:
        corr = np.nan  # If not enough valid data, set correlation to NaN
    
    # Create scatter plot
    plt.figure(figsize=(6, 6))
    plt.scatter(flat_plugin_valid, flat_code_valid, s=1, alpha=0.5, color='royalblue')
    plt.plot([0, 1], [0, 1], 'r--', label='Ideal match')
    plt.title(f"{class_names[i]} - Scatter Plot\nCorrelation: {corr:.4f}")
    plt.xlabel("Plugin pixel values")
    plt.ylabel("Your code pixel values")
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

### Average, Maximum and Minimum Difference Between the Plugin and the Code Only on "Confidently Classified" Pixels
NOTE: a pixel is deifned as "Confidentially Classified" if the class with the highest value inside the pixel has value higher then the "dominant_thresh" and if the the second highest value inside the pixel has value lower than the "second_thresh"

In [None]:
# To collect absolute differences between plugin and code results
abs_diffs = []

rows, cols = plugin_data.shape[1], plugin_data.shape[2]

# Loop through all pixels
for row in range(rows):
    for col in range(cols):
        pixel = plugin_data[:, row, col]
        top_two = np.sort(pixel)[-2:]

        if top_two[-1] > dominant_thresh and top_two[-2] < second_thresh:
            dominant_class = np.argmax(pixel)

            plugin_val = plugin_data[dominant_class, row, col]
            code_val = code_data[dominant_class, row, col]

            # Skip if either plugin_val or code_val is NaN
            if np.isnan(plugin_val) or np.isnan(code_val):
                continue

            abs_diffs.append(abs(plugin_val - code_val))

# Convert to NumPy array for stats
abs_diffs = np.array(abs_diffs)

# Print summary
print(f"Number of clearly defined pixels: {len(abs_diffs)}")
print(f"Mean absolute difference: {abs_diffs.mean():.4f}")
print(f"Max absolute difference: {abs_diffs.max():.4f}")
print(f"Min absolute difference: {abs_diffs.min():.4f}")

### Absolute and Percentage Mismatches among "Confidently Classified" Pixels

NOTE: a pixel is defined as "Confidentially Classified" if the class with the highest value inside the pixel has value higher then the "dominant_thresh" and if the second highest value inside the pixel has value lower than the "second_thresh"

In [None]:
# Counters
total_confident = 0
mismatch_count = 0

rows, cols = plugin_data.shape[1], plugin_data.shape[2]
total_pixels = rows * cols

# Loop through all pixels
for row in range(rows):
    for col in range(cols):
        pixel_plugin = plugin_data[:, row, col]
        
        # Skip if there are NaN values in the plugin pixel
        if np.any(np.isnan(pixel_plugin)):
            continue
        
        top_two = np.sort(pixel_plugin)[-2:]

        # Check if this is a clearly defined pixel
        if top_two[-1] > dominant_thresh and top_two[-2] < second_thresh:
            total_confident += 1

            # Get dominant class from both plugin and code
            plugin_class = np.argmax(pixel_plugin)
            
            # Skip if there are NaN values in the code data
            if np.isnan(code_data[plugin_class, row, col]):
                continue

            code_class = np.argmax(code_data[:, row, col])

            if plugin_class != code_class:
                mismatch_count += 1

# Calculate percentages
percent_confident = (total_confident / total_pixels) * 100 if total_pixels > 0 else 0.0
percent_mismatch = (mismatch_count / total_confident) * 100 if total_confident > 0 else 0.0

# Print results
print(f"Percentage of confidently classified pixels: {percent_confident:.2f}%")
print(f"Number of mismatches: {mismatch_count}")
print(f"Percentage of mismatches among confident pixels: {percent_mismatch:.2f}%")

### Histograms Comparison
NOTE: the plugin set classes present whith a low percentage (close to 0), directly to 0. This does not happen in my code

In [None]:
n_classes = len(class_names)

# Create a figure with subplots: 1 row per class
fig, axs = plt.subplots(n_classes, 1, figsize=(8, 4 * n_classes), sharex=True)

# If only 1 class, axs will not be an array
if n_classes == 1:
    axs = [axs]

for i in range(n_classes):
    ax = axs[i]

    # Flatten and filter values
    code_vals = code_data[i].flatten()
    plugin_vals = plugin_data[i].flatten()
    plugin_vals = plugin_vals[plugin_vals >= 0]  # Remove invalid values

    # Normalize histograms to show percentage
    weights_code = np.ones_like(code_vals) / len(code_vals) * 100
    weights_plugin = np.ones_like(plugin_vals) / len(plugin_vals) * 100

    # Plot
    ax.hist(code_vals, bins=50, alpha=0.5, label="Code", range=(0, 1), weights=weights_code)
    ax.hist(plugin_vals, bins=50, alpha=0.5, label="Plugin", range=(0, 1), weights=weights_plugin)

    # Labels and formatting
    ax.set_title(f"{class_names[i]} - Value Distribution")
    ax.set_ylabel("Percentage of Pixels (%)")
    ax.grid(True, linestyle="--", alpha=0.4)
    ax.legend()

# Common x-label for all
axs[-1].set_xlabel("Fraction Value")

plt.tight_layout()
plt.show()

## Testing

### Import libraries

In [None]:
import pandas as pd
import geopandas as gpd
import rasterio
import numpy as np
from sklearn.metrics import confusion_matrix

### Define Parameters

In [None]:
testing_points_path = rf"E:\TESI\Spectral_Unmixing_Geoinf_Proj\Unmixing_QGIS\Output\NEW_VALIDATION_points_{target_date}.gpkg" #Will be renamed to "Testing Points" in the future.
plugin_tif      = rf"E:\TESI\Spectral_Unmixing_Geoinf_Proj\Unmixing_QGIS\Output\Final_Class_Fraction_Layer_Masked_{target_date}.tif"
code_tif        = rf"E:\TESI\Spectral_Unmixing_Geoinf_Proj\Python_Output\fraction_maps_multiband_masked_{target_date}.tif"

# In this case, thresholds are set to so all pixels will be considered
dominant_thresh = 0.0 
second_thresh   = 1.0

### Testing
Considering Dominant OR Non-Dominant Pixels

In [None]:

# LOAD TESTING POINTS
gdf = gpd.read_file(testing_points_path)
truth_col = "endmember_class"

# OPEN RASTERS
src_plug = rasterio.open(plugin_tif)
src_code = rasterio.open(code_tif)

# SAMPLE FRACTIONS & BUILD RECORDS
records = []
cls_list = sorted(range(1, src_plug.count + 1))  # [1,2,…,n_classes]

for idx, pt in gdf.iterrows():
    x, y = pt.geometry.x, pt.geometry.y
    truth = pt[truth_col]

    # sample returns arrays of length = band count
    plugin_vals = np.array(list(src_plug.sample([(x, y)]))[0])
    code_vals   = np.array(list(src_code.sample([(x, y)]))[0])

    def classify(vals):
        # find top two fractions
        top2 = np.sort(vals)[-2:]
        if top2[-1] > dominant_thresh and top2[-2] < second_thresh: # Only if the top value is > "dominant_thresh" and the second is < "second_thresh"
            pred_idx = np.argmax(vals)
            return cls_list[pred_idx], True
        else:
            return None, False

    plug_pred, plug_conf = classify(plugin_vals)
    code_pred, code_conf = classify(code_vals)

    records.append({
        "fid": idx,
        "truth": truth,
        # plugin result
        "plug_pred": plug_pred,
        "plug_conf": plug_conf,
        "plug_correct": plug_conf and (plug_pred == truth), # Correct if the model is confident, and the predicted class matches the true class exactly.
        # code result
        "code_pred": code_pred,
        "code_conf": code_conf,
        "code_correct": code_conf and (code_pred == truth),
    })

df = pd.DataFrame(records)

# 4) OVERALL METRICS
total = len(df)

def summarize(method):
    conf       = df[f"{method}_conf"].sum()
    correct    = df[f"{method}_correct"].sum()
    incorrect  = conf - correct
    non_conf   = total - conf
    return pd.Series({
        "TotalPts": total,
        "Confident": conf,
        "  Correct": correct,
        "  Incorrect": incorrect,
        "NonConfident": non_conf,
        "PctConfident": conf/total*100,
        "PctCorrect|Conf": correct/conf*100 if conf>0 else np.nan,
    })

summary = pd.DataFrame({
    "Plugin": summarize("plug"),
    "Code":   summarize("code"),
}).T

print("\n=== OVERALL SUMMARY ===")
print(summary)

# PER-CLASS METRICS
per_class = []
for cls in cls_list:
    sub = df[df["truth"] == cls]
    total_cls = len(sub)
    for method in ("plug", "code"):
        conf      = sub[f"{method}_conf"].sum()
        correct   = sub[f"{method}_correct"].sum()
        incorrect = conf - correct
        per_class.append({
            "Class": cls,
            "Method": method,
            "TotalPts": total_cls,
            "Confident": conf,
            "Correct":   correct,
            "Incorrect": incorrect,
            "NonConf":   total_cls - conf,
            "PctConf":   conf/total_cls*100 if total_cls>0 else np.nan,
            "PctCorr|Conf": correct/conf*100 if conf>0 else np.nan,
        })

pc_df = pd.DataFrame(per_class)
print("\n=== PER-CLASS SUMMARY ===")
print(pc_df.pivot(index="Class", columns="Method", 
                  values=["TotalPts","Confident","Correct","PctCorr|Conf"]))

### Confusion Matrix

In [None]:
# List of all class labels
cls_list = sorted(df['truth'].unique())

def print_confusion(method_name, pred_col, conf_col):
    # Select only the rows where the method was confident
    mask = df[conf_col]
    y_true = df.loc[mask, 'truth']
    y_pred = df.loc[mask, pred_col]

    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred, labels=cls_list)
    cm_df = pd.DataFrame(cm, index=cls_list, columns=cls_list)

    print(f"\n=== Confusion Matrix for {method_name} (confident only) ===")
    print("Rows = true class, Columns = predicted class\n")
    print(cm_df)
    print()

# Plugin confusion
print_confusion("Plugin", pred_col="plug_pred", conf_col="plug_conf")

# Code confusion
print_confusion("Code",   pred_col="code_pred", conf_col="code_conf")