In this notebook, the model will be tested on all four test sets using the EarthNetScore. 

First, we use both RGB and NIR models to generate predictions for all test tracks:

In [1]:
import os
import csv
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from diffusers import StableDiffusionXLImg2ImgPipeline
import time
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class SatelliteImageTestDataset(Dataset):
    def __init__(self, csv_path, nir_image_dir, rgb_image_dir, max_minicubes=None):
        self.data = []
        self.nir_image_dir = nir_image_dir
        self.rgb_image_dir = rgb_image_dir
        
        self.transform = transforms.Compose([
            transforms.Resize((1024, 1024)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
        
        self.metadata = self.load_metadata(csv_path)
        logging.info(f"Loaded metadata for {len(self.metadata)} minicubes")
        
        minicube_data = {}
        minicube_count = 0
        total_images = 0
        usable_images_count = 0

        for minicube, images in self.metadata.items():
            total_images += len(images)
            usable_images = [img for img in images.values() if img['usable']]
            usable_images_count += len(usable_images)
            logging.info(f"Minicube {minicube}: {len(usable_images)} usable images out of {len(images)} total images")
            
            if len(usable_images) >= 30:  # We need at least 30 images (10 context + 20 target)
                sorted_images = sorted(usable_images, key=lambda x: int(x['number']))
                minicube_data[minicube] = sorted_images[9:30]  # 10th image (index 9) to 30th image
                minicube_count += 1
                if max_minicubes and minicube_count >= max_minicubes:
                    break
            else:
                logging.warning(f"Minicube {minicube} skipped: only {len(usable_images)} usable images")

        self.data = list(minicube_data.items())
        
        logging.info(f"Total images: {total_images}")
        logging.info(f"Total usable images: {usable_images_count}")
        logging.info(f"Total test data points: {len(self.data)}")
        logging.info(f"Number of minicubes loaded: {len(minicube_data)}")

        if len(self.data) == 0:
            logging.error("No valid data points found. Check your CSV file and image directories.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        minicube, images = self.data[idx]
        
        input_image = images[0]
        input_file = f"{input_image['full_identifier']}_{input_image['date']}.png"
        nir_input_image_tensor = self.load_image(input_file, is_nir=True)
        rgb_input_image_tensor = self.load_image(input_file, is_nir=False)
        
        return nir_input_image_tensor, rgb_input_image_tensor, minicube, input_image

    def load_image(self, filename, is_nir=True):
        if is_nir:
            filename = filename.replace('_NIR_', '_')
            path = os.path.join(self.nir_image_dir, filename)
        else:
            path = os.path.join(self.rgb_image_dir, filename)
        
        if os.path.exists(path):
            image = Image.open(path)
            if is_nir:
                image = image.convert("L")
                image = self.transform(image)
                image = image.repeat(3, 1, 1)
            else:
                image = image.convert("RGB")
                image = self.transform(image)
            return image
        else:
            raise FileNotFoundError(f"Image not found: {path}")

    def load_metadata(self, metadata_csv_path):
        metadata = {}
        row_count = 0
        with open(metadata_csv_path, 'r') as csvfile:
            reader = csv.DictReader(csvfile)
            for i, row in enumerate(reader):
                if i < 5:  # Print first 5 rows
                    logging.info(f"Sample row {i+1}: {row}")
                row_count += 1
                minicube = row['minicube']
                number = row['image_count']
                
                # Extract the full identifier from the filename
                full_identifier = '_'.join(row['file_name'].split('_')[:-1])  # Remove the date part
                
                if minicube not in metadata:
                    metadata[minicube] = {}
                
                try:
                    metadata[minicube][number] = {
                        'Temperature Avg': float(row['Temperature Avg']) if row['Temperature Avg'] else None,
                        'Sea-level Pressure': float(row['Sea-level Pressure']) if row['Sea-level Pressure'] else None,
                        'Shortwave Downwelling Radiation': float(row['Shortwave Downwelling Radiation']) if row['Shortwave Downwelling Radiation'] else None,
                        'Relative Humidity': float(row['Relative Humidity']) if row['Relative Humidity'] else None,
                        'region': row['region'],
                        'season': row['season'],
                        'number': row['image_count'],
                        'date': row['date'],
                        'usable': row['usable'].lower() != 'no',  # Consider all non-'no' values as usable
                        'full_identifier': full_identifier
                    }
                except ValueError as e:
                    logging.warning(f"Skipping row due to invalid data - {e}")
                    continue
        
        logging.info(f"Processed {row_count} rows from CSV file")
        return metadata

    def generate_prompt(self, minicube, input_number, target_number):
        input_data = self.get_metadata(minicube, input_number)
        target_data = self.get_metadata(minicube, target_number)
        
        if input_data is None or target_data is None:
            return "Orthophoto taken by a satellite of a nature region"
        
        differences = []
        for key in ['Temperature Avg', 'Sea-level Pressure', 'Shortwave Downwelling Radiation', 'Relative Humidity']:
            if input_data[key] is not None and target_data[key] is not None:
                if input_data[key] != 0:  # Add this check to avoid division by zero
                    diff = (target_data[key] - input_data[key]) / input_data[key] * 100  # Convert to percentage change
                    if abs(diff) > 5:  # Only include significant changes (5%)
                        if key == 'Temperature Avg':
                            differences.append(f"{'warmer' if diff > 0 else 'colder'} by {abs(diff):.2f}%")
                        elif key == 'Sea-level Pressure':
                            differences.append(f"{'higher pressure' if diff > 0 else 'lower pressure'} by {abs(diff):.2f}%")
                        elif key == 'Shortwave Downwelling Radiation':
                            differences.append(f"{'more solar radiation' if diff > 0 else 'less solar radiation'} by {abs(diff):.2f}%")
                        elif key == 'Relative Humidity':
                            differences.append(f"{'more humid' if diff > 0 else 'drier'} by {abs(diff):.2f}%")
                else:
                    # Handle the case where input_data[key] is zero
                    if target_data[key] > 0:
                        differences.append(f"increased {key}")
                    elif target_data[key] < 0:
                        differences.append(f"decreased {key}")
                    # If target_data[key] is also zero, we don't add any difference
        
        region = input_data.get('region', 'unknown region')
        season = input_data.get('season', 'unknown season')
        prompt = f"Orthophoto taken by a satellite of {region} in {season} that is " + ", ".join(differences)
        return prompt

    def get_metadata(self, minicube, number):
        if minicube in self.metadata and number in self.metadata[minicube]:
            return self.metadata[minicube][number]
        else:
            logging.warning(f"Metadata not found for minicube {minicube}, number {number}")
            return None

def check_and_resume_predictions(csv_path, predictions_dir, max_images=20):
    incomplete_minicubes = []
    
    # Load minicube data from CSV
    minicube_data = {}
    with open(csv_path, 'r') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            minicube = row['minicube']
            if minicube not in minicube_data:
                minicube_data[minicube] = []
            minicube_data[minicube].append(row)
    
    # Check each minicube
    for minicube, data in minicube_data.items():
        usable_images = [img for img in data if img['usable'].lower() != 'no']
        if len(usable_images) >= 30:  # We need at least 30 images (10 context + 20 target)
            nir_minicube_folder = os.path.join(predictions_dir, "NIR", minicube)
            rgb_minicube_folder = os.path.join(predictions_dir, "RGB", minicube)
            
            # Check if all predicted images exist
            all_predictions_exist = True
            for i in range(11, 31):  # Check images 11 to 30
                file_name = next((img['file_name'] for img in usable_images if int(img['image_count']) == i), None)
                if file_name:
                    nir_file_path = os.path.join(nir_minicube_folder, file_name)
                    rgb_file_path = os.path.join(rgb_minicube_folder, file_name)
                    if not (os.path.exists(nir_file_path) and os.path.exists(rgb_file_path)):
                        all_predictions_exist = False
                        break
            
            if not all_predictions_exist:
                incomplete_minicubes.append(minicube)
    
    return incomplete_minicubes

# Function to get file names for a minicube
def get_file_names(csv_path, minicube):
    file_names = {}
    with open(csv_path, 'r') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            if row['minicube'] == minicube and int(row['image_count']) >= 11:
                file_names[int(row['image_count'])] = row['file_name']
    return file_names

# Main execution
if __name__ == "__main__":
    # Base path and subfolders
    base_path = r"C:\TjallingData\greenearthnet_additional"
    subfolders = ['iid_chopped', 'ood-s_chopped', 'ood-st_chopped', 'ood-t_chopped']

    # Load the trained models
    nir_model_path = "model_checkpoint_NIR"
    rgb_model_path = "model_checkpoint-RGB"  
    nir_pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(nir_model_path, torch_dtype=torch.float16)
    rgb_pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(rgb_model_path, torch_dtype=torch.float16)
    nir_pipeline = nir_pipeline.to("cuda")
    rgb_pipeline = rgb_pipeline.to("cuda")

    # Hyperparameters
    strength = 0.3
    guidance_scale = 3.0
    num_inference_steps = 50
    max_minicubes = None  # Set to None or a number if you want to limit the number of minicubes

    for subfolder in subfolders:
        logging.info(f"\nProcessing subfolder: {subfolder}")

        # Paths for the current subfolder
        csv_path = os.path.join(base_path, subfolder, f"{subfolder}_combined_imputed_with_region_normalized_season_minicube_and_usable.csv")
        NIR_image_dir = os.path.join(base_path, subfolder, "NIR_total")
        RGB_image_dir = os.path.join(base_path, subfolder, "RGB_total")
        predictions_dir = os.path.join(base_path, subfolder, "predictions_trained_SDXL")

        # Check if files and directories exist
        if not os.path.exists(csv_path):
            logging.error(f"CSV file not found: {csv_path}")
            continue
        if not os.path.exists(NIR_image_dir):
            logging.error(f"NIR image directory not found: {NIR_image_dir}")
            continue
        if not os.path.exists(RGB_image_dir):
            logging.error(f"RGB image directory not found: {RGB_image_dir}")
            continue

        # Check for incomplete predictions
        incomplete_minicubes = check_and_resume_predictions(csv_path, predictions_dir)
        
        if incomplete_minicubes:
            logging.info(f"Found {len(incomplete_minicubes)} incomplete minicubes. Resuming predictions for these.")
        else:
            logging.info("All minicubes have complete predictions. Moving to next subfolder.")
            continue

        # Create dataset for the current subfolder
        test_dataset = SatelliteImageTestDataset(csv_path, NIR_image_dir, RGB_image_dir, max_minicubes=max_minicubes)
        
        if len(test_dataset) == 0:
            logging.warning(f"No valid data points found for subfolder: {subfolder}")
            continue

        test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

        # Prediction loop
        for batch_idx, (nir_input_image, rgb_input_image, minicube, input_data) in enumerate(test_dataloader):
            if minicube[0] not in incomplete_minicubes:
                continue  # Skip this minicube if it's already complete
            nir_input_image = nir_input_image.to("cuda")
            rgb_input_image = rgb_input_image.to("cuda")
            minicube = minicube[0]
            
            logging.info(f"\nProcessing batch {batch_idx+1}")
            logging.info(f"Minicube: {minicube}")
            logging.info(f"Starting with image number: {input_data['number'][0]}")
            logging.info(f"Input image date: {input_data['date'][0]}")
            
            # Create folders for this minicube
            nir_minicube_folder = os.path.join(predictions_dir, "NIR", minicube)
            rgb_minicube_folder = os.path.join(predictions_dir, "RGB", minicube)
            os.makedirs(nir_minicube_folder, exist_ok=True)
            os.makedirs(rgb_minicube_folder, exist_ok=True)
            
            # Get file names for this minicube
            file_names = get_file_names(csv_path, minicube)
            
            nir_current_input = nir_input_image
            rgb_current_input = rgb_input_image
            current_data = {k: v[0] for k, v in input_data.items()}

            for i in range(20):  # Generate 20 future images
                # Generate the prompt for the next image
                next_number = str(int(current_data['number']) + 1)
                prompt = test_dataset.generate_prompt(minicube, current_data['number'], next_number)
                
                logging.info(f"\nGenerating image {i+1}")
                logging.info(f"Current image number: {current_data['number']}")
                logging.info(f"Next image number: {next_number}")
                logging.info(f"Prompt: {prompt}")
                
                # Generate NIR image
                start_time = time.time()
                nir_generated_image = nir_pipeline(
                    prompt=prompt,
                    image=nir_current_input,
                    strength=strength,
                    guidance_scale=guidance_scale,
                    num_inference_steps=num_inference_steps
                ).images[0]
                end_time = time.time()
                logging.info(f"NIR image generation took {end_time - start_time:.2f} seconds")
                
                # Generate RGB image
                start_time = time.time()
                rgb_generated_image = rgb_pipeline(
                    prompt=prompt,
                    image=rgb_current_input,
                    strength=strength,
                    guidance_scale=guidance_scale,
                    num_inference_steps=num_inference_steps
                ).images[0]
                end_time = time.time()
                logging.info(f"RGB image generation took {end_time - start_time:.2f} seconds")
                
                # Save the generated images immediately
                image_number = i + 11  # Start from image 11
                if image_number in file_names:
                    file_name = file_names[image_number]
                    nir_save_path = os.path.join(nir_minicube_folder, file_name)
                    rgb_save_path = os.path.join(rgb_minicube_folder, file_name)
                    try:
                        nir_generated_image.save(nir_save_path)
                        rgb_generated_image.save(rgb_save_path)
                        logging.info(f"Saved NIR image: {nir_save_path}")
                        logging.info(f"Saved RGB image: {rgb_save_path}")
                    except Exception as e:
                        logging.error(f"Error saving images: {str(e)}")
                else:
                    logging.warning(f"No file name found for image number {image_number}")
                
                # Prepare the generated images for the
                
                # Prepare the generated images for the next iteration
                nir_current_input = transforms.ToTensor()(nir_generated_image).unsqueeze(0).to("cuda")
                rgb_current_input = transforms.ToTensor()(rgb_generated_image).unsqueeze(0).to("cuda")
                
                # Update current_data for the next iteration
                current_data = test_dataset.get_metadata(minicube, next_number)
                if current_data is None:
                    logging.warning(f"No metadata found for image number {next_number}. Stopping generation.")
                    break
                
                logging.info(f"Generated image pair {i+1} completed")
            
            logging.info(f"\nProcessed batch {batch_idx+1}, Input number: {input_data['number'][0]}, Generated 20 future image pairs")

    logging.info("Prediction completed for all test samples")

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
A matching Triton is not available, some optimizations will not be enabled
Traceback (most recent call last):
  File "C:\Users\r0902260\AppData\Roaming\Python\Python311\site-packages\xformers\__init__.py", line 55, in _is_triton_available
    from xformers.triton.softmax import softmax as triton_softmax  # noqa
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\r0902260\AppData\Roaming\Python\Python311\site-packages\xformers\triton\softmax.py", line 11, in <module>
    import triton
ModuleNotFoundError: No module named 'triton'


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

KeyboardInterrupt: 

Afterwards, we can calculate the EarthNetScore for each testing track. I manually implemented the calculation code, since I utilized the GreenEarthNet2021 dataset, and the EarthNetScore API provided by the EarthNetChallenge is exclusive for the original EarthNet2021 dataset. Notable changes between these datasets are a different cloud mask, and a different format (.npz -> .nc).

Before calculations, all required images are loaded, normalized, and reshaped to matching resolutions.

First, the NDVI scores are calculated for all 20 prediction dates, using both the RGB and NIR image. This results in two 4D arrays: pred_ndvi and target_ndvi, both consisting of [Date , Height of images, Width of images, 1]. On these, the  four subscores MAD, OLS, EMD, and SSIM are then calculated. These are calculated once for the entire array, not per image.

1. MAD calculates the absolute difference between all corresponding pixels in the pred_ndvi and target_ndivi,
2. OLS fits a linear trend to each pixel's time series in both arrays and compares the slopes of both,
3. EMD calcuates the Wasserstein distance between the two arrays,
4. and SSIM computes the structural similarity as a whole between the two arrays.


The ENS is then calculated as an average of these scores.

In [17]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from scipy.stats import wasserstein_distance
from skimage.metrics import structural_similarity as ssim
from skimage.transform import resize

'''
THe predictions are only created on 'usable' minicubes. However, should i add an extra layer of checking here?
'''

def load_images(directory, file_names, target_size=None):
    images = []
    for file_name in file_names:
        img_path = os.path.join(directory, file_name)
        print(f"Loading image: {img_path}")
        img = np.array(Image.open(img_path)) / 255.0  # Normalize to [0, 1]
        img[np.isnan(img)] = 0
        img = np.clip(img, 0, 1)
        if target_size is not None and img.shape[:2] != target_size:
            print(f"Resizing image from {img.shape[:2]} to {target_size}")
            img = resize(img, target_size + (img.shape[2],), anti_aliasing=True, preserve_range=True)
        images.append(img)
    return np.array(images)

def calculate_ndvi(nir, rgb):
    epsilon = 1e-8
    numerator = nir[:,:,:,0] - rgb[:,:,:,2]
    denominator = nir[:,:,:,0] + rgb[:,:,:,2] + epsilon
    ndvi = np.divide(numerator, denominator, out=np.zeros_like(numerator), where=denominator!=0)
    return ndvi[:,:,:,np.newaxis]

def calculate_mad(preds, targs):
    dists = np.abs(preds - targs)
    print(f"MAD dists shape: {dists.shape}, min: {np.min(dists)}, max: {np.max(dists)}")
    print(f"MAD dists NaN count: {np.isnan(dists).sum()}")
    
    scaling_factor = 0.06649346971087526
    dists = dists.astype(np.float64)
    scaled_dists = dists ** scaling_factor
    print(f"MAD scaled_dists min: {np.min(scaled_dists)}, max: {np.max(scaled_dists)}")
    print(f"MAD scaled_dists NaN count: {np.isnan(scaled_dists).sum()}")
    
    distmedian = np.nanmedian(scaled_dists)
    mad = max(0, min(1, 1-distmedian)) if distmedian is not None else None
    print(f"MAD distmedian: {distmedian}, mad: {mad}")
    
    return mad

def calculate_ols(preds, targs):
    try:
        h, w, c, t = preds.shape
        preds = preds.reshape(h*w*c, t)
        targs = targs.reshape(h*w*c, t)
        
        A = np.vstack([np.linspace(1, t, t), np.ones(t)]).T
        
        btarg = np.linalg.lstsq(A, targs.T, rcond=None)[0]
        bpred = np.linalg.lstsq(A, preds.T, rcond=None)[0]
        
        dists = np.abs(btarg[0] - bpred[0]) / 2
        print(f"OLS dists shape: {dists.shape}, min: {np.min(dists)}, max: {np.max(dists)}")
        print(f"OLS dists NaN count: {np.isnan(dists).sum()}")
        
        scaling_factor = 0.10082047548620601
        scaled_dists = dists ** scaling_factor
        print(f"OLS scaled_dists min: {np.min(scaled_dists)}, max: {np.max(scaled_dists)}")
        print(f"OLS scaled_dists NaN count: {np.isnan(scaled_dists).sum()}")
        
        distmean = np.nanmean(scaled_dists)
        ols = max(0, min(1, 1-distmean)) if distmean is not None else None
        print(f"OLS distmean: {distmean}, ols: {ols}")
        
        return ols
    except Exception as e:
        print(f"Error in OLS calculation: {str(e)}")
        return None

def calculate_emd(preds, targs):
    print(f"EMD preds shape: {preds.shape}, targs shape: {targs.shape}")
    preds_flat = preds.flatten()
    targs_flat = targs.flatten()
    print(f"EMD flattened shapes: preds {preds_flat.shape}, targs {targs_flat.shape}")
    print(f"EMD preds NaN count: {np.isnan(preds_flat).sum()}, targs NaN count: {np.isnan(targs_flat).sum()}")
    
    # Remove any NaN values
    preds_flat = preds_flat[~np.isnan(preds_flat)]
    targs_flat = targs_flat[~np.isnan(targs_flat)]
    
    if len(preds_flat) == 0 or len(targs_flat) == 0:
        print("EMD: All values are NaN")
        return None
    
    dists = wasserstein_distance(preds_flat, targs_flat)
    print(f"EMD distance: {dists}")
    
    scaling_factor = 0.10082047548620601
    scaled_dists = dists ** scaling_factor
    print(f"EMD scaled distance: {scaled_dists}")
    
    emd = max(0, min(1, 1-scaled_dists))
    print(f"EMD score: {emd}")
    return emd

# def calculate_ssim(preds, targs):
#     try:
#         print(f"SSIM preds shape: {preds.shape}, targs shape: {targs.shape}")
#         print(f"SSIM preds NaN count: {np.isnan(preds).sum()}, targs NaN count: {np.isnan(targs).sum()}")
        
#         # Remove the time dimension if it exists
#         if preds.ndim == 4:
#             preds = preds[:,:,:,0]
#         if targs.ndim == 4:
#             targs = targs[:,:,:,0]
        
#         ssim_score = ssim(preds, targs, data_range=1.0, multichannel=True)
#         print(f"SSIM raw score: {ssim_score}")
        
#         scaling_factor = 10.31885115
#         scaled_ssim = float(ssim_score ** scaling_factor)
#         print(f"SSIM scaled score: {scaled_ssim}")
        
#         return max(0, min(1, scaled_ssim))
#     except Exception as e:
#         print(f"Error in SSIM calculation: {str(e)}")
#         return None

import numpy as np

def calculate_ssim(preds, targs):
    try:
        print(f"SSIM preds shape: {preds.shape}, targs shape: {targs.shape}")
        print(f"SSIM preds NaN count: {np.isnan(preds).sum()}, targs NaN count: {np.isnan(targs).sum()}")
        
        # Remove the time dimension if it exists
        if preds.ndim == 4:
            preds = preds[:,:,:,0]
        if targs.ndim == 4:
            targs = targs[:,:,:,0]
        
        # Constants
        L = 1.0  # Dynamic range of pixel values (normalized to [0, 1])
        k1, k2 = 0.01, 0.03
        c1 = (k1 * L) ** 2
        c2 = (k2 * L) ** 2
        
        # Calculate means
        mu_x = np.mean(preds)
        mu_y = np.mean(targs)
        
        # Calculate variances and covariance
        var_x = np.var(preds)
        var_y = np.var(targs)
        cov_xy = np.mean((preds - mu_x) * (targs - mu_y))
        
        # Calculate SSIM
        numerator = (2 * mu_x * mu_y + c1) * (2 * cov_xy + c2)
        denominator = (mu_x**2 + mu_y**2 + c1) * (var_x + var_y + c2)
        ssim_score = numerator / denominator
        
        print(f"SSIM raw score: {ssim_score}")
        
        scaling_factor = 10.31885115
        scaled_ssim = float(ssim_score ** scaling_factor)
        print(f"SSIM scaled score: {scaled_ssim}")
        
        return max(0, min(1, scaled_ssim))
    except Exception as e:
        print(f"Error in SSIM calculation: {str(e)}")
        return None

def earthnetscore(preds, targs):
    mad = calculate_mad(preds, targs)
    ols = calculate_ols(preds, targs)
    emd = calculate_emd(preds, targs)
    ssim_score = calculate_ssim(preds, targs)
    
    scores = [mad, ols, emd, ssim_score]
    valid_scores = list(filter(None, scores))
    
    ens = min(1, len(valid_scores) / sum([1/(v+1e-8) for v in valid_scores])) if valid_scores else None # v+1e-8: also used in source code, so keep it
    
    return ens, {'mad': mad, 'ols': ols, 'emd': emd, 'ssim': ssim_score}


def get_non_imputed_pairs(csv_path, pred_files):
    print(f"Reading CSV file: {csv_path}")
    if not os.path.exists(csv_path):
        print(f"Error: CSV file not found at {csv_path}")
        return []

    df = pd.read_csv(csv_path)
    print(f"CSV file shape: {df.shape}")
    print(f"CSV columns: {df.columns}")
    print(f"Sample data:\n{df.head()}")

    if 'imputed' not in df.columns or 'file_name' not in df.columns:
        print("Error: 'imputed' or 'file_name' column not found in CSV file")
        return []

    non_imputed = df[df['imputed'] == 'no']
    print(f"Number of non-imputed rows: {len(non_imputed)}")

    valid_pairs = []
    for _, row in non_imputed.iterrows():
        file_name = row['file_name']
        if file_name in pred_files:
            valid_pairs.append((file_name, _))
        else:
            print(f"Prediction file not found: {file_name}")

    print(f"Number of valid pairs found: {len(valid_pairs)}")
    return valid_pairs

# Base path and subfolders
base_path = r"C:\TjallingData\greenearthnet_additional"
subfolders = ['iid_chopped', 'ood-s_chopped', 'ood-st_chopped', 'ood-t_chopped']

all_scores = []

for subfolder in subfolders:
    print(f"\nProcessing subfolder: {subfolder}")
    
    # Paths for the current subfolder
    predictions_dir = os.path.join(base_path, subfolder, "predictions_base_SDXL")
    NIR_image_dir = os.path.join(base_path, subfolder, "NIR_total")
    RGB_image_dir = os.path.join(base_path, subfolder, "RGB_total")

    # Correct CSV file name
    csv_filename = f"{subfolder}_combined_imputed_with_region_normalized_season_minicube_and_usable.csv"
    csv_path = os.path.join(base_path, subfolder, csv_filename)

    # Get the minicubes with prediction images
    minicubes = [d for d in os.listdir(os.path.join(predictions_dir, "NIR")) if os.path.isdir(os.path.join(predictions_dir, "NIR", d))]

    subfolder_scores = []

    for minicube in minicubes[:2]:
        try:
            print(f"\nProcessing minicube: {minicube}")
            
            # Get prediction files
            pred_nir_dir = os.path.join(predictions_dir, "NIR", minicube)
            pred_rgb_dir = os.path.join(predictions_dir, "RGB", minicube)
            pred_files = set(os.listdir(pred_nir_dir))
            print(f"Number of prediction files: {len(pred_files)}")

            # Get non-imputed pairs
            valid_pairs = get_non_imputed_pairs(csv_path, pred_files)

            print(f"Valid non-imputed pairs for {minicube}:")
            for pair in valid_pairs:
                print(f"Prediction: {pair[0]}, Index: {pair[1]}")

            if not valid_pairs:
                print(f"No valid pairs found for {minicube}. Skipping.")
                continue

            # Load predictions and targets
            pred_nir = load_images(pred_nir_dir, [pair[0] for pair in valid_pairs])
            pred_rgb = load_images(pred_rgb_dir, [pair[0] for pair in valid_pairs])
            
            target_size = pred_nir.shape[1:3]
            target_files = [pair[0] for pair in valid_pairs]
            target_nir = load_images(NIR_image_dir, target_files, target_size=target_size)
            target_rgb = load_images(RGB_image_dir, target_files, target_size=target_size)

            # Calculate NDVI
            pred_ndvi = calculate_ndvi(pred_nir, pred_rgb)
            target_ndvi = calculate_ndvi(target_nir, target_rgb)

            # Calculate EarthNetScore
            ens, subscores = earthnetscore(pred_ndvi, target_ndvi)
            
            minicube_score = {
                'minicube': minicube,
                'earthnetscore': ens,
                'subscores': subscores,
                'valid_pairs': len(valid_pairs)
            }
            subfolder_scores.append(minicube_score)
            all_scores.append({
                'subfolder': subfolder,
                **minicube_score
            })
            
            print(f"MiniCube: {minicube}")
            print(f"EarthNetScore: {ens}")
            print(f"Subscores: {subscores}")
            print(f"Valid pairs: {len(valid_pairs)}")
            print()
        except Exception as e:
            print(f"Error processing minicube {minicube} in subfolder {subfolder}: {str(e)}")
            print()

    # Calculate and print average scores for the subfolder
    subfolder_ens = np.mean([score['earthnetscore'] for score in subfolder_scores if score['earthnetscore'] is not None])
    subfolder_subscores = {
        'mad': np.mean([score['subscores']['mad'] for score in subfolder_scores if score['subscores']['mad'] is not None]),
        'ols': np.mean([score['subscores']['ols'] for score in subfolder_scores if score['subscores']['ols'] is not None]),
        'emd': np.mean([score['subscores']['emd'] for score in subfolder_scores if score['subscores']['emd'] is not None]),
        'ssim': np.mean([score['subscores']['ssim'] for score in subfolder_scores if score['subscores']['ssim'] is not None])
    }
    print(f"\nSubfolder {subfolder} average EarthNetScore: {subfolder_ens}")
    print(f"Subfolder {subfolder} average subscores: {subfolder_subscores}")

# Calculate overall EarthNetScore across all test tracks
valid_scores = [score['earthnetscore'] for score in all_scores if score['earthnetscore'] is not None]
overall_ens = np.mean(valid_scores) if valid_scores else None
print(f"\nOverall EarthNetScore across all test tracks: {overall_ens}")

# Calculate overall average subscores
overall_subscores = {
    'mad': np.mean([score['subscores']['mad'] for score in all_scores if score['subscores']['mad'] is not None]),
    'ols': np.mean([score['subscores']['ols'] for score in all_scores if score['subscores']['ols'] is not None]),
    'emd': np.mean([score['subscores']['emd'] for score in all_scores if score['subscores']['emd'] is not None]),
    'ssim': np.mean([score['subscores']['ssim'] for score in all_scores if score['subscores']['ssim'] is not None])
}
print(f"Overall average subscores: {overall_subscores}")

# Write results to JSON file
results = {
    'overall_earthnetscore': overall_ens,
    'overall_subscores': overall_subscores,
    'test_tracks': {}
}

for subfolder in subfolders:
    subfolder_scores = [score for score in all_scores if score['subfolder'] == subfolder]
    results['test_tracks'][subfolder] = {
        'average_earthnetscore': np.mean([score['earthnetscore'] for score in subfolder_scores if score['earthnetscore'] is not None]),
        'average_subscores': {
            'mad': np.mean([score['subscores']['mad'] for score in subfolder_scores if score['subscores']['mad'] is not None]),
            'ols': np.mean([score['subscores']['ols'] for score in subfolder_scores if score['subscores']['ols'] is not None]),
            'emd': np.mean([score['subscores']['emd'] for score in subfolder_scores if score['subscores']['emd'] is not None]),
            'ssim': np.mean([score['subscores']['ssim'] for score in subfolder_scores if score['subscores']['ssim'] is not None])
        },
        'minicubes': subfolder_scores
    }

with open('earthnetscore_results.json', 'w') as f:
    json.dump(results, f, indent=2)

print("\nResults have been written to earthnetscore_results.json")


Processing subfolder: iid_chopped

Processing minicube: JAS17_minicube_0_29SND_39.29_-8
Number of prediction files: 20
Reading CSV file: C:\TjallingData\greenearthnet_additional\iid_chopped\iid_chopped_combined_imputed_with_region_normalized_season_minicube_and_usable.csv
CSV file shape: (85680, 13)
CSV columns: Index(['file_name', 'property', 'date', 'imputed', 'image_count',
       'Temperature Avg', 'Sea-level Pressure',
       'Shortwave Downwelling Radiation', 'Relative Humidity', 'region',
       'season', 'minicube', 'usable'],
      dtype='object')
Sample data:
                                           file_name property        date  \
0  JAS17_minicube_0_29SND_39.29_-8.56_2017-05-15.png  context  2017-05-15   
1  JAS17_minicube_0_29SND_39.29_-8.56_2017-05-20.png  context  2017-05-20   
2  JAS17_minicube_0_29SND_39.29_-8.56_2017-05-25.png  context  2017-05-25   
3  JAS17_minicube_0_29SND_39.29_-8.56_2017-05-30.png  context  2017-05-30   
4  JAS17_minicube_0_29SND_39.29_-8.56_

FileNotFoundError: [WinError 3] The system cannot find the path specified: 'C:\\TjallingData\\greenearthnet_additional\\ood-s_chopped\\predictions_base_SDXL\\NIR'

Here, I still try to leverage the EarthNet package. First, we create NIR and RGB target image folders. 