In [1]:
import ipywidgets as widgets
config_yml_input = widgets.Textarea(placeholder="openmapflow.yaml", layout=widgets.Layout(height="10em", width="50%"))
config_yml_input

Textarea(value='', layout=Layout(height='10em', width='50%'), placeholder='openmapflow.yaml')

In [2]:
# with open('openmapflow.yaml', 'w') as f:
#   f.write(config_yml_input.value)

In [3]:
import os
import cmocean
import rasterio as rio
import matplotlib.pyplot as plt
import warnings
from pathlib import Path
import pickle  # For loading/saving models
import numpy as np
from azure.storage.blob import BlobServiceClient
from azure.identity import DefaultAzureCredential

warnings.simplefilter(action='ignore', category=FutureWarning)
PROJECT = "openmapflow_local"
print(PROJECT)

# Azure Storage Configuration
try:
    credential = DefaultAzureCredential()
    blob_service_client = BlobServiceClient(
        account_url="https://openmapflow.blob.core.windows.net/", credential=credential
    )
    container_name = "openmap"
    print("Azure Storage client initialized with managed identity.")
except Exception as e:
    try:
        AZURE_STORAGE_CONNECTION_STRING = os.environ.get(
            "AZURE_STORAGE_CONNECTION_STRING"
        )
        blob_service_client = BlobServiceClient.from_connection_string(
            AZURE_STORAGE_CONNECTION_STRING
        )
        container_name = "openmap"
        print("Azure Storage client initialized with connection string.")
    except Exception as e2:
        print(f"Azure Storage connection failed: {e}, {e2}")
        blob_service_client = None
        container_name = None

# Azure Blob Download
def download_from_azure(blob_name, local_file_path):
    """Downloads a blob from Azure Blob Storage."""
    if blob_service_client is None:
        print("Azure Storage not configured.")
        return False
    try:
        blob_client = blob_service_client.get_blob_client(
            container=container_name, blob=blob_name
        )
        with open(local_file_path, "wb") as download_file:
            download_file.write(blob_client.download_blob().readall())
        print(f"Downloaded {blob_name} to {local_file_path}")
        return True
    except Exception as e:
        print(f"Error downloading {blob_name}: {e}")
        return False

# Azure Blob Upload
def upload_to_azure(local_file_path, blob_name):
    """Uploads a local file to Azure Blob Storage."""
    if blob_service_client is None:
        print("Azure Storage not configured.")
        return False
    try:
        blob_client = blob_service_client.get_blob_client(
            container=container_name, blob=blob_name
        )
        with open(local_file_path, "rb") as data:
            blob_client.upload_blob(data, overwrite=True)
        print(f"Uploaded {local_file_path} to {blob_name}")
        return True
    except Exception as e:
        print(f"Error uploading {local_file_path} to {blob_name}: {e}")
        return False

# Prediction Logic (Example)
def make_new_predictions(input_blob_name, model_blob_name, output_blob_name):
    """
    Makes predictions using a model and input data from Azure Blob Storage.

    Args:
        input_blob_name (str): Name of the input data blob.
        model_blob_name (str): Name of the model blob.
        output_blob_name (str): Name of the output prediction blob.
    """
    local_input_path = "input.tif"
    local_model_path = "data/models/ksaopenmap.pt"
    local_output_path = "output.tif"

    if not download_from_azure(input_blob_name, local_input_path):
        return False
    if not download_from_azure(model_blob_name, local_model_path):
        return False

    try:
        with open(local_model_path, "rb") as f:
            model = pickle.load(f)

        with rio.open(local_input_path) as src:
            data = src.read()
            profile = src.profile
            # Example prediction: replace with your model's prediction logic
            predictions = model.predict(data.reshape(data.shape[0], -1).T).reshape(data.shape[1],data.shape[2])
            predictions = np.expand_dims(predictions, axis=0) #add a channel dimension.

            profile.update(
                dtype=rio.float32,
                count=1,
                compress='lzw'
            )

        with rio.open(local_output_path, 'w', **profile) as dst:
            dst.write(predictions.astype(rio.float32))

        if upload_to_azure(local_output_path, output_blob_name):
            print(f"Predictions saved to {output_blob_name}")
            os.remove(local_input_path)
            os.remove(local_model_path)
            os.remove(local_output_path)
            return True
        else:
            return False

    except Exception as e:
        print(f"Error during prediction: {e}")
        return False

# Example Usage:
# make_new_predictions("input_data.tif", "my_model.pkl", "predictions.tif")

openmapflow_local
Azure Storage client initialized with managed identity.


