In [1]:
import torch
print("CUDA available:", torch.cuda.is_available())

CUDA available: True


In [2]:
# @title
# Imports
import opensr_model

# other inmports
import torch
import rasterio
from omegaconf import OmegaConf
from importlib.resources import files
from io import StringIO
import requests
from IPython.display import Image, display
import cubo
import numpy as np
import ipywidgets as widgets
import datetime
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from opensr_model.utils import plot_example
from IPython.display import display, clear_output, Image as IPImage

# Set Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"‚úÖ Using compute device: {device}")


‚úÖ Using compute device: cuda


In [3]:
import torch
print(torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("GPU name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")


2.5.1+cu121
CUDA available: True
GPU name: NVIDIA GeForce GTX 1650 with Max-Q Design


In [4]:
# @title
# Helper Functions
def create_model():
  # get Config File
  config_url = "https://raw.githubusercontent.com/ESAOpenSR/opensr-model/refs/heads/main/opensr_model/configs/config_10m.yaml"
  response = requests.get(config_url)
  config = OmegaConf.load(StringIO(response.text))
  # Make sure you're running this Notebook in a GPU environment
  model = opensr_model.SRLatentDiffusion(config, device=device) # create model
  model.load_pretrained(config.ckpt_version) # download checkpint
  assert model.training == False, "Model has to be in eval mode."

  return(model)

model = create_model()

LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 113.63 M params.
Keeping EMAs of 308.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 128, 128) = 65536 dimensions.
making attention of type 'vanilla' with 512 in_channels
Normalization disabled.


  weights = torch.load(weights_file, map_location=self.device)["state_dict"]


Loaded pretrained weights from:  opensr-ldsrs2_v1_0_0.ckpt


In [5]:
# @title
# Helpers: fetch Sentinel-2 cube  ‚ûú  tensor  ‚ûú  quick RGB preview
def get_s2_scene(lat, lon, start, end, max_cc):
    """
    Fetch the first Sentinel-2 L2A acquisition in the date window,
    return (LR_tensor, xarray_datacube). Tensor is  (1,4,128,128)  in [0,1].
    Prints human-friendly diagnostics with emojis.
    """
    IMAGE_INDEX = 0
    bands       = ["B04", "B03", "B02", "B08","SCL"]      # RGB + NIR

    print("üì°  Requesting Sentinel-2 cube ‚Ä¶")
    da = cubo.create(
        lat        = lat, lon = lon,
        collection = "sentinel-2-l2a",
        bands      = bands,
        start_date = start, end_date = end,
        edge_size  = 128, resolution = 10,
        query      = {"eo:cloud_cover": {"lt": max_cc}},
    )

    # --- find the least-cloudy acquisition ------------------------------
    cloud_classes = (8, 9, 10)                                   # SCL codes
    scl = da.sel(band="SCL")                                     # (time, y, x)
    cloud_fraction = (scl.isin(cloud_classes)).mean(("y", "x"))  # (time,)
    best_idx = int(cloud_fraction.argmin())                      # first if tie

    # extract Info
    best_date = str(da.time.values[best_idx])[:10]          # 'YYYY-MM-DD'
    cloud_pct = float(cloud_fraction[best_idx] * 100)       # 0-100
    info = f"{best_date}, {cloud_pct:.2f}% cloud"

    # ---- pretty summary ------------------------------------------------------
    print("\nüìù  Scene Summary")
    print(f"üõ∞Ô∏è  Data provider      :  Microsoft Planetary Computer")
    print(f"üìç  Location           :  {lat:.4f}, {lon:.4f}")
    print(f"üïì  Time span          :  {start}  ‚Üí  {end}")
    print(f"‚òÅÔ∏è  Cloud cover max    :  ‚â§ {max_cc}%")
    print(f"üìê  Tile size / res    :  {da.attrs['edge_size']} px  @ {da.attrs['resolution']} m")
    print(f"üñºÔ∏è  Acquisitions found :  {da.shape[0]}")
    print(f"‚úÖ  Returning Image from '{best_date}' with {cloud_pct}% cloud cover.")

    np_img  = da[best_idx].compute().to_numpy().astype("float32")
    np_img = np_img[:4,:,:]
    tensor  = torch.from_numpy(np_img).to(device) / 10_000      # scale 0-1
    tensor = torch.nan_to_num(tensor, nan=0.0)
    return tensor.unsqueeze(0), da

