### **For given aoi, run Deforestation model(Unet-Diff) inference**

In [1]:
import io
import json
import logging
import os
import re
import tempfile
import urllib
import uuid
import warnings
from datetime import datetime, timedelta
from os.path import join
from pathlib import Path

import cv2
import geojson
import geopandas
import geopandas as gp
import geopandas as gpd
import imageio
import numpy as np
import pandas as pd
import pyproj
import rasterio
import rasterio.mask
import segmentation_models_pytorch as smp
import shapely
import torch
import torch.backends.cudnn as cudnn
from geojson import Feature
from geopandas import GeoSeries
from google.oauth2 import service_account
from googleapiclient.discovery import build
from googleapiclient.http import MediaIoBaseDownload
from rasterio import Affine, features
from rasterio.mask import mask as riomask
from rasterio.merge import merge
from rasterio.plot import reshape_as_image, reshape_as_raster
from rasterio.warp import Resampling, calculate_default_transform, reproject
from rasterio.windows import Window
from scipy import spatial
from sentinel2download.downloader import Sentinel2Downloader, logger
from sentinel2download.overlap import Sentinel2Overlap
from shapely import wkt
from shapely.geometry import MultiPolygon, Polygon, box
from shapely.ops import transform, unary_union
from skimage.exposure import match_histograms
from torch import nn
from torchvision import transforms
from tqdm import tqdm

warnings.filterwarnings("ignore")

#Functions to create, load the model(Use them instead of Catalyst lib code)
def prepare_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def prepare_model(model):
    device = prepare_device()

    if torch.cuda.is_available():
        cudnn.benchmark = True

    if torch.cuda.device_count() > 1:
        raise ValueError("Multi GPU mode is not supported")
    else:
        model = model.to(device)

    return model, device


def load_model(network, model_weights_path):
    if network == "unet-diff":
        aux_params = dict(
            pooling="max",  # one of 'avg', 'max'
            dropout=0.1,  # dropout ratio, default is None
            activation="sigmoid",  # activation function, default is None
            classes=1,  # define number of output labels
        )
        model = smp.Unet(
            "resnet18",
            aux_params=aux_params,
            encoder_weights=None,
            in_channels=27,
            encoder_depth=2,
            decoder_channels=(256, 128),
        )

    else:
        raise ValueError("Unknown network")

    model, device = prepare_model(model)
    model.load_state_dict(
        torch.load(model_weights_path, map_location=torch.device(device))
    )
    return model, device

### 0. Setting up parameters
#### Read Input from environment and setup folders and filenames

In [2]:
REQUEST_ID = os.getenv("REQUEST_ID")
START_DATE = os.getenv("START_DATE")
END_DATE = os.getenv("END_DATE")
AOI = os.getenv("AOI")
SENTINEL2_GOOGLE_API_KEY = os.getenv("SENTINEL2_GOOGLE_API_KEY")
SATELLITE_CACHE_FOLDER = os.getenv("SENTINEL2_CACHE")

INPUT_FOLDER = os.path.dirname(SATELLITE_CACHE_FOLDER)
OUTPUT_FOLDER = os.getenv("OUTPUT_FOLDER")
PREPARED_DATA_FOLDER = os.path.join(INPUT_FOLDER, "prepared")
CODE_FOLDER = "/code"

LANDCOVER_POLYGONS_PATH = os.path.join(CODE_FOLDER, "data", "landcovers")
LANDCOVER_FILENAME = (
    "S2A_OPER_GIP_TILPAR_MPC__20151209T095117_V20150622T000000_21000101T000000_B00.kml"
)
SENTINEL_TILES = os.path.join(LANDCOVER_POLYGONS_PATH, LANDCOVER_FILENAME)

SCOPES = ["https://www.googleapis.com/auth/drive.file"]

CLOUD_DATA_FOLDER = os.path.join(CODE_FOLDER, "data", "clouds")
MODEL_PATH = os.path.join(CODE_FOLDER, "models", "unet_diff.pth")

LOAD_DIR = SATELLITE_CACHE_FOLDER

PRODUCT_TYPE = "L1C"
BANDS = {"TCI", "B01", "B02", "B04", "B05", "B08", "B8A", "B09", "B10", "B11", "B12"}
CONSTRAINTS = {
    "NODATA_PIXEL_PERCENTAGE": 1,
    "CLOUDY_PIXEL_PERCENTAGE": 15.0,
}
CLOUDS_PROBABILITY_THRESHOLD = 1
REMOVE_OTHER_DATES = True


MAX_SHIFT_ITERS = 2
MAX_SHIFT = 30

In [3]:
os.makedirs(LANDCOVER_POLYGONS_PATH, exist_ok=True)
os.makedirs(PREPARED_DATA_FOLDER, exist_ok=True)
os.makedirs(CLOUD_DATA_FOLDER, exist_ok=True)

In [4]:
START_DATE, END_DATE