In [4]:
import os
import torch  # Import PyTorch

def get_available_models(model_directory):
    """Retrieves available models from a local directory."""
    try:
        models = []
        for filename in os.listdir(model_directory):
            if filename.endswith(".pt"):  # Check for .pt files
                model_path = os.path.join(model_directory, filename)
                try:
                    torch.load(model_path, map_location=torch.device('cpu')) #attempt to load the model.
                    models.append(filename) #if successful, add it to the list.
                except Exception as e:
                    print(f"Error loading model {filename}: {e}")
        return models
    except FileNotFoundError:
        print(f"Directory not found: {model_directory}")
        return []

def load_model(model_filename, model_directory):
    """Loads a PyTorch model from a local file."""
    model_path = os.path.join(model_directory, model_filename)
    try:
        model = torch.load(model_path, map_location=torch.device('cpu')) #map to cpu for cross platform loading.
        return model
    except FileNotFoundError:
        print(f"Model file not found: {model_path}")
        return None

# Set the directory where your models are stored
model_directory = "data/models/"  # Replace with your actual directory

available_models = get_available_models(model_directory)

if available_models:
    print(available_models)
    model_name = available_models[0] #load first model.
    loaded_model = load_model(model_name, model_directory)

    if loaded_model:
        print(f"Model {model_name} loaded successfully.")
        # Use the loaded PyTorch model for predictions
        # For example, if you are doing inference.
        # loaded_model.eval() #set the model to evaluation mode.
        # with torch.no_grad(): #disable gradient calculations.
        #    predictions = loaded_model(input_tensor) #input_tensor would be your input data.

    else:
        print(f"Failed to load model {model_name}.")

else:
    print("No models found in the specified directory.")



['ksaopenmapflow.pt']
Model ksaopenmapflow.pt loaded successfully.


In [5]:
import os
import re
from dataclasses import dataclass

from openmapflow.bbox import BBox  # Assuming you have a BBox class

@dataclass
class CustomBBox(BBox):
    name: str = "from_file" #add a name, as it is required.

def extract_bbox_from_filename(filename):
    """Extracts bounding box coordinates from a filename."""
    pattern = r"min_lat=(-?\d+\.?\d*)_min_lon=(-?\d+\.?\d*)_max_lat=(-?\d+\.?\d*)_max_lon=(-?\d+\.?\d*)"
    match = re.search(pattern, filename)

    if match:
        min_lat = float(match.group(1))
        min_lon = float(match.group(2))
        max_lat = float(match.group(3))
        max_lon = float(match.group(4))
        return CustomBBox(
            min_lat=min_lat,
            min_lon=min_lon,
            max_lat=max_lat,
            max_lon=max_lon,
        )
    else:
        return None

def generate_bboxes_from_directory(directory):
    bboxes = []
    for filename in os.listdir(directory):
        if filename.endswith(".tif") and filename.startswith("min_lat="):
            bbox = extract_bbox_from_filename(filename)
            if bbox:
                bboxes.append(bbox)
    return bboxes

# Replace "./your_directory" with the actual path to your directory
directory = "./openmapflow/satdata/"  # or "data/training_data" or whatever your directory is.
available_bboxes = generate_bboxes_from_directory(directory)

if available_bboxes:
    print("Available Bounding Boxes:")
    for bbox in available_bboxes:
        print(bbox)
else:
    print("No bounding boxes found in the directory.")

Available Bounding Boxes:
CustomBBox(min_lat=-0.95, max_lat=-0.55, min_lon=36.75, max_lon=37.35, name='from_file')


##inference configuration


In [7]:
from openmapflow.inference_widgets import InferenceWidget
inference_widget = InferenceWidget(available_models=available_models, available_bboxes=available_bboxes)
inference_widget.ui()