def plot_lr_image(tensor):
    """Show RGB preview of a (1,4,H,W) tensor."""
    rgb = tensor[0, :3].permute(1, 2, 0).clamp(0, 1).cpu().numpy()
    rgb = rgb*3 # stretch for viz
    rgb = rgb.clip(0,1)
    plt.figure(figsize=(5, 5))
    plt.title("Sen-2 RGB")
    plt.tight_layout()
    plt.imshow(rgb)
    plt.show()

In [6]:
# @title
# Export to Disk
def save_tensor_as_geotiff(tensor, attrs, out_path, super_resolved=False, sr_factor=4):
    """
    Save a PyTorch tensor as a georeferenced GeoTIFF using metadata in attrs.

    Parameters:
        tensor (torch.Tensor or np.ndarray): shape (bands, H, W), values in 0‚Äì1.
        attrs (dict): Metadata from LR image (.attrs).
        out_path (str): Output file path (.tif).
        super_resolved (bool): If True, assumes image is SR upscaled by sr_factor.
        sr_factor (int): SR upscale factor.
    """
    if hasattr(tensor, "cpu"):
        tensor = tensor.cpu().numpy()
    tensor = tensor[0]

    # Scale and clip
    arr = (tensor * 10000).clip(0, 10000).astype(np.uint16)

    # Original georef info
    pixel_size = attrs["resolution"]
    edge_size = attrs["edge_size"]
    central_x = attrs["central_x"]
    central_y = attrs["central_y"]
    epsg = attrs["epsg"]

    # Bounding box remains the same
    total_extent = edge_size * pixel_size
    half_extent = total_extent / 2
    ul_x = central_x - half_extent
    ul_y = central_y + half_extent

    # If SR, update pixel size only (dimensions are already upsampled)
    if super_resolved:
        pixel_size = pixel_size / sr_factor

    # Define geotransform
    transform = rasterio.transform.from_origin(ul_x, ul_y, pixel_size, pixel_size)

    # Save
    with rasterio.open(
        out_path,
        "w",
        driver="GTiff",
        height=arr.shape[1],
        width=arr.shape[2],
        count=arr.shape[0],
        dtype=arr.dtype,
        crs=f"EPSG:{epsg}",
        transform=transform,
    ) as dst:
        dst.write(arr)

