In [None]:
import mlflow
import numpy as np
from astrovision.data import SatelliteImage, SegmentationLabeledSatelliteImage
import yaml
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import torch

In [None]:
%env MLFLOW_TRACKING_URI=https://projet-slums-detection-128833.user.lab.sspcloud.fr
%env MLFLOW_S3_ENDPOINT_URL=https://minio.lab.sspcloud.fr
model_name = "test"
model_version = "3"
model = mlflow.pytorch.load_model(model_uri=f"models:/{model_name}/{model_version}")
model_mlflow = mlflow.pyfunc.load_model(model_uri=f"models:/{model_name}/{model_version}")
n_bands = int(mlflow.get_run(model_mlflow.metadata.run_id).data.params["n_bands"])

In [None]:
# Import normalization metrics
params = yaml.safe_load(
    mlflow.artifacts.load_text(
        f"{mlflow.get_run(model_mlflow.metadata.run_id).info.artifact_uri}/model/code/metrics-normalization.yaml"
    )
)
normalization_mean, normalization_std = params["mean"], params["std"]
normalization_mean, normalization_std = (
    normalization_mean[:n_bands],
    normalization_std[:n_bands],
)

# Load an image
si = SatelliteImage.from_raster(
    file_path="/vsis3/projet-slums-detection/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0524_8587_U38S_8Bits_0005.jp2",
    dep=None,
    date=None,
    n_bands=n_bands,
)

# reproduce transform
transform = A.Compose(
    [
        A.Normalize(
            max_pixel_value=255.0,
            mean=normalization_mean,
            std=normalization_std,
        ),
        ToTensorV2(),
    ]
)

# normalize the image
# normalized_si = transform(image=np.transpose(si.array, [1, 2, 0]))["image"].unsqueeze(dim=0).numpy()
normalized_si = transform(image=np.transpose(si.array, [1, 2, 0]))["image"].unsqueeze(dim=0)

# predict the mask
# mask = (torch.tensor(model.predict(normalized_si)).softmax(dim=1)[:,1,:,:] > 0.5).numpy()[0,:,:]
# mask = (torch.tensor(model.predict(normalized_si)).sigmoid() > 0.5).numpy()
mask = model(normalized_si).sigmoid() > 0.5

lsi = SegmentationLabeledSatelliteImage(si, mask)

plot = lsi.plot(bands_indices=[0, 1, 2])

In [None]:
from functions.instanciators import get_dataset

ds = get_dataset(
    "segmentation",
    [
        "projet-slums-detection/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0524_8587_U38S_8Bits_0005.jp2"
    ],
    [
        "projet-slums-detection/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0524_8587_U38S_8Bits_0005.np"
    ],
    n_bands,
    1,
    transform,
)

In [None]:
normalized_si

In [None]:
model.predict(normalized_si)

In [None]:
n_bands = 3
normalization_mean = [70.39515812366652, 88.14856950608, 71.13188369425293]
normalization_std = [24.82401944849501, 27.227723561124673, 30.270387883075287]
si = SatelliteImage.from_raster(
    file_path="/vsis3/projet-slums-detection/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0524_8587_U38S_8Bits_0005.jp2",
    dep=None,
    date=None,
    n_bands=n_bands,
)

# reproduce transform
transform = A.Compose(
    [
        A.Normalize(
            max_pixel_value=1.0,
            mean=normalization_mean,
            std=normalization_std,
        ),
        ToTensorV2(),
    ]
)

# normalize the image
normalized_si = transform(image=np.transpose(si.array, [1, 2, 0]))["image"].unsqueeze(dim=0).numpy()

In [None]:
from models.segmentation_module import SegmentationModule
from models.components.segmentation_models import SingleClassDeepLabv3Module
from config.loss import WeightedBCEWithLogitsLoss
from torch import optim

In [None]:
model_torch = SegmentationModule.load_from_checkpoint(
    "epoch=19-step=10620.ckpt",
    model=SingleClassDeepLabv3Module(),
    loss=WeightedBCEWithLogitsLoss(label_smoothing=0, building_class_weight=2),
    optimizer=optim.Adam,
    optimizer_params={"lr": 0.00005},
    scheduler=optim.lr_scheduler.OneCycleLR,
    scheduler_params={
        "monitor": "validation_loss",
        "mode": "min",
        "patience": 3,
    },
    scheduler_interval="epoch",
)

In [None]:
model_torch.eval()
mask = (
    model_torch(
        transform(image=np.transpose(si.array, [1, 2, 0]))["image"].unsqueeze(dim=0)
    ).sigmoid()
    > 0.5
)

In [None]:
lsi = SegmentationLabeledSatelliteImage(si, mask)

plot = lsi.plot(bands_indices=[0, 1, 2])

In [None]:
si.array.astype("float32")[np.newaxis, :].shape

In [None]:
model.predict(si.array.astype("float32")[np.newaxis, :])

In [None]:
model.predict(normalized_si)

In [None]:
model.predict(normalized_si)

In [None]:
torch.sigmoid(torch.tensor(model.predict(normalized_si)))

In [None]:
# Not useful
# create inverse transform
inv_transform = A.Compose(
    [
        A.Normalize(
            max_pixel_value=255.0,
            mean=[
                -1 * mean / (255 * std) for mean, std in zip(normalization_mean, normalization_std)
            ],
            std=[1 / (std * 255**2) for std in normalization_std],
        ),
    ]
)

original = np.transpose(si.array, [1, 2, 0])
unnormalized = np.round(
    inv_transform(image=np.transpose(normalized_si["image"].numpy(), [1, 2, 0]))["image"]
).astype("uint16")