('2020-05-01', '2020-06-30')

### 1. Transform AOI to GeoJSON file and then download data from Sentinel L1C

In [5]:
aoi = gp.GeoDataFrame(geometry=[wkt.loads(AOI)], crs="epsg:4326")
aoi_filename = "provided_aoi.geojson"
aoi.to_file(aoi_filename, driver="GeoJSON")

In [6]:
s2overlap = Sentinel2Overlap(aoi_path=aoi_filename)
overlap_tiles = s2overlap.overlap_with_geometry()
landcover_tiles = set(overlap_tiles.Name.apply(lambda x: x[:3]).to_list())

In [7]:
def shift_date(date, delta=5, format="%Y-%m-%d"):
    date = datetime.strptime(date, format)
    date = date - timedelta(days=delta)
    return datetime.strftime(date, format)


def diff_date(date_a, date_b, format="%Y-%m-%d"):
    date_a, date_b = datetime.strptime(date_a, format), datetime.strptime(
        date_b, format
    )
    delta = date_a - date_b
    return delta

In [8]:
start_timestamp, end_timestamp = datetime.strptime(
    START_DATE, "%Y-%m-%d"
), datetime.strptime(END_DATE, "%Y-%m-%d")
year = ((end_timestamp - start_timestamp) / 2 + start_timestamp).year
year

2020

In [9]:
def load_images(tiles, start_date, end_date):
    loader = Sentinel2Downloader(SENTINEL2_GOOGLE_API_KEY)
    loadings = dict()

    for tile in tiles:
        start = start_date
        end = end_date

        print(f"Loading images for tile: {tile}...")
        count = 0
        while count < MAX_SHIFT_ITERS:
            loaded = loader.download(
                PRODUCT_TYPE,
                [tile],
                start_date=start,
                end_date=end,
                output_dir=LOAD_DIR,
                bands=BANDS,
                constraints=CONSTRAINTS,
            )

            if not loaded:
                end = start_date
                start = shift_date(start_date, delta=MAX_SHIFT)
                print(
                    f"For tile: {tile} and dates {start_date} {end_date} proper images not found! Shift dates to {start} {end}!"
                )
            else:
                break
            count += 1
        if loaded:
            loadings[tile] = loaded
            print(f"Loading images for tile {tile} finished")
        else:
            print(f"Images for tile {tile} were not loaded!")

    # tile_folders = dict()
    # for tile, tile_paths in loadings.items():
    #    tile_folders[tile] = {str(Path(tile_path[0]).parent) for tile_path in tile_paths}
    return loadings

In [10]:
loadings_start_date = load_images(
    overlap_tiles.Name.values, shift_date(START_DATE, delta=10), START_DATE
)
loadings_end_date = load_images(
    overlap_tiles.Name.values, shift_date(END_DATE, delta=10), END_DATE
)

Loading images for tile: 16TFK...
Loading images for tile 16TFK finished
Loading images for tile: 16TFK...
For tile: 16TFK and dates 2020-06-20 2020-06-30 proper images not found! Shift dates to 2020-05-21 2020-06-20!
Loading images for tile 16TFK finished


In [11]:
def filter_by_date(loadings, func=max, filtered=dict(), tag="start"):
    def _find_agg_date(folders, func=func):
        dates = list()
        for i, folder in enumerate(folders):
            search = re.search(r"_(\d+)T\d+_", str(folder))
            date = search.group(1)
            date = datetime.strptime(date, "%Y%m%d")
            dates.append(date)
        last_date = func(dates)
        last_date = datetime.strftime(last_date, "%Y%m%d")
        return last_date

    def _get_folder(files):
        return os.path.join("/", *files[0][0].split("/")[:-1])

    for tile, items in loadings.items():
        try:
            last_date = _find_agg_date(items)
            bands_paths = dict()
            for path, _ in items:
                if PRODUCT_TYPE == "L2A":
                    if last_date in path:
                        if "B8A_20m.jp2" in path:
                            bands_paths["B8A"] = path
                        if "B11_20m.jp2" in path:
                            bands_paths["B11"] = path
                        if "B04_10m.jp2" in path:
                            bands_paths["B04"] = path
                        if "B08_10m.jp2" in path:
                            bands_paths["B08"] = path
                        if "B12_20m.jp2" in path:
                            bands_paths["B12"] = path
                        if "TCI_10m.jp2" in path:
                            bands_paths["TCI"] = path
                        folder = _get_folder(items)
                elif PRODUCT_TYPE == "L1C":
                    if last_date in path:
                        if "B8A.jp2" in path:
                            bands_paths["B8A"] = path
                        if "B11.jp2" in path:
                            bands_paths["B11"] = path
                        if "B04.jp2" in path:
                            bands_paths["B04"] = path
                        if "B08.jp2" in path:
                            bands_paths["B08"] = path
                        if "B12.jp2" in path:
                            bands_paths["B12"] = path
                        if "TCI.jp2" in path:
                            bands_paths["TCI"] = path
                        folder_path = _get_folder(items)

            info_dict = {
                tag: dict(paths=bands_paths, date=last_date, folder=folder_path)
            }

            if tile in filtered.keys():
                filtered[tile].update(info_dict)
            else:
                filtered.update({tile: info_dict})

        except Exception as ex:
            print(f"Error for {tile}: {str(ex)}")
    return filtered