In [7]:
# @title
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ build_ui ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
def build_ui(model):
    from ipyleaflet import Map, Marker, basemaps
    global lr_tensor, sr_tensor, lr_attrs

    # basic inputs (always visible)
    intro_label = widgets.HTML(
        "<b>üóìÔ∏è Select the time window and maximum cloud cover:</b><br>"
        "<i>Large windows may result in slow fetching.</i>"
    )

    sdt = widgets.DatePicker(description="Start Date:", value=datetime.date(2023, 6, 1))
    edt = widgets.DatePicker(description="End Date:",   value=datetime.date(2023, 6, 30))
    cc  = widgets.FloatText(description="Max. CC (%):", value=10.0)

    coord_label = widgets.HTML("<b>üìç Choose to either input coordinates or select on the map:</b>")

    # coordinate widgets (hidden until chosen)
    lat = widgets.FloatText(description="Latitude:",value=44.80)
    lon = widgets.FloatText(description="Longitude:",value=2.40)
    lat.layout.display = "none"
    lon.layout.display = "none"

    # choose-coordinates buttons
    btn_coords = widgets.Button(description="Enter coordinates", button_style="info")
    btn_map    = widgets.Button(description="Select on map",    button_style="info")

    # map container (hidden until requested)
    map_box = widgets.Output(layout={"height": "250"})
    map_box.layout.display = "none"

    # main action buttons
    btn_load = widgets.Button(description="Load Scene", button_style="success", disabled=True)
    btn_yes  = widgets.Button(description="Use this scene",        button_style="success")
    btn_no   = widgets.Button(description="Get different scene",   button_style="warning")
    btn_restart = widgets.Button(description="Start Over", button_style="danger")
    btn_restart.layout.display = "none"  # hidden by default

    out = widgets.Output()
    confirm_box = widgets.HBox([btn_yes, btn_no]); confirm_box.layout.display = "none"

    # ‚îÄ‚îÄ coordinate selection logic ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    def enable_load(_=None):
        # enable Load when both lat & lon are set
        btn_load.disabled = lat.value is None or lon.value is None

    def show_coord_boxes(_):
        lat.layout.display = ""
        lon.layout.display = ""
        map_box.layout.display = "none"
        enable_load()

    def show_map(_):
        lat.layout.display = "none"
        lon.layout.display = "none"
        map_box.layout.display = ""
        btn_load.disabled = True  # wait for click

        with map_box:
            clear_output(wait=True)
            m = Map(
                basemap=basemaps.OpenStreetMap.Mapnik,
                center=(20, 0),
                zoom=2,
                scroll_wheel_zoom=True,          # wheel zoom on
                layout={"height": "250"}       # smaller map
            )
            marker = Marker()
            def handle_click(**kwargs):
                if kwargs.get("type") == "click":
                    coords = kwargs["coordinates"]
                    marker.location = coords
                    if marker not in m.layers:
                        m.add_layer(marker)
                    lat.value, lon.value = coords
                    btn_load.disabled = False
            m.on_interaction(handle_click)
            display(m)

    lat.observe(enable_load, "value")
    lon.observe(enable_load, "value")
    btn_coords.on_click(show_coord_boxes)
    btn_map.on_click(show_map)

    # ‚îÄ‚îÄ scene-loading & SR callbacks (your original logic) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    def on_load(_):
        global lr_tensor, lr_attrs
        # hide map and coord-choice widgets immediately
        map_box.layout.display   = "none"
        btn_coords.layout.display = "none"   # ‚Üê NEW
        btn_map.layout.display    = "none"   # ‚Üê NEW
        coord_label.layout.display = "none"  # ‚Üê NEW
        btn_load.layout.display   = "none"   # ‚Üê NEW (hide ‚ÄúLoad Scene‚Äù)

        with out:
            clear_output(wait=True)
            try:
                lr_tensor, lr_attrs = get_s2_scene(
                    lat.value, lon.value,
                    sdt.value.strftime("%Y-%m-%d"),
                    edt.value.strftime("%Y-%m-%d"),
                    float(cc.value),
                )
                plot_lr_image(lr_tensor)

                # ‚¨áÔ∏è make sure the buttons are back to normal every time
                btn_yes.layout.display = ""               # show YES again
                btn_no.description = "Get different scene"
                btn_no.button_style = "warning"

                print("\n üîç  Does this look OK?")
                confirm_box.layout.display = ""
                display(confirm_box)

            except Exception as e:
                print(f"‚ùå  Error: {e}")

                # reuse the existing "Get different scene" button
                btn_yes.layout.display = "none"        # hide the YES button
                btn_no.description = "Start Again"     # temporary label
                btn_no.button_style = "danger"         # optional color change
                confirm_box.layout.display = ""        # show the box
                display(confirm_box)


    def on_yes(_):
        global sr_tensor
        confirm_box.layout.display = "none"
        for w in (
            btn_load, btn_coords, btn_map,     # buttons
            coord_label, intro_label,          # labels
            sdt, edt, cc,                      # date / CC
            lat, lon,                          # coordinate boxes
            map_box                            # ‚Üê map is hidden here as well
        ):
            w.layout.display = "none"

        with out:
            clear_output(wait=True)
            print("\n\n\nüöÄ  Running super-resolution , please wait . . . \n\n")
            try:
                with torch.no_grad():
                    sr_tensor = model.forward(lr_tensor,
                                              sampling_eta=1.,
                                              sampling_steps=200,
                                              sampling_temperature=1.0)

                save_tensor_as_geotiff(lr_tensor, lr_attrs.attrs, "lr.tif", False)
                save_tensor_as_geotiff(sr_tensor, lr_attrs.attrs, "sr.tif", True)
                clear_output(wait=True)  # remove the "Running..." print

                print("‚úÖ  Super-resolution complete!")
                print("üíæ  GeoTIFFs saved ‚Üí lr.tif  |  sr.tif")

                plot_example(lr_tensor, sr_tensor, out_file="example.png")
                display(IPImage("example.png"))
                btn_restart.layout.display = ""
            except Exception as e:
                print(f"‚ùå  SR error: {e}")

    def on_restart(_):
        btn_restart.layout.display = "none"
        on_no(None)  # same logic as "Get different scene"

    def on_no(_):
        confirm_box.layout.display = "none"

        # ‚îÄ‚îÄ restore coordinate-selection UI ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
        btn_coords.layout.display  = ""     # show ‚ÄúEnter coordinates‚Äù
        btn_map.layout.display     = ""     # show ‚ÄúSelect on map‚Äù
        coord_label.layout.display = ""     # instruction line
        btn_load.layout.display    = ""     # show ‚ÄúLoad Scene‚Äù
        btn_load.disabled = True            # disable until user picks again

        # hide map & manual boxes until user decides
        map_box.layout.display = "none"
        lat.layout.display     = "none"
        lon.layout.display     = "none"

        with out:
            clear_output(wait=True)
            print("üîÑ  Choose coordinates again and click *Load Scene*.")

    btn_load.on_click(on_load)
    btn_yes.on_click(on_yes)
    btn_no.on_click(on_no)
    btn_restart.on_click(on_restart)

    # assemble UI
    inputs_box = widgets.VBox([
        intro_label,                                    # ‚Üê new line here
        widgets.HBox([sdt, edt, cc]),
        coord_label,
        widgets.HBox([btn_coords, btn_map]),
        widgets.HBox([lat, lon]),
        map_box
    ])

    ui = widgets.VBox([inputs_box, btn_load, out, btn_restart])
    display(ui)


