In [None]:
# os.chdir("satellite-images-inference/")
import os

import geopandas as gpd
import mlflow
import numpy as np
import torch
from astrovision.data import SegmentationLabeledSatelliteImage
from astrovision.plot import make_mosaic
from matplotlib import pyplot as plt

from app.utils.data import get_file_system, get_filename_to_polygons, get_satellite_image
from app.utils.predict import make_batched_prediction
from app.utils.preprocess_image import get_transform
from app.utils.split_and_normalize import get_normalized_sis

os.environ["MLFLOW_MODEL_NAME"] = "Segmentation-multiclass"
os.environ["MLFLOW_MODEL_VERSION"] = "1"
os.environ["MLFLOW_TRACKING_URI"] = "https://projet-slums-detection-mlflow.user.lab.sspcloud.fr/"

try:
    del os.environ["AWS_SESSION_TOKEN"]
except KeyError:
    pass

%load_ext autoreload
%autoreload 2

In [None]:
fs = get_file_system()
dep = "MAYOTTE"
year = 2022
n_bands = 3

filename_table = get_filename_to_polygons(dep, year, fs)

roi = gpd.read_file(fs.open(f"projet-slums-detection/data-roi/{dep}.geojson", "rb"))
roi_polygon = roi.geometry.iloc[0]

images = filename_table.loc[
    filename_table.geometry.intersects(roi_polygon),
    "filename",
].tolist()

for idx in range(409, 410):
    image = images[idx]
    si = get_satellite_image(image, n_bands)
    print(si.bounds)
    print(si.crs)
    plt.imshow(np.transpose(si.array, (1, 2, 0)))
    plt.show()

In [None]:
model_name = os.environ["MLFLOW_MODEL_NAME"]
model_version = os.environ["MLFLOW_MODEL_VERSION"]

model = mlflow.pytorch.load_model(model_uri=f"models:/{model_name}/{model_version}")

In [None]:
tiles_size = 250
model.cuda().eval()
normalized_sis = []
tiles_size = 250
augment_size = 512
n_bands = 3
normalization_mean = [67.43142604916895, 83.8419884471921, 67.89098874405661]
normalization_std = [27.83435228788356, 30.8127535004091, 32.37272004596061]
sliding_window_split = True
overlap = 125
batch_size = 25

In [None]:
transform = get_transform(tiles_size, augment_size, n_bands, normalization_mean, normalization_std)

normalized_sis_tensor, si_splitted = get_normalized_sis(
    image=image,
    n_bands=n_bands,
    tiles_size=tiles_size,
    normalization_mean=normalization_mean,
    transform=transform,
    sliding_window_split=sliding_window_split,
    overlap=overlap,
)

In [None]:
prediction = make_batched_prediction(
    normalized_si=normalized_sis_tensor.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu")),
    model=model,
    tiles_size=tiles_size,
    batch_size=batch_size,
)  # already softmaxed

In [None]:
lsi_splitted = [SegmentationLabeledSatelliteImage(si_splitted[i], prediction[i], logits=True) for i in range(len(si_splitted))]
lsi = make_mosaic(lsi_splitted, [i for i in range(n_bands)])  # get back to full image

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(np.transpose(lsi.satellite_image.array, (1, 2, 0))[:, :, list(range(n_bands))])
ax.imshow(lsi.label[1], alpha=0.5, cmap="jet", vmin=0, vmax=lsi.label.max(), interpolation="none")