In [1]:
import ipyleaflet
import ipywidgets as widgets
from IPython.display import display
import numpy as np
from distributed import Client
from joblib import parallel_backend
from dask.distributed import Client
from datacube.testutils.io import rio_slurp_xarray

import datacube
from datacube import Datacube
import xarray as xr
from joblib import load
import matplotlib.pyplot as plt
from datacube.utils.cog import write_cog

from deafrica_tools.datahandling import load_ard
from deafrica_tools.bandindices import calculate_indices
from deafrica_tools.dask import create_local_dask_cluster
from deafrica_tools.plotting import rgb, display_map
from deafrica_tools.classification import predict_xr

import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import seaborn as sns

import warnings
warnings.filterwarnings("ignore")

create_local_dask_cluster()

# Load training data and models
training_data = "L_training_data(1).txt"
S_model_path = 'S_model(1).joblib'
L_model_path = 'L_model(1).joblib'

landsat_model = load(L_model_path).set_params(n_jobs=1)
sentinel_model = load(S_model_path).set_params(n_jobs=1)

# Define buffer and other parameters
buffer = 0.1
dask_chunks = {'x': 1000, 'y': 1000}

# Connect to the existing cluster
client = Client()
client.cluster.scale(n=4)  # Increase the number of workers to 4

dc = datacube.Datacube(app='prediction')

# Load the data
model_input = np.loadtxt(training_data)

# Load the column_names
with open(training_data, 'r') as file:
    header = file.readline()

column_names = header.split()[2:]

# Create a map centered in the Western Cape Province
m = ipyleaflet.Map(center=[-34.00, 20.00], zoom=7)

# Output widget to display coordinates
output = widgets.Output()

# Button widget to clear the inputs
clear_button = widgets.Button(description="Clear", button_style="danger")

# Button widget to process the inputs
process_button = widgets.Button(description="Process", button_style="success")

# Create checkboxes for years from 1995 to 2022
years = [str(year) for year in range(1995, 2023)]
checkboxes = [widgets.Checkbox(value=False, description=year) for year in years]

# Arrange them in a vertical box
years_box = widgets.VBox(checkboxes)

# Create a horizontal box to position the map and the years box next to each other
hbox = widgets.HBox([m, years_box])

# Variable to store the selected coordinates
selected_coordinates = None

# Function to handle the click event and get the coordinates
def handle_interaction(**kwargs):
    global selected_coordinates
    if kwargs.get('type') == 'click':
        selected_coordinates = kwargs.get('coordinates')
        with output:
            output.clear_output()  # Clear previous output
            print(f"Coordinates selected: {selected_coordinates}")

# Attach the interaction to the map
m.on_interaction(handle_interaction)

# Function to clear the output and reset the map interactions
def clear_output(button):
    global selected_coordinates
    selected_coordinates = None
    output.clear_output()  # Clears the printed coordinates
    for checkbox in checkboxes:
        checkbox.value = False  # Reset checkbox values to False
    with output:
        print("All inputs cleared.")

# Attach the clear function to the clear button
clear_button.on_click(clear_output)

# Generalise Sentinel and Landsat bands
measurements =  ['blue','green','red','nir','swir_1','swir_2']

S_resolution = (-10, 10)
L_resolution = (-30, 30)

output_crs = 'epsg:6933'

# Function to get feature layers (included in previous messages)
def feature_layers(query, model):
    try:
        dc = datacube.Datacube(app='feature_layers')
        ds = dc.load(product=['gm_s2_annual', 'gm_ls5_ls7_annual', 'gm_ls8_ls9_annual'], **query)
        
        input_features = list(ds.data_vars)
        print(f"Available bands in dataset: {input_features}")

        # Determine satellite mission and calculate indices
        satellite_mission = 's2' if model == sentinel_model else 'ls'
        da = calculate_indices(ds, index=['NDVI', 'LAI', 'MNDWI'], drop=False, satellite_mission=satellite_mission)

        # Add slope data
        url_slope = "https://deafrica-input-datasets.s3.af-south-1.amazonaws.com/srtm_dem/srtm_africa_slope.tif"
        slope = rio_slurp_xarray(url_slope, gbox=ds.geobox)
        slope = slope.to_dataset(name='slope').chunk(dask_chunks)

        result = xr.merge([da, slope], compat='override')

        return result.squeeze()

    except ValueError as ve:
        print(f"ValueError occurred: {ve}")
        raise
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        raise
     
    
# Function to predict for a location
def predict_for_location(location_id, value, year):
    print(f'Working on location {location_id} for year {year}')
    
    if int(year) < 2017:
        model = landsat_model
        resolution = (-30, 30)
        print("Using Landsat model...")
    else:
        model = sentinel_model
        resolution = (-10, 10)
        print("Using Sentinel model...")
    
    query = {
        'x': (value[1] - buffer, value[1] + buffer),
        'y': (value[0] + buffer, value[0] - buffer),
        'time': year,
        'measurements': measurements,
        'resolution': resolution,
        'output_crs': output_crs,
        'dask_chunks': dask_chunks,
    }

    try:
        data = feature_layers(query, model)
        
        expected_features = measurements + ['NDVI', 'LAI', 'MNDWI', 'slope']
        data = data[expected_features]
        
        # Predict using the imported model
        with parallel_backend('dask', wait_for_workers_timeout=60):  # Increase timeout if necessary
            predicted = predict_xr(
                        model,
                        data,
                        proba=True,
                        persist=True,
                        clean=True,
                        return_input=True
                        ).compute()
        
        return predicted
    
    except ValueError as e:
        print(f"Error predicting for location {location_id} for year {year}: {e}")
    except Exception as ex:
        print(f"An unexpected error occurred for location {location_id} for year {year}: {ex}")