In [8]:
# @title
"""
For statistical purposes, this cell logs that this notebook has ben run, including the name of the notebook and the current timestamp.
If you do not wish to log this information, skip this cell. No information on you is collected, other than that this notebook has been run at all.
"""
import datetime,requests
# Link to Google Form Survey
form_url = "https://docs.google.com/forms/d/e/1FAIpQLSefLEzmtNEEdzkCkjiO5OLAs_GvGLnpDcumhFW2oBCw6jX6Xw/formResponse"

# Replace with your actual entry IDs from the form:
NOTEBOOK_ID_ENTRY = "entry.1687006499"   # Short Answer field for notebook name
EVENT_ENTRY      = "entry.779554792"   # Static marker - Logs a Run
TIMESTAMP_ENTRY  = "entry.1419170723"   # Short Answer - Logs the TimeStamp

form_data = {
    NOTEBOOK_ID_ENTRY: "NoCode",  # Logs which notebook has been run
    EVENT_ENTRY: "run",
    TIMESTAMP_ENTRY: datetime.datetime.now().isoformat()
}

try:
    requests.post(form_url, data=form_data)
except Exception as e:
    print(e)  # silent fail if offline or blocked
    pass

In [9]:
build_ui(model) 

VBox(children=(VBox(children=(HTML(value='<b>üóìÔ∏è Select the time window and maximum cloud cover:</b><br><i>Larg‚Ä¶