In [12]:
filtered = filter_by_date(loadings_start_date, func=max, tag="start")
filtered = filter_by_date(loadings_end_date, func=max, filtered=filtered, tag="end")

In [13]:
def get_tile_and_images_folders(fitered, idx=0):
    tile = sorted(list(filtered.keys()))[0]
    return tile, filtered[tile]["start"]["folder"], filtered[tile]["end"]["folder"]


tile, start_date_folder, end_date_folder = get_tile_and_images_folders(filtered)
start_date_folder, end_date_folder

('/input/SENTINEL2_CACHE/S2B_MSIL1C_20200424T161829_N0209_R040_T16TFK_20200424T195711',
 '/input/SENTINEL2_CACHE/S2A_MSIL1C_20200611T162901_N0209_R083_T16TFK_20200611T200748')

In [14]:
import shutil


def remove_not_used_dates(
    start_date_folder, end_date_folder, cache_dir=SATELLITE_CACHE_FOLDER
):
    start_split, end_split = start_date_folder.split("/"), end_date_folder.split("/")
    if cache_dir != os.path.join("/", *start_split[:-1]) or cache_dir != os.path.join(
        "/", *end_split[:-1]
    ):
        raise ValueError("cache_dir is not valid")
    used_dates = start_split[-1], end_split[-1]
    for folder in os.listdir(cache_dir):
        if folder not in used_dates:
            shutil.rmtree(os.path.join(cache_dir, folder))


# if REMOVE_OTHER_DATES:
#     remove_not_used_dates(start_date_folder, end_date_folder)

### 2. Preparing images (calculating ndmi ndvi, scaling, merging to tiff)

In [15]:
from time_dependent.data_prepare.prepare_tif import (
    get_ndmi,
    get_ndvi,
    merge,
    scale_img,
    search_band,
    to_tiff,
)


def prepare_data(data_folder, save_path):
    img_folder = data_folder

    tmp_file = data_folder.split("/")[-1]

    os.makedirs(save_path, exist_ok=True)
    save_file_merged = join(save_path, f"all_merged_{tmp_file}.tif")

    bands, band_names = ["TCI", "B08", "B8A", "B11", "B12"], []

    for band in bands:
        band_names.append(join(img_folder, search_band(band, img_folder, "jp2")))

    b4_name = join(img_folder, search_band("B04", img_folder, "jp2"))
    ndvi_name = join(img_folder, "ndvi")
    ndmi_name = join(img_folder, "ndmi")
    print("\nall bands are converting to *tif...\n")

    for band_name in band_names:
        print(band_name[-3:])
        if "B08" in band_name:
            b8_name = band_name
        if "B8A" in band_name:
            b8a_name = band_name
        if "B11" in band_name:
            b11_name = band_name
        to_tiff(f"{band_name}.jp2")

    to_tiff(f"{b4_name}.jp2")
    print("\nndvi band is processing...")

    get_ndvi(f"{b4_name}.tif", f"{b8_name}.tif", f"{ndvi_name}.tif")

    print("\nndmi band is processing...")

    get_ndmi(f"{b11_name}.tif", f"{b8a_name}.tif", f"{ndmi_name}.tif")

    band_names.append(ndvi_name)
    band_names.append(ndmi_name)

    bands.append("ndvi")
    bands.append("ndmi")

    print("\nall bands are scaling to 8-bit images...\n")
    band_names_scaled = []
    for band_name in band_names:
        print(band_name)
        scaled_name = scale_img(f"{band_name}.tif")
        band_names_scaled.append(scaled_name)

    print("\nall bands are being merged...\n")
    print(band_names_scaled)

    merge(save_file_merged, *band_names_scaled)

    for item in os.listdir(img_folder):
        if item.endswith(".tif"):
            os.remove(join(img_folder, item))
    return save_file_merged

In [16]:
os.makedirs(PREPARED_DATA_FOLDER, exist_ok=True)

start_date_merged_path = prepare_data(start_date_folder, PREPARED_DATA_FOLDER)
end_date_merged_path = prepare_data(end_date_folder, PREPARED_DATA_FOLDER)


all bands are converting to *tif...