# Function to process the selected inputs
def process_inputs(button):
    if selected_coordinates:
        selected_years = [year.description for year in checkboxes if year.value]
        with output:
            output.clear_output()  # Clears previous output
            print(f"Processing data for coordinates: {selected_coordinates}")
            print(f"Selected years: {selected_years}")
            
            # Run predictions for the selected coordinates and years
            predictions = []
            for year in selected_years:
                prediction = predict_for_location("user_selected_location", selected_coordinates, year)
                if prediction is not None:
                    predictions.append((selected_coordinates, year, prediction))
            
            # Plot results
            import matplotlib.colors as mcolors

            # Define class labels and their corresponding colors
            class_labels = ['Water', 'Forestry', 'Urban', 'Barren', 'Crop', 'Other']
            class_colors = {
                'Water': 'blue',
                'Forestry': 'green',
                'Urban': 'red',
                'Barren': 'yellow',
                'Crop': 'purple',
                'Other': 'grey'
            }

            # Create a colormap and norm for the classes
            cmap = mcolors.ListedColormap([class_colors[label] for label in class_labels])
            norm = mcolors.BoundaryNorm(range(len(class_labels) + 1), cmap.N)

            areas_per_class = {}

            # Iterate over the predictions and generate plots
            for i in range(len(predictions)):
                location_id, year, prediction_data = predictions[i]  # Unpack the tuple

                # Plot classified image, true color image, and probability map
                fig, axes = plt.subplots(1, 3, figsize=(24, 8))

                # Plot classified image with custom colors
                im1 = axes[0].imshow(prediction_data.Predictions, cmap=cmap, norm=norm)
                cbar1 = fig.colorbar(im1, ax=axes[0], ticks=range(len(class_labels)))
                cbar1.ax.set_yticklabels(class_labels)
                axes[0].set_title(f'Classified Image - Location {location_id} Year {year}')

                # Plot true color image
                rgb(prediction_data, bands=['red', 'green', 'blue'], ax=axes[1], percentile_stretch=(0.01, 0.99))
                axes[1].set_title('True Colour Geomedian')

                # Plot probabilities for each class
                prob_maps = prediction_data.Probabilities.plot(ax=axes[2], cmap='magma', vmin=0, vmax=100, add_labels=False, add_colorbar=True)

                # Handle probability map plotting
                if isinstance(prob_maps, plt.Axes):
                    cbar = fig.colorbar(prob_maps.collections[0], ax=axes[2])
                    cbar.set_label('Probability (%)')
                    axes[2].set_title('Probability Map')

                plt.tight_layout()
                plt.show()

                # Area calculation for each class
                if int(year) < 2017:
                    area_per_pixel = 900  # 30m x 30m (Landsat)
                else:
                    area_per_pixel = 100  # 10m x 10m (Sentinel)

                if year not in areas_per_class:
                    areas_per_class[year] = {label: 0 for label in class_labels}

                if isinstance(prediction_data, xr.Dataset):
                    classes = prediction_data.Predictions.values

                    # Count the area for each class
                    unique, counts = np.unique(classes, return_counts=True)
                    class_areas = dict(zip(unique, counts))

                    # Map class IDs to labels and accumulate areas
                    for class_id, count in class_areas.items():
                        class_label = class_labels[int(class_id)]
                        areas_per_class[year][class_label] += count * area_per_pixel

            # Convert to DataFrame for easy plotting
            df_areas = pd.DataFrame(areas_per_class).T

            # Plot the area covered by each class over time
            fig, ax = plt.subplots(figsize=(12, 8))
            df_areas.plot(kind='bar', stacked=True, color=[class_colors[label] for label in class_labels], ax=ax)

            ax.set_title('Area Covered by Each Class Over Time')
            ax.set_xlabel('Year')
            ax.set_ylabel('Area (square meters)')
            ax.legend(title='Classes', bbox_to_anchor=(1.05, 1), loc='upper left', labels=class_labels)

            for bars in ax.containers:
                ax.bar_label(bars, label_type='center', fmt='%.0f', fontsize=9, color='black')

            plt.tight_layout()
            plt.show()

            # Display DataFrame
            print("Areas covered per class:\n", df_areas)

            # Calculate differences between consecutive years
            # Ensure the DataFrame is correctly populated and sorted
            df_areas = df_areas.sort_index()

            # Debugging: Print the df_areas DataFrame to verify correctness
            print("df_areas DataFrame:\n", df_areas)

            # Calculate differences between consecutive years
            df_area_diff = df_areas.diff().dropna()

            # Debugging: Print the df_area_diff DataFrame to verify differences
            print("df_area_diff DataFrame:\n", df_area_diff)

            # Plotting the differences
            fig, ax = plt.subplots(figsize=(12, 8))

            # Using the same color scheme for consistency
            colors = [class_colors[label] for label in df_area_diff.columns]

            # Plotting as a grouped bar plot
            df_area_diff.plot(kind='bar', color=colors, ax=ax)

            # Customize the plot
            ax.set_title('Change in Area Covered by Each Class Between Years')
            ax.set_xlabel('Year')
            ax.set_ylabel('Change in Area (sq meters)')
            ax.legend(title='Class', bbox_to_anchor=(1.05, 1), loc='upper left')

            plt.tight_layout()
            plt.show()
            
            def compute_transition_matrix(predictions_year1, predictions_year2, class_labels):
                
                pred1_flat = predictions_year1.flatten() # Flatten predictions data
                pred2_flat = predictions_year2.flatten()
                
                # Number of classes
                num_classes = len(class_labels)
                
                # Initialize transition matrix with zeros
                transition_matrix = np.zeros((num_classes, num_classes), dtype=int)
                
                # Count transitions
                for i in range(len(pred1_flat)):
                    class_from = int(pred1_flat[i])
                    class_to = int(pred2_flat[i])
                    transition_matrix[class_from, class_to] += 1
                    
                return transition_matrix

            def normalize_transition_matrix(matrix):
                # Normalize transition matrix to get percentages
                row_sums = matrix.sum(axis=1, keepdims=True)
                percentage_matrix = np.divide(matrix, row_sums, out=np.zeros_like(matrix, dtype=float), where=row_sums!=0)
                return percentage_matrix * 100  # Convert to percentage

            def plot_transition_matrix(transition_matrix, class_labels, year_from, year_to):
                plt.figure(figsize=(10, 8))
                sns.heatmap(transition_matrix, annot=True, fmt='.2f', cmap='Blues', xticklabels=class_labels, yticklabels=class_labels)
                plt.title(f'Transition Matrix from {year_from} to {year_to}')
                plt.xlabel('Class in ' + str(year_to))
                plt.ylabel('Class in ' + str(year_from))
                plt.show()

            def aggregate_predictions(predictions_list):
                # Aggregate predictions from all locations
                aggregated_predictions = np.concatenate([pred['Predictions'].values.flatten() for pred in predictions_list])
                return aggregated_predictions

            # Predictions is a list of tuples, where each tuple is (location_id, year, prediction_data)
            # Organize predictions by year
            from collections import defaultdict

            predictions_by_year = defaultdict(list)
            for location_id, year, prediction_data in predictions:
                predictions_by_year[year].append(prediction_data)

            # Now process the transition matrices
            years = sorted(predictions_by_year.keys(), key=int)
            for i in range(len(years) - 1):
                year_from = years[i]
                year_to = years[i + 1]

                # Aggregate predictions for each year across all locations
                predictions_year_from = aggregate_predictions(predictions_by_year[year_from])
                predictions_year_to = aggregate_predictions(predictions_by_year[year_to])

                # Compute the transition matrix
                transition_matrix = compute_transition_matrix(predictions_year_from, predictions_year_to, class_labels)

                # Normalize the transition matrix to show percentages
                percentage_matrix = normalize_transition_matrix(transition_matrix)

                # Plot the transition matrix
                plot_transition_matrix(percentage_matrix, class_labels, year_from, year_to)


    else:
        with output:
            output.clear_output()  # Clears previous output
            print("Please select coordinates on the map before processing.")


# Attach the process function to the process button
process_button.on_click(process_inputs)

# Display the map, output, clear button, process button, and years checkbox
ui = widgets.VBox([hbox, output, clear_button, process_button])
display(ui)


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: /user/sakhilemkhize18@gmail.com/proxy/8787/status,

0,1
Dashboard: /user/sakhilemkhize18@gmail.com/proxy/8787/status,Workers: 1
Total threads: 4,Total memory: 26.21 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:36423,Workers: 1
Dashboard: /user/sakhilemkhize18@gmail.com/proxy/8787/status,Total threads: 4
Started: Just now,Total memory: 26.21 GiB

0,1
Comm: tcp://127.0.0.1:35683,Total threads: 4
Dashboard: /user/sakhilemkhize18@gmail.com/proxy/40789/status,Memory: 26.21 GiB
Nanny: tcp://127.0.0.1:36651,
Local directory: /tmp/dask-scratch-space/worker-pwy9v1g9,Local directory: /tmp/dask-scratch-space/worker-pwy9v1g9


VBox(children=(HBox(children=(Map(center=[-34.0, 20.0], controls=(ZoomControl(options=['position', 'zoom_in_te…