VBox(children=(Box(children=(VBox(children=(HTML(value='<h3>Select model and specify region of interest</h3>')…

In [8]:
from azure.storage.blob import BlobServiceClient  # For Azure Blob Storage
import os
# Assuming you have inference_widget and available_models defined elsewhere

# Azure Configuration
AZURE_STORAGE_CONNECTION_STRING = os.environ.get("AZURE_STORAGE_CONNECTION_STRING")
AZURE_CONTAINER_PREDS_MERGED = os.environ.get("AZURE_CONTAINER_PREDS_MERGED")
# Local Directory Configuration
LOCAL_PREDS_MERGED_PATH = "./local_preds_merged" # Replace with your local path.
AZURE_STORAGE_CONNECTION_STRING

config = inference_widget.get_config_as_dict()
map_key = config["map_key"]
bbox = config["bbox"]
start_date = config["start_date"]
end_date = config["end_date"]
tifs_in_gcloud = config["tifs_in_gcloud"]

def get_map_files(map_key):
    map_files = []
    if AZURE_STORAGE_CONNECTION_STRING and AZURE_CONTAINER_PREDS_MERGED:
        blob_service_client = BlobServiceClient.from_connection_string(AZURE_STORAGE_CONNECTION_STRING)
        container_client = blob_service_client.get_container_client(AZURE_CONTAINER_PREDS_MERGED)
        blobs = container_client.list_blobs(name_starts_with=f"{map_key}.tif")
        map_files.extend([f"az://{AZURE_CONTAINER_PREDS_MERGED}/{blob.name}" for blob in blobs])
    return map_files

existing_map_files = get_map_files(map_key)
while len(existing_map_files) > 0:
    print(f"Map for {map_key} already exists: \n{existing_map_files}")
    map_key += "_" + input(f"Append to map key: {map_key}_")
    existing_map_files = get_map_files(map_key)

print(f"Using map key: {map_key}")
# Now you can use the unique map_key for your further processing.

# export AZURE_STORAGE_ACCOUNT_NAME=openmapflow
# export AZURE_STORAGE_ACCOUNT_KEY=gBh30r5wqeU2HMhfG5jTmG0Ags++3rsYe1wTotQoxNK/EVnCnBCOt7ytHQrJuBya9/qMT/63xE3k+ASth7eOBQ==                                                               export AZURE_CONTAINER_NAME=openmap
# export MODELS_API_URL="./data/models"
# export AZURE_INFERCONTAINER_NAME=inference-eo-container
# export AZURE_STORAGE_CONNECTION_STRING=DefaultEndpointsProtocol=https;EndpointSuffix=core.windows.net;AccountName=openmapflow;AccountKey=gBh30r5wqeU2HMhfG5jTmG0Ags++3rsYe1wTotQoxNK/EVnCnBCOt7ytHQrJuBya9/qMT/63xE3k+ASth7eOBQ==;BlobEndpoint=https://openmapflow.blob.core.windows.net/;FileEndpoint=https://openmapflow.file.core.windows.net/;QueueEndpoint=https://openmapflow.queue.core.windows.net/;TableEndpoint=https://openmapflow.table.core.windows.net/

# export AZURE_CONTAINER_PREDS_MERGED=preds-merged-container
                                                          

Using map key: ksaopenmapflow.pt/min_lat=-0.95_min_lon=36.75_max_lat=-0.55_max_lon=37.35_dates=2020-02-01_2021-02-01_all


#run first inference

In [9]:
from azure.storage.blob import BlobServiceClient
import os

def inference_status(map_key, existing_map_files, local_tifs, preds_in_azure):
    """
    Checks inference status based on TIFF files stored locally and in Azure Blob Storage.
    Handles file uploads, movement, and inference status checks.
    Provides detailed logging of map_key, file locations, and transformations.

    :param map_key: Unique identifier for the map
    :param existing_map_files: List of merged map files in `preds-merged-container`
    :param local_tifs: List of local TIFF files available for processing
    :param preds_in_azure: List of TIFF files in `preds-container` on Azure
    :return: Status message indicating the next step
    """
    print(f"Inference Status Check for map_key: {map_key}") # log the map_key

    tifs_amount, predictions_amount = get_status(map_key)

    # Check if merged map already exists
    if existing_map_files:
        print(f"Merged map found in Azure: {existing_map_files}")
        return f"Merged map available in Azure: {existing_map_files}"

    # Check if inference is complete
    if tifs_amount > 0 and tifs_amount == predictions_amount:
        print(f"Inference complete. Total TIFs: {tifs_amount}, Predictions: {predictions_amount}")
        return "Inference complete! Time to merge predictions into a map."

    # Retry missing predictions
    if tifs_amount > predictions_amount:
        print(f"Inference incomplete. Total TIFs: {tifs_amount}, Predictions: {predictions_amount}")
        if confirmation("Predictions in progress but incomplete. Retry missing predictions? (y/n)"):
            missing = find_missing_predictions(map_key)
            print(f"Missing predictions found: {missing}") # log missing files
            make_new_predictions(missing)
            print("Retrying model on missing predictions...")
            return "Retrying model on missing predictions..."
        else:
            print("Waiting for predictions to complete.")
            return "Waiting for predictions to complete."

    # Move TIFFs within Azure if misplaced
    if preds_in_azure:
        dest_container = AZURE_CONTAINER_PREDS_MERGED
        print(f"TIFFs found in incorrect Azure container: {preds_in_azure}") # log found files
        if confirmation(f"Move TIFFs to {dest_container}?"):
            move_tifs_in_azure(preds_in_azure, dest_container, map_key)
            print(f"Moved TIFFs to {dest_container}")
            return get_status(map_key)

    # Upload local TIFFs to Azure if no existing data is found
    if not preds_in_azure and local_tifs:
        print(f"Local TIFFs found: {local_tifs}") # log local files
        if confirmation("No existing predictions found in Azure. Upload local TIFFs?"):
            upload_tifs_to_azure(local_tifs, AZURE_CONTAINER_PREDS_MERGED, map_key)
            print(f"Uploaded {len(local_tifs)} TIFFs to Azure container {AZURE_CONTAINER_PREDS_MERGED}")
            return f"Uploading {len(local_tifs)} TIFFs to Azure container {AZURE_CONTAINER_PREDS_MERGED}..."

    print("No data available for inference.")
    return "No data available for inference. Provide TIFF files locally or in Azure."

Merge predictions into a map

In [10]:
from pathlib import Path
import os
from azure.storage.blob import BlobServiceClient
from azure.core.exceptions import AzureError
from openmapflow.inference_utils import build_vrt

# Get Azure Storage account details from environment variables
storage_account_name = os.environ.get("AZURE_STORAGE_ACCOUNT_NAME")
storage_account_key = os.environ.get("AZURE_STORAGE_ACCOUNT_KEY")
container_name = os.environ.get("AZURE_CONTAINER_PREDS")

#error checking for environment variables
if not storage_account_name or not storage_account_key or not container_name:
    print("Error: Missing Azure Storage environment variables.")
    exit(1)

# Check if map_key is defined
try:
    map_key
except NameError:
    print("Error: map_key is not defined. Please ensure it is defined in a previous cell.")
    exit(1)

# Create Azure Blob Service client
connection_string = f"DefaultEndpointsProtocol=https;AccountName={storage_account_name};AccountKey={storage_account_key};EndpointSuffix=core.windows.net"
try:
    blob_service_client = BlobServiceClient.from_connection_string(connection_string)
except ValueError as e:
    print(f"Error creating BlobServiceClient: {e}")
    exit(1)

# Create local directories
prefix = map_key.replace("/", "_")
Path(f"{prefix}_preds").mkdir(exist_ok=True)
Path(f"{prefix}_vrts").mkdir(exist_ok=True)
Path(f"{prefix}_tifs").mkdir(exist_ok=True)

print("Download predictions as nc files (may take several minutes)")
source_prefix = f"{map_key}"
destination_folder = f"{prefix}_preds"

# List and download blobs with the given prefix
container_client = blob_service_client.get_container_client(container_name)
try:
    blobs = container_client.list_blobs(name_starts_with=source_prefix)
except AzureError as e:
    print(f"Error listing blobs: {e}")
    exit(1)

# Download each blob
for blob in blobs:
    relative_path = blob.name[len(source_prefix):].lstrip('/')
    destination_path = os.path.join(destination_folder, relative_path)
    os.makedirs(os.path.dirname(destination_path), exist_ok=True)

    if os.path.exists(destination_path):
        print(f"Skipping existing file: {destination_path}")
        continue

    print(f"Downloading {blob.name}...")
    blob_client = container_client.get_blob_client(blob.name)
    try:
        with open(destination_path, "wb") as download_file:
            download_file.write(blob_client.download_blob().readall())
    except (AzureError, OSError) as e:
        print(f"Error downloading {blob.name}: {e}")

# Call the build_vrt function after downloads complete
try:
    build_vrt(prefix)
except Exception as e:
    print(f"Error building VRT: {e}")

Loading config from: /home/joshua/openmapflow/openmapflow.yaml
Config loaded: {'project': 'openmapflow_project', 'data_paths': {'raw_labels': '/data/raw_labels', 'datasets': '/data/datasets', 'models': '/data/models', 'metrics': '/data/metrics', 'report': '/data/report'}, 'azure': {'labeled_eo_container': 'openmap', 'inference_eo_container': 'inference-eo-container', 'preds_container': 'preds-container', 'preds_merged_container': 'preds-merged-container'}}
Download predictions as nc files (may take several minutes)
Building vrt for each batch


0it [00:00, ?it/s]

Prefix for final VRT: ksaopenmapflow.pt_min_lat=-0.95_min_lon=36.75_max_lat=-0.55_max_lon=37.35_dates=2020-02-01_2021-02-01_all
Building full vrt





0...10...20...30...40...50...60...70...80...90...100 - done.