TCI
Input file size is 10980, 10980
0...10...20...30...40...50...60...70...80...90...100 - done.
B08
Input file size is 10980, 10980
0...10...20...30...40...50...60...70...80...90...100 - done.
B8A
Input file size is 5490, 5490
0...10...20...30...40...50...60...70...80...90...100 - done.
B11
Input file size is 5490, 5490
0...10...20...30...40...50...60...70...80...90...100 - done.
B12
Input file size is 5490, 5490
0...10...20...30...40...50...60...70...80...90...100 - done.
Input file size is 10980, 10980
0...10...20...30...40...50...60...70...80...90...100 - done.

ndvi band is processing...
0 .. 10 .. 20 .. 30 .. 40 .. 50 .. 60 .. 70 .. 80 .. 90 .. 100 - Done

ndmi band is processing...
0 .. 10 .. 20 .. 30 .. 40 .. 50 .. 60 .. 70 .. 80 .. 90 .. 100 - Done

all bands are scaling to 8-bit images...

/input/SENTINEL2_CACHE/S2B_MSIL1C_20200424T161829_N0209_R040_T16TFK_20200424T195711/T16TFK_20200424T161829_TCI
Input file size is 10980, 10980
0...10..


Processing file     1 of     7,  0.000% completed in 0 minutes.
Filename: /input/SENTINEL2_CACHE/S2A_MSIL1C_20200611T162901_N0209_R083_T16TFK_20200611T200748/T16TFK_20200611T162901_TCI_scaled.tif
File Size: 10980x10980x3
Pixel Size: 10.000000 x -10.000000
UL:(600000.000000,4500000.000000)   LR:(709800.000000,4390200.000000)
Copy 0,0,10980,10980 to 0,0,10980,10980.
Copy 0,0,10980,10980 to 0,0,10980,10980.
Copy 0,0,10980,10980 to 0,0,10980,10980.

Processing file     2 of     7, 14.286% completed in 0 minutes.
Filename: /input/SENTINEL2_CACHE/S2A_MSIL1C_20200611T162901_N0209_R083_T16TFK_20200611T200748/T16TFK_20200611T162901_B08_scaled.tif
File Size: 10980x10980x1
Pixel Size: 10.000000 x -10.000000
UL:(600000.000000,4500000.000000)   LR:(709800.000000,4390200.000000)
Copy 0,0,10980,10980 to 0,0,10980,10980.

Processing file     3 of     7, 28.571% completed in 0 minutes.
Filename: /input/SENTINEL2_CACHE/S2A_MSIL1C_20200611T162901_N0209_R083_T16TFK_20200611T200748/T16TFK_20200611T162901_

### 3.1 Prepare clouds tif files for postprocessing step

In [17]:
from time_dependent.data_prepare.prepare_clouds import (
    detect_clouds,
    merge,
    search_band,
    to_tiff,
)


def prepare_clouds(data_folder, save_path):
    img_folder = data_folder
    tile_folder = data_folder.split("/")[-1]
    print(tile_folder)
    bands, band_names = [
        "B01",
        "B02",
        "B04",
        "B05",
        "B08",
        "B8A",
        "B09",
        "B10",
        "B11",
        "B12",
    ], []

    for band in bands:
        band_names.append(join(img_folder, search_band(band, img_folder, "jp2")))

    print("\nall bands are converting to *tif...\n")

    for band_name in band_names:
        print(band_name[-3:])
        to_tiff(f"{band_name}.jp2")

    print("\n all bands are being merged...\n")

    save_file_merged = join(save_path, f"{tile_folder}_full_merged.tif")
    merge(save_file_merged, *band_names)

    save_file_clouds = join(save_path, f"{tile_folder}_clouds.tiff")
    detect_clouds(save_file_merged, save_file_clouds)
    os.remove(save_file_merged)

    for item in os.listdir(img_folder):
        if item.endswith(".tif"):
            os.remove(join(img_folder, item))

    # os.system(f'rm {join(granule_folder, tile_folder, 'IMG_DATA')}*.jp2')
    print("\ntemp files have been deleted\n")

In [18]:
os.makedirs(CLOUD_DATA_FOLDER, exist_ok=True)

prepare_clouds(start_date_folder, CLOUD_DATA_FOLDER)
prepare_clouds(end_date_folder, CLOUD_DATA_FOLDER)

S2B_MSIL1C_20200424T161829_N0209_R040_T16TFK_20200424T195711

all bands are converting to *tif...

B01
Input file size is 1830, 1830
0...10...20...30...40...50...60...70...80...90...100 - done.
B02
Input file size is 10980, 10980
0...10...20...30...40...50...60...70...80...90...100 - done.
B04
Input file size is 10980, 10980
0...10...20...30...40...50...60...70...80...90...100 - done.
B05
Input file size is 5490, 5490
0...10...20...30...40...50...60...70...80...90...100 - done.
B08
Input file size is 10980, 10980
0...10...20...30...40...50...60...70...80...90...100 - done.
B8A
Input file size is 5490, 5490
0...10...20...30...40...50...60...70...80...90...100 - done.
B09
Input file size is 1830, 1830
0...10...20...30...40...50...60...70...80...90...100 - done.
B10
Input file size is 1830, 1830
0...10...20...30...40...50...60...70...80...90...100 - done.
B11
Input file size is 5490, 5490
0...10...20...30...40...50...60...70...80...90...100 - done.
B12
Input file size is 5490, 5490
0...10

