In [None]:
import os
import mlflow
import json

In [None]:

def get_model_from_id(run_id: str) -> mlflow.pyfunc.PyFuncModel:
    """
    This function fetches a trained machine learning model from the MLflow
    model registry based on the specified model name and version.

    Args:
        model_name (str): The name of the model to fetch from the model
        registry.
        model_version (str): The version of the model to fetch from the model
        registry.
    Returns:
        model (mlflow.pyfunc.PyFuncModel): The loaded machine learning model.
    Raises:
        Exception: If the model fetching fails, an exception is raised with an
        error message.
    """

    try:
        model = mlflow.pyfunc.load_model(model_uri=f"runs:/{run_id}/model")
        return model
    except Exception as error:
        raise Exception(f"Failed to fetch model from run_id : {run_id}") from error


def fetch_model(run_id):
    # Load the ML model
    model = get_model_from_id(run_id)

    # Extract several variables from model metadata
    n_bands = int(mlflow.get_run(model.metadata.run_id).data.params["n_bands"])
    tiles_size = int(mlflow.get_run(model.metadata.run_id).data.params["tiles_size"])
    augment_size = int(mlflow.get_run(model.metadata.run_id).data.params["augment_size"])
    module_name = mlflow.get_run(model.metadata.run_id).data.params["module_name"]
    normalization_mean, normalization_std = get_normalization_metrics(model, n_bands)

    return {
        "model": model,
        "n_bands": n_bands,
        "tiles_size": tiles_size,
        "augment_size": augment_size,
        "normalization_mean": normalization_mean,
        "normalization_std": normalization_std,
        "module_name": module_name,
    }


def get_normalization_metrics(model: mlflow.pyfunc.PyFuncModel, n_bands: int):
    """
    Retrieves normalization metrics (mean and standard deviation) for the model.

    Args:
        model (mlflow.pyfunc.PyFuncModel): MLflow PyFuncModel object representing the model.
        n_bands (int): Number of bands in the satellite image.

    Returns:
        Tuple: A tuple containing normalization mean and standard deviation.
    """
    normalization_mean = json.loads(
        mlflow.get_run(model.metadata.run_id).data.params["normalization_mean"]
    )
    normalization_std = json.loads(
        mlflow.get_run(model.metadata.run_id).data.params["normalization_std"]
    )

    # Extract normalization mean and standard deviation for the number of bands
    normalization_mean, normalization_std = (
        normalization_mean[:n_bands],
        normalization_std[:n_bands],
    )

    return (normalization_mean, normalization_std)



In [None]:
run_id = "5cf1eb7bdd4141529688acdd738e739e"
model_info = fetch_model(run_id)

In [None]:
model = model_info["model"]

In [None]:

from typing import Tuple
import argparse
import gc
import numpy as np
import random
import ast

import albumentations as A
import mlflow
import torch
from albumentations.pytorch.transforms import ToTensorV2
from osgeo import gdal
from torch import Generator
from torch.utils.data import DataLoader, random_split

from functions.download_data import (
    get_file_system,
    get_patchs_labels,
    normalization_params,
    get_golden_paths,
    pooled_std_dev,
)
from functions.instanciators import get_dataset, get_lightning_module, get_trainer
from functions.filter import filter_indices_from_labels

gdal.UseExceptions()


In [None]:
remote_server_uri = "https://projet-slums-detection-128833.user.lab.sspcloud.fr"
experiment_name = "test-dev"
run_name =  "stagiosessaye"
task = "segmentation"
source = "PLEIADES"
tiles_size = 250
augment_size = 250
type_labeler = "BDTOPO"
n_bands = 3
logits = 1
freeze_encoder = 0
epochs = 10
batch_size = 8
test_batch_size = 8
num_sanity_val_steps = 1
accumulate_batch = 8
module_name = "deeplab-v3plus-ocr"
loss_name =  "cross_entropy_weighted"
building_class_weight = 1
label_smoothing = 0.0
lr = 0.00005
momentum = float
scheduler_name = "one_cycle"
scheduler_patience = 3
patience = 200
from_s3 = 0
seed = 12345 
cuda = 0
cuda = cuda and torch.cuda.is_available()
kwargs = {"num_workers": os.cpu_count(), "pin_memory": True} if cuda else {}

deps = ["MAYOTTE","MAYOTTE","MAYOTTE"]
years = ["2017","2019","2020"]

normalization_means = []
normalization_stds = []

for dep,year in zip(deps,years):
    normalization_mean, normalization_std = normalization_params(
        task, source, dep, year, tiles_size, type_labeler
    )
    normalization_means.append(normalization_mean)
    normalization_stds.append(normalization_std)


In [None]:
# Golden test
golden_patches, golden_labels = get_golden_paths(
    from_s3, task, source, "MAYOTTE_CLEAN", "2022", tiles_size
)

golden_patches.sort()
golden_labels.sort()


In [None]:
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Assuming normalization_means, normalization_stds, and n_bands are defined elsewhere
normalization_mean = np.average(
    [mean[:n_bands] for mean in normalization_means], weights=[1.0, 1.0, 1.0], axis=0
)

normalization_std = [
    pooled_std_dev(
        [1.0, 1.0, 1.0],
        [mean[i] for mean in normalization_means],
        [std[i] for std in normalization_stds],
    )
    for i in range(n_bands)
]

# Convert numpy arrays to lists
normalization_mean_list = normalization_mean.tolist()


test_transform = [
    A.Normalize(
        max_pixel_value=1.0,
        mean=normalization_mean_list,
        std=normalization_std,
    ),
    ToTensorV2(),
]


In [None]:
golden_dataset = get_dataset(
    task, golden_patches, golden_labels, n_bands, from_s3, test_transform
)

golden_loader = DataLoader(
    golden_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True, **kwargs
)

In [None]:
model

In [None]:
batch = next(iter(golden_loader))
labels = batch["labels"]
images = batch["pixel_values"]


In [None]:
#s atellites_images inference
    # Preprocess the image
    normalized_si = preprocess_image(
        model=model,
        image=image,
        tiles_size=tiles_size,
        augment_size=augment_size,
        n_bands=n_bands,
        normalization_mean=normalization_mean,
        normalization_std=normalization_std,
    )

    # Make prediction using the model
    with torch.no_grad():
        prediction = torch.tensor(model.predict(normalized_si.numpy()))