In [1]:
import numpy as np
import os
from collections import defaultdict
import numpy as np
from sklearn.metrics import mean_squared_error

from tqdm.autonotebook import tqdm

  from tqdm.autonotebook import tqdm


In [2]:
from dataset import SeqGreenEarthNetDataset, ADDITIONAL_INFO_DICT

In [3]:
os.makedirs('example_predictions', exist_ok=True)

In [None]:
info_list = list(ADDITIONAL_INFO_DICT.keys())

ds_train = SeqGreenEarthNetDataset(
    folder="example_preprocessed_dataset/",
    input_channels=["red", "green", "blue"],
    target_channels=["ndvi", "class"],
    additional_info_list=info_list,
    time=True,
    use_mask=True,
    return_filename=True
)

In [5]:
# Make predictions
for batch in tqdm(ds_train):
    print(batch["filename"])
    pred =  np.zeros_like(batch["inputs"][1:2, 0])
    np.savez_compressed("example_predictions/" + str(batch["filename"].name), pred=pred)


  0%|          | 0/4 [00:00<?, ?it/s]

example_preprocessed_dataset/JAS20_minicube_164_33UVR_50.25_14.71_15.npz
example_preprocessed_dataset/JAS20_minicube_164_33UVR_50.25_14.71_10.npz
example_preprocessed_dataset/JAS20_minicube_164_33UVR_50.25_14.71_5.npz
example_preprocessed_dataset/JAS20_minicube_164_33UVR_50.25_14.71_1.npz


# Evaluation 

In [6]:
CLASSES2EVAL = [10, 30, 40] # Only evaluate on these classes

In [7]:
info_list = list(ADDITIONAL_INFO_DICT.keys())

ds_test = SeqGreenEarthNetDataset(
    folder="example_preprocessed_dataset/",
    input_channels=["red", "green", "blue"],
    target_channels=["evi", "class"],
    additional_info_list=info_list,
    time=True,
    use_mask=True,
    return_filename=True
)

In [8]:
class RMSEimagewise():
    def __init__(self, name):
        self.name = name
        self.rmse = defaultdict(lambda: np.array([]))  
        
    def update(self, class_idx, y_gt, y_pred):
        """Update RMSE values for a specific class index."""
        self.rmse[class_idx] = np.append(self.rmse[class_idx], mean_squared_error(y_gt, y_pred))

    def compute(self):
        """Compute the mean RMSE for all class indices."""
        rmse = {"name": self.name}
        for class_idx in self.rmse.keys():
            rmse[class_idx] = float(np.mean(self.rmse[class_idx]))
        return rmse
    
    def __repr__(self):
        """Print the computed RMSE values in a tabular format."""
        rmse = self.compute()
        output = f"{'Class':<10}{'RMSE':<10}\n"
        output += "-" * 20 + "\n"
        for class_idx in rmse.keys():
            if class_idx == "name":
                continue
            output += f"{class_idx:<10}{rmse[class_idx]:<10.4f}\n"
        return output

    def reset(self):
        """Reset the stored RMSE values."""
        self.rmse = defaultdict(lambda: np.array([]))

    def set_name(self, name):
        """Set a new name for the RMSE tracker."""
        self.name = name

    def get_class_rmse(self, class_idx):
        """Retrieve the RMSE values for a specific class index."""
        return self.rmse[class_idx] if class_idx in self.rmse else np.array([])

In [None]:
stats = RMSEimagewise("RMSE")

for idx, batch in enumerate(tqdm(ds_test)):

    pred = np.load("example_predictions/" + str(batch["filename"].name))["pred"]
    print(pred.shape)
    assert pred.shape == (1, 128, 128)
    evi = batch["targets"][0, 0:1]
    class_mask = batch["targets"][0, 1:2]

    for class_idx in CLASSES2EVAL:

        # NOTE: The evi channel should be in range [-1, 1], but due to the preprocessing and
        #  noise on cameras, it might not be the case -> we filter out invalid values (consider 
        # to adjust the evi channels in train phase)
        valid_mask = (class_mask == class_idx) & (~np.isnan(evi)) & (evi >= -1) & (evi <= 1)
 
        gt = evi[valid_mask]
        pred_act = pred[valid_mask]
    
        if len(gt) == 0:
            continue

        stats.update(class_idx, gt, pred_act)


stats



  0%|          | 0/4 [00:00<?, ?it/s]

Class     RMSE      
--------------------
10        0.3248    
30        0.2385    
40        0.1637    