predict.
resize.
save cloud.

temp files have been deleted



### 3.2 Preparing forest landcover data, also needed for postprocessing stage

In [19]:
class LandcoverPolygons:
    """
    LandcoverPolygon class to access forest polygons. Before usage,
    be sure that SENTINEL_TILES file is downloaded.
    SENTINEL_TILES_POLYGONS = 'https://sentinel.esa.int/documents/247904/1955685/S2A_OPER_GIP_TILPAR_MPC__20151209T095117_V20150622T000000_21000101T000000_B00.kml'

    :param tile: tile name (str), e.g. '36UYA'
    :param crs: coordinate system (str), e.g. 'EPSG:4326'

    :return polygons: list of forest polygons within a tile in CRS of a S2A image
    """

    def __init__(self, tile, crs, year, aoi):
        self.tile = tile
        self.crs = crs
        self.LANDCOVER_GEOJSON = prepare_landcover(
            year, [tile[:3]], LANDCOVER_POLYGONS_PATH, aoi
        )
        gpd.io.file.fiona.drvsupport.supported_drivers["KML"] = "rw"

    def get_polygon(self):
        polygon_path = os.path.join(LANDCOVER_POLYGONS_PATH, f"{self.tile}.geojson")
        logging.info(f"LANDCOVER_POLYGONS_PATH: {polygon_path}")
        if os.path.isfile(polygon_path):
            logging.info(f"{self.tile} forests polygons file exists.")
            polygons = gpd.read_file(polygon_path)
        else:
            logging.info(
                f"{self.tile} forests polygons file does not exist. Creating polygons..."
            )
            polygons = self.create_polygon()

        if len(polygons) > 0:
            polygons = polygons.to_crs(self.crs)
            polygons = list(polygons["geometry"])
        else:
            logging.info("No forests polygons.")
        return polygons

    def create_polygon(self):
        polygons = []
        if os.path.isfile(SENTINEL_TILES):
            logging.info(
                f"read forests_polygons_file: {SENTINEL_TILES}, for tile {self.tile}"
            )

            sentinel_tiles = gpd.read_file(SENTINEL_TILES, driver="KML")
            sentinel_tiles = sentinel_tiles[sentinel_tiles["Name"] == self.tile]

            logging.info(f"sentinel_tiles for {self.tile}: {sentinel_tiles}")

            bounding_polygon = sentinel_tiles["geometry"].values[0]
            polygons = gpd.read_file(self.LANDCOVER_GEOJSON)
            polygons = polygons[polygons["geometry"].intersects(bounding_polygon)]
            polygon_path = os.path.join(LANDCOVER_POLYGONS_PATH, f"{self.tile}.geojson")

            logging.info(f"forests_polygons_file_path: {polygon_path}")

            if polygons.empty:
                return polygons
            polygons.to_file(polygon_path, driver="GeoJSON")
        else:
            logging.error(f"{SENTINEL_TILES} doth not exists")
            raise FileNotFoundError
        return polygons


def landcover_annual(year, landcover_tiles, output_path, aoi):
    # landcover_classes = {
    #    1: "Water",
    #    2: "Trees",
    #    4: "Flooded vegetation",
    #    5: "Crops",
    #    7: "Built Area",
    #    8: "Bare ground",
    #   9: "Snow/Ice",
    #   10: "Clouds",
    #    11: "Rangeland"
    # }
    EPSG = "EPSG:4326"
    landcover_downloaded = []
    os.makedirs(output_path, exist_ok=True)

    for tile_i in landcover_tiles:
        tile_url = f"https://lulctimeseries.blob.core.windows.net/lulctimeseriespublic/lc{year}/{tile_i}_{year}0101-{year+1}0101.tif"
        path = f"{output_path}/{os.path.basename(tile_url)}"

        if not os.path.exists(output_path):
            os.mkdir(output_path)
        if not os.path.exists(path):
            urllib.request.urlretrieve(tile_url, path)
        else:
            print("File already exists")
        landcover_downloaded.append(path)

    crops = []
    for path in landcover_downloaded:
        src = rasterio.open(path, "r")
        src_crs = src.crs
        profile = src.profile
        aoi_crs = aoi.to_crs(src_crs)
        crop, transform = riomask(src, aoi_crs.geometry, all_touched=False, crop=True)
        profile["width"] = crop.shape[2]
        profile["height"] = crop.shape[1]
        profile["transform"] = transform
        crop_name = os.path.join(
            output_path, os.path.split(path)[1].split("_")[0] + "_crop.tif"
        )
        with rasterio.open(crop_name, "w", **profile, nbits=1) as dst:
            dst.write(np.where(crop == 2, 1, 0).astype(np.uint8))
        crops.append(crop_name)
    landcover_name = os.path.join(output_path, f"landcover{year}.tif")
    listToStr = " ".join(crops)
    os.system(
        " ".join(
            [
                f"gdalwarp --config GDAL_CACHEMAX 3000 -wm 3000 -t_srs {EPSG}",
                listToStr,
                landcover_name,
            ]
        )
    )
    print(f"{landcover_name} was merged")


