# MBSP Cosine Similarity Explorer

This notebook demonstrates interactive cosine similarity search across Sentinel-2 TOA bands. Select a pixel on the MBSP fractional image to compare all other pixels in the scene.

In [None]:
# Interactive (weighted) cosine-similarity + distance search across Sentinel-2 TOA bands
# -------------------------------------------------------------------------------------
# 2025-06-10   WEIGHTED VERSION  +  BAND-SPACE DISTANCE MODE

import math
import datetime as dt
import ee
import geemap

ee.Authenticate()
ee.Initialize()


# --------------------------------------------------------------------------
# Cloud mask helper
# --------------------------------------------------------------------------
def mask_s2_clouds(image: ee.Image) -> ee.Image:
    qa = image.select("QA60")
    cloud_bit_mask = 1 << 10
    cirrus_bit_mask = 1 << 11
    mask = qa.bitwiseAnd(cloud_bit_mask).eq(0).And(qa.bitwiseAnd(cirrus_bit_mask).eq(0))
    return (
        image.updateMask(mask)
        .divide(10000)
        .copyProperties(image, image.propertyNames())
    )


# --------------------------------------------------------------------------
# MBSP fractional band helper
# --------------------------------------------------------------------------
def mbsp_fractional_image(image: ee.Image, region: ee.Geometry) -> ee.Image:
    num_img = image.select("B11").multiply(image.select("B12"))
    den_img = image.select("B12").multiply(image.select("B12"))
    num_sum = num_img.reduceRegion(ee.Reducer.sum(), region, 20, bestEffort=True)
    den_sum = den_img.reduceRegion(ee.Reducer.sum(), region, 20, bestEffort=True)
    slope = ee.Number(num_sum.get("B11")).divide(ee.Number(den_sum.get("B12")))
    mbsp = (
        image.select("B12")
        .multiply(slope)
        .subtract(image.select("B11"))
        .divide(image.select("B11"))
        .rename("R")
    )
    return image.addBands(mbsp).set({"slope": slope})


# --------------------------------------------------------------------------
# Bands + weights  (all zeros except SWIR pair, you can tune freely)
# --------------------------------------------------------------------------
BANDS, WEIGHTS = zip(
    *[
        ["B1", 0],
        ["B2", 0],
        ["B3", 0],
        ["B4", 0],
        ["B5", 0],
        ["B6", 0],
        ["B7", 0],
        ["B8", 0],
        ["B8A", 0],
        ["B9", 1],
        ["B11", 0],
        ["B12", 0],
        ["R", 1],  # engineered band (MBSP)
    ]
)
assert len(BANDS) == len(WEIGHTS), "BANDS and WEIGHTS must be same length"

SQRT_W = [math.sqrt(w) for w in WEIGHTS]  # pre-compute once


# --------------------------------------------------------------------------
# Weighted cosine similarity
# --------------------------------------------------------------------------
def weighted_cosine_image(image: ee.Image, ref_feat: ee.Feature) -> ee.Image:
    ref_vals = ee.Image.constant([ee.Number(ref_feat.get(b)) for b in BANDS]).rename(
        BANDS
    )
    proj20 = image.select("B11").projection()
    img20 = image.select(BANDS).resample("bilinear").reproject(proj20)
    ref20 = ref_vals.reproject(proj20)

    w_img = img20.multiply(ee.Image.constant(SQRT_W))
    w_ref = ref20.multiply(ee.Image.constant(SQRT_W))

    dot = w_img.multiply(w_ref).reduce(ee.Reducer.sum())
    mag1 = w_img.pow(2).reduce(ee.Reducer.sum()).sqrt()
    mag2 = w_ref.pow(2).reduce(ee.Reducer.sum()).sqrt()

    sim = dot.divide(mag1.multiply(mag2)).rename("wcos")
    sim = sim.updateMask(img20.mask().reduce(ee.Reducer.anyNonZero()))
    return sim


# --------------------------------------------------------------------------
# Weighted Euclidean distance  (band-space distance)
# --------------------------------------------------------------------------
def weighted_distance_image(image: ee.Image, ref_feat: ee.Feature) -> ee.Image:
    ref_vals = ee.Image.constant([ee.Number(ref_feat.get(b)) for b in BANDS]).rename(
        BANDS
    )
    proj20 = image.select("B11").projection()
    img20 = image.select(BANDS).resample("bilinear").reproject(proj20)
    ref20 = ref_vals.reproject(proj20)

    delta_sq = img20.subtract(ref20).pow(2).multiply(ee.Image.constant(WEIGHTS))
    dist = delta_sq.reduce(ee.Reducer.sum()).sqrt().rename("wdist")
    dist = dist.updateMask(img20.mask().reduce(ee.Reducer.anyNonZero()))
    return dist


# --------------------------------------------------------------------------
# Data selection
# --------------------------------------------------------------------------
lat, lon = 31.6585, 5.9053
start = dt.date(2019, 10, 14)
end = dt.date(2019, 10, 15)
point = ee.Geometry.Point(lon, lat)

collection = (
    ee.ImageCollection("COPERNICUS/S2_HARMONIZED")
    .filterDate(str(start), str(end))
    .filterBounds(point)
    .filter(ee.Filter.lt("CLOUDY_PIXEL_PERCENTAGE", 20))
    .sort("system:time_start")
    .map(mask_s2_clouds)
)
images = collection.toList(collection.size())
count = images.size().getInfo()
print(f"Found {count} images")

# --------------------------------------------------------------------------
# Interactive map
# --------------------------------------------------------------------------
if count:
    region = point.buffer(1000).bounds()
    img0 = ee.Image(images.get(0))
    img = mbsp_fractional_image(img0, region)

    m = geemap.Map(center=(lat, lon), zoom=16)
    m.addLayer(img.select(["B4", "B3", "B2"]), {"min": 0, "max": 0.3}, "RGB", True)
    m.addLayer(
        img.select("R"),
        {"min": -0.1, "max": 0.1, "palette": ["red", "white", "blue"]},
        "MBSP (R)",
        True,
    )

    poi = ee.FeatureCollection([point]).style(**{"color": "green", "pointSize": 8})
    m.addLayer(poi, {}, "Point of interest", True)

    m.sim_layer = None  # weighted cosine layer
    m.dist_layer = None  # weighted distance layer
    mbsp_layer = m.layers[-1]

    # ----------------------------------------------------------------------
    # Click handler
    # ----------------------------------------------------------------------
    def handle_click(**kwargs):
        if kwargs.get("type") != "click":
            return

        latc, lonc = kwargs["coordinates"]
        pt = ee.Geometry.Point(lonc, latc)

        sample_fc = img.select(BANDS).sample(pt, 20, numPixels=1, tileScale=1)
        if sample_fc.size().getInfo() == 0:
            print("No valid data at clicked location.")
            return
        ref_feat = ee.Feature(sample_fc.first())

        sim = weighted_cosine_image(img, ref_feat)
        dist = weighted_distance_image(img, ref_feat)

        # Remove old layers
        if m.sim_layer is not None:
            m.remove_layer(m.sim_layer)
        if m.dist_layer is not None:
            m.remove_layer(m.dist_layer)

        m.addLayer(
            sim,
            {"min": 0.9, "max": 1.0, "palette": ["blue", "white", "red"]},
            "Weighted similarity",
            False,
        )
        m.sim_layer = m.layers[-1]

        m.addLayer(
            dist,
            {"min": 0.0, "max": 0.1, "palette": ["red", "white", "blue"]},
            "Weighted distance",
            False,
        )
        m.dist_layer = m.layers[-1]

    m.on_interaction(handle_click)
    display(m)