def rescale(img, ratio):
    width = int(img.shape[1] * ratio)
    height = int(img.shape[0] * ratio)
    dim = (width, height)
    resized = cv2.resize(img, dim, interpolation=cv2.INTER_AREA)
    return resized


def mask_to_polygons(mask: np.ndarray, transform) -> MultiPolygon:
    """
    Converts raster mask to shapely MultiPolygon
    """
    polygons = []
    shapes = features.shapes(
        mask.astype(np.uint8), mask=(mask > 0), transform=transform
    )

    for shape, _ in shapes:
        polygons.append(shapely.geometry.shape(shape))

    polygons = MultiPolygon(polygons)
    if not polygons.is_valid:
        polygons = polygons.buffer(0)
        if polygons.type == "Polygon":
            polygons = MultiPolygon([polygons])
    return polygons


def create_landcover_gdf(landcover_dir, output_dir, filename):
    landcover_names = [
        name for name in os.listdir(landcover_dir) if name.endswith("_crop.tif")
    ]

    gdfs = []
    for name in tqdm(landcover_names):
        lc_fullpath = os.path.join(landcover_dir, name)
        print(lc_fullpath)
        with rasterio.open(lc_fullpath, "r") as src:
            data = src.read().squeeze()
            data = rescale(data, 0.5)
            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
            data = cv2.morphologyEx(data, cv2.MORPH_CLOSE, kernel)
            data = cv2.morphologyEx(data, cv2.MORPH_OPEN, kernel)
            data = rescale(data, 2)

            current_polygons = list(mask_to_polygons(data, src.transform))
            current_polygons = [poly for poly in current_polygons if poly.area]

            current_areas = [poly.area for poly in current_polygons]
            current_tilename = [name] * len(current_areas)

            crs = src.crs
            current_gdf = gpd.GeoDataFrame(
                {
                    "area": current_areas,
                    "names": current_tilename,
                    "geometry": current_polygons,
                },
                crs=crs,
            )
            current_gdf.to_crs("EPSG:4326", inplace=True)
            gdfs.append(current_gdf.copy())

    gdf = gpd.GeoDataFrame(pd.concat(gdfs, ignore_index=True), crs="EPSG:4326")

    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f"{filename}.geojson")
    gdf.to_file(output_path, driver="GeoJSON")
    return output_path


def prepare_landcover(year, landcover_tiles, output_path, aoi):
    filename = f"{landcover_tiles[0]}_{year}_landcover"

    for file in os.listdir(output_path):
        if file != LANDCOVER_FILENAME and not file.endswith("_landcover.geojson"):
            os.remove(os.path.join(output_path, file))

    for file in os.listdir(output_path):
        if filename in file:
            return os.path.join(output_path, file)

    landcover_annual(year, landcover_tiles, output_path, aoi)
    landcover_geojson = create_landcover_gdf(output_path, output_path, filename)
    return landcover_geojson


def weights_exists_or_download(path, file_id):
    if not Path(path).exists():
        creds_file = os.environ.get("CREDENTIAL_FILE")
        creds = service_account.Credentials.from_service_account_file(
            creds_file, scopes=SCOPES
        )

        service = build("drive", "v3", credentials=creds)
        request = service.files().get_media(fileId=file_id)

        fh = io.FileIO("unet_v4.pth", mode="wb")
        downloader = MediaIoBaseDownload(fh, request)
        done = False
        while done is False:
            status, done = downloader.next_chunk()
            print(f"Download {int(status.progress() * 100)}")

    return path

### 4. Unet-diff inference on the prepared data and then postprocessing on cloud and forest landcover data

In [20]:
CLOUDS_PROBABILITY_THRESHOLD = 15
NEAREST_POLYGONS_NUMBER = 10
DATES_FOR_TILE = 2


os.environ.get("CUDA_VISIBLE_DEVICES", "0")

logging.basicConfig(format="%(asctime)s %(message)s")


def predict(image_tensor, model, channels, neighbours, size, device):
    image_shape = 1, count_channels(channels) * neighbours, size, size
    prediction, _ = model.predict(
        image_tensor.view(image_shape).to(device, dtype=torch.float)
    )
    result = prediction.view(size, size).detach().cpu().numpy()
    return result


def diff(img1, img2):
    img2 = match_histograms(img2, img1, multichannel=True)
    difference = (img1 - img2) / (img1 + img2)
    difference = (difference + 1) * 127
    return np.concatenate(
        (difference.astype(np.uint8), img1.astype(np.uint8), img2.astype(np.uint8)),
        axis=-1,
    )


def mask_postprocess(mask):
    kernel = np.ones((3, 3), np.uint8)
    erosion = cv2.erode(mask, kernel, iterations=1)
    kernel = np.ones((5, 5), np.uint8)
    closing = cv2.morphologyEx(erosion, cv2.MORPH_CLOSE, kernel)
    return closing


def predict_raster(
    img_current,
    img_previous,
    channels,
    network="unet-diff",
    model_weights_path="/code/models/unet_diff.pth",
    input_size=56,
    neighbours=3,
):
    model, device = load_model(network, model_weights_path)

    with rasterio.open(img_current) as source_current, rasterio.open(
        img_previous
    ) as source_previous:
        meta = source_current.meta
        meta["count"] = 1

        clearcut_mask = np.zeros((source_current.height, source_current.width))
        for i in tqdm(range(source_current.width // input_size)):
            for j in range(source_current.height // input_size):
                bottom_row = j * input_size
                upper_row = (j + 1) * input_size
                left_column = i * input_size
                right_column = (i + 1) * input_size

                corners = [
                    source_current.xy(bottom_row, left_column),
                    source_current.xy(bottom_row, right_column),
                    source_current.xy(upper_row, right_column),
                    source_current.xy(upper_row, left_column),
                    source_current.xy(bottom_row, left_column),
                ]

                window = Window(bottom_row, left_column, input_size, input_size)
                image_current = reshape_as_image(source_current.read(window=window))
                image_previous = reshape_as_image(source_previous.read(window=window))

                difference_image = diff(image_current, image_previous)
                image_tensor = transforms.ToTensor()(
                    difference_image.astype(np.uint8)
                ).to(device, dtype=torch.float)

                predicted = predict(
                    image_tensor, model, channels, neighbours, input_size, device
                )
                predicted = mask_postprocess(predicted)
                clearcut_mask[
                    left_column:right_column, bottom_row:upper_row
                ] += predicted
    meta["dtype"] = "float32"
    return clearcut_mask.astype(np.float32), meta


def count_channels(channels):
    count = 0
    for ch in channels:
        ch = ch.lower()
        if ch == "rgb":
            count += 3
        elif ch in ["ndvi", "ndmi", "b08", "b8a", "b11", "b12"]:
            count += 1
        else:
            raise Exception("{} channel is unknown!".format(ch))

    return count


def scale(tensor, max_value):
    max_ = tensor.max()
    if max_ > 0:
        return tensor / max_ * max_value
    return tensor


def save_raster(raster_array, meta, save_path, filename):
    if not os.path.exists(save_path):
        os.makedirs(save_path, exist_ok=True)
        logging.info("Data directory created.")

    filename = filename.split("_all_merged")[-1]
    save_path = os.path.join(save_path, f"predicted_{filename}")

    cv2.imwrite(f"{save_path}.png", raster_array)

    with rasterio.open(f"{save_path}.tif", "w", **meta) as dst:
        for i in range(1, meta["count"] + 1):
            dst.write(raster_array, i)


def polygonize(raster_array, meta, transform=True, mode=cv2.RETR_TREE):
    raster_array = (raster_array * 255).astype(np.uint8)

    contours, _ = cv2.findContours(raster_array, mode, cv2.CHAIN_APPROX_SIMPLE)

    polygons = []
    for i in tqdm(range(len(contours))):
        c = contours[i]
        n_s = (c.shape[0], c.shape[2])
        if n_s[0] > 2:
            if transform:
                polys = [tuple(i) * meta["transform"] for i in c.reshape(n_s)]
            else:
                polys = [tuple(i) for i in c.reshape(n_s)]
            polygons.append(Polygon(polys))

    return polygons


def save_polygons(polygons, save_path, filename):
    if len(polygons) == 0:
        logging.info("no_polygons detected")
        return

    if not os.path.exists(save_path):
        os.makedirs(save_path, exist_ok=True)
        logging.info("Data directory created.")

    logging.info(f"{filename} saved.")
    print(f"{filename} saved.")
    polygons.to_file(os.path.join(save_path, f"{filename}.geojson"), driver="GeoJSON")


def intersection_poly(test_poly, mask_poly):
    intersecion_score = False
    if test_poly.is_valid and mask_poly.is_valid:
        intersection_result = test_poly.intersection(mask_poly)
        if not intersection_result.is_valid:
            intersection_result = intersection_result.buffer(0)
        if not intersection_result.is_empty:
            intersecion_score = True
    return intersecion_score


def morphological_transform(img):
    kernel = np.ones((5, 5), np.uint8)
    closing = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel)

    kernel = np.ones((3, 3), np.uint8)
    closing = cv2.dilate(closing, kernel, iterations=1)
    return closing


def postprocessing(tile, cloud_files, clearcuts, src_crs, year, aoi):
    def get_intersected_polygons(polygons, masks, mask_column_name):
        """Finding in GeoDataFrame with clearcuts the masked polygons.

        :param polygons: GeoDataFrame with clearcuts and mask columns
        :param masks: list of masks (e.g., polygons of clouds)
        :param mask_column_name: name of mask column in polygons GeoDataFrame

        :return: GeoDataFrame with filled mask flags in corresponding column
        """
        masked_values = []
        if len(masks) > 0:
            centroids = [[mask.centroid.x, mask.centroid.y] for mask in masks]
            kdtree = spatial.KDTree(centroids)
            for _, clearcut in polygons.iterrows():
                polygon = clearcut["geometry"]
                _, idxs = kdtree.query(polygon.centroid, k=NEAREST_POLYGONS_NUMBER)
                masked_value = 0
                for idx in idxs:
                    if idx >= len(masks):
                        break
                    if intersection_poly(polygon, masks[idx].buffer(0)):
                        masked_value = 1
                        break
                masked_values.append(masked_value)
        polygons[mask_column_name] = masked_values
        return polygons

    landcover = LandcoverPolygons(tile, src_crs, year, aoi)
    forest_polygons = landcover.get_polygon()

    #     cloud_files = [f"{img_path}/{tile}_{i}/clouds.tiff" for i in range(DATES_FOR_TILE)]
    cloud_polygons = []
    for cloud_file in cloud_files:
        with rasterio.open(cloud_file) as src:
            clouds = src.read(1)
            meta = src.meta
        clouds = morphological_transform(clouds)
        clouds = (clouds > CLOUDS_PROBABILITY_THRESHOLD).astype(np.uint8)
        if clouds.sum() > 0:
            cloud_polygons.extend(polygonize(clouds, meta, mode=cv2.RETR_LIST))

    n_clearcuts = len(clearcuts)
    polygons = {
        "geometry": clearcuts,
        "forest": np.zeros(n_clearcuts),
        "clouds": np.zeros(n_clearcuts),
    }

    polygons = geopandas.GeoDataFrame(polygons, crs=src_crs)

    if len(cloud_polygons) > 0:
        polygons = get_intersected_polygons(polygons, cloud_polygons, "clouds")
    else:
        print('Clouds with specified CLOUDS_PROBABILITY_THRESHOLD not found')
    polygons = get_intersected_polygons(polygons, forest_polygons, "forest")
    return polygons

In [21]:
def deforestation_inference(
    img_start_path,
    img_end_path,
    tile,
    network="unet-diff",
    model_weights_path=MODEL_PATH,
    save_path=OUTPUT_FOLDER,
    channels=["RGB", "B08", "B8A", "B11", "B12", "NDVI", "NDMI"],
    threshold=0.4,
    polygonize_only=False,
):
    filename = img_start_path.split("/")[-1].split(".")[0]
    predicted_filename = f"predicted_{filename}"

    if not polygonize_only:
        raster_array, meta = predict_raster(
            img_start_path, img_end_path, channels, network, model_weights_path
        )
        save_raster(raster_array, meta, save_path, filename)
    else:
        with rasterio.open(os.path.join(save_path, f"{predicted_filename}.tif")) as src:
            raster_array = src.read()
            raster_array = np.moveaxis(raster_array, 0, -1)
            meta = src.meta
            src.close()

    logging.info("Polygonize raster array of clearcuts...")
    clearcuts = polygonize(raster_array > threshold, meta)
    logging.info("Filter polygons of clearcuts")
    polygons = postprocessing(
        tile,
        [
            os.path.join(CLOUD_DATA_FOLDER, file)
            for file in os.listdir(CLOUD_DATA_FOLDER)
        ],
        clearcuts,
        meta["crs"],
        year,
        aoi,
    )

    save_polygons(polygons, save_path, predicted_filename)

In [22]:
deforestation_inference(start_date_merged_path, end_date_merged_path, tile)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [18:04<00:00,  5.53s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4243/4243 [00:00<00:00, 12437.42it/s]


Copying color table from /code/data/landcovers/16T_crop.tif to new file.
Creating output file that is 1853P x 1060L.
Processing input file /code/data/landcovers/16T_crop.tif.
Using internal nodata values (e.g. 0) for image /code/data/landcovers/16T_crop.tif.
Copying nodata values from source /code/data/landcovers/16T_crop.tif to destination /code/data/landcovers/landcover2020.tif.
0...10...20...30...40...50...60...70...80...90...100 - done.
/code/data/landcovers/landcover2020.tif was merged


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.68it/s]

/code/data/landcovers/16T_crop.tif





predicted_all_merged_S2B_MSIL1C_20200424T161829_N0209_R040_T16TFK_20200424T195711 saved.
