In [None]:
%matplotlib inline
import os
import random
import copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import torch.nn.functional as F

from torchviz import make_dot
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

from model import OriginalRelationshipLearner, Discriminator1, FlexibleUpsamplingModule, weights_init_normal, SSIM, TVLoss, PerceptualLoss
from datasets import CustomDataset, load_data_with_augmentation, load_data
from utils import plot_results
from taylorDiagram import TaylorDiagram

# Define the ModelTrainer class as provided
class ModelTrainer:
    def __init__(self, epochs, batch_size, relationship_learner, relationship_output_channels, smoothing_method=None, attention=None, senet=None, rand=42):
        self.epochs = epochs
        self.batch_size = batch_size
        self.relationship_output_channels = relationship_output_channels
        self.smoothing_method = smoothing_method
        self.attention = attention
        self.senet = senet
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.rand = rand
        
        # Load and prepare data
        [self.lr_grace_05, self.trend05], [self.lr_grace_025, self.trend25], self.hr_aux, self.grace_scaler_05, self.grace_scaler_025, self.aux_scalers = load_data_with_augmentation()
        [self.lr_grace_05o, self.trend05o], [self.lr_grace_025o, self.trend25o], self.hr_auxo, self.grace_scaler_05o, self.grace_scaler_025o, self.aux_scalerso = load_data()
        
        # Apply data smoothing to hr_aux if smoothing_method is specified
        if self.smoothing_method:
            self.hr_aux = self.smoothing_method(self.hr_aux)
        
        # Split data into training and testing sets (if needed)
        split_index = int(len(self.lr_grace_05) * 0.8)  # 80% training, 20% testing
        self.train_lr_grace_05, self.test_lr_grace_05 = self.lr_grace_05[:split_index], self.lr_grace_05[split_index:]
        self.train_lr_grace_025, self.test_lr_grace_025 = self.lr_grace_025[:split_index], self.lr_grace_025[split_index:]
        self.train_hr_aux, self.test_hr_aux = self.hr_aux[:split_index], self.hr_aux[split_index:]
        
        # Train-test split (Optional, if you need separate splits)
        self.train_lr_grace_05, self.test_lr_grace_05, self.train_lr_grace_025, self.test_lr_grace_025, self.train_hr_aux, self.test_hr_aux = train_test_split(
            self.lr_grace_05, self.lr_grace_025, self.hr_aux, test_size=0.2, random_state=self.rand)
        
        # Create datasets and dataloaders
        self.train_dataset = CustomDataset(self.train_lr_grace_05, self.train_lr_grace_025, self.train_hr_aux)
        self.test_dataset = CustomDataset(self.test_lr_grace_05, self.test_lr_grace_025, self.test_hr_aux)
        
        # Create full dataset (no split)
        self.full_dataset = CustomDataset(self.lr_grace_05o, self.lr_grace_025o, self.hr_auxo)
        
        # Dataloaders
        self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
        self.test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)
        
        # Full loader for the entire dataset (no split)
        self.full_loader = DataLoader(self.full_dataset, batch_size=self.batch_size, shuffle=False)

        # Initialize models
        self.discriminator = Discriminator1().to(self.device)
        self.upsampling_module = FlexibleUpsamplingModule(input_channels=self.hr_aux.shape[-1]+1, attention_type=self.attention).to(self.device)
        
        # Initialize optional modules
        if self.attention:
            self.attention_module = self.attention.to(self.device)
        else:
            self.attention_module = None
            
        if self.senet:
            self.senet_module = self.senet.to(self.device)
        else:
            self.senet_module = None
        
        # Initialize weights
        self.discriminator.apply(weights_init_normal)
        self.upsampling_module.apply(weights_init_normal)
        if self.attention_module:
            self.attention_module.apply(weights_init_normal)
        if self.senet_module:
            self.senet_module.apply(weights_init_normal)
        
        # Optimizers
        hat_parameters = list(self.upsampling_module.parameters())
        if self.attention_module:
            hat_parameters += list(self.attention_module.parameters())
        if self.senet_module:
            hat_parameters += list(self.senet_module.parameters())
        
        self.optimizer_D = optim.AdamW(self.discriminator.parameters(), lr=0.0004, betas=(0.5, 0.999), weight_decay=1e-4)
        self.optimizer_U = optim.AdamW(hat_parameters, lr=0.0002, betas=(0.5, 0.999), weight_decay=1e-4)
        
        # Learning Rate Schedulers
        self.scheduler_D = CosineAnnealingWarmRestarts(self.optimizer_D, T_0=10, T_mult=2, eta_min=1e-6)
        self.scheduler_U = CosineAnnealingWarmRestarts(self.optimizer_U, T_0=10, T_mult=2, eta_min=1e-6)
        
        # Loss functions
        self.adversarial_loss = torch.nn.BCEWithLogitsLoss()
        self.pixelwise_loss = torch.nn.MSELoss()
        self.ssim_loss = SSIM(window_size=11, size_average=True).to(self.device)
        self.tv_loss = TVLoss(weight=1e-5).to(self.device)
        self.perceptual_loss = PerceptualLoss(use_gpu=torch.cuda.is_available())

    
    def smooth_data_gaussian(self, data, sigma=2):
        return gaussian_filter(data, sigma=sigma)

    def smooth_data_median(self, data, size=3):
        return median_filter(data, size=size)

    def smooth_data_savitzky_golay(self, data, window_length=5, polyorder=2):
        return savgol_filter(data, window_length, polyorder)

    def train(self):
        train_losses_G = []
        train_losses_D = []
        patience = 20  # Early stopping patience
        min_delta = 0  # Minimum change in monitored value to qualify as improvement
        trigger_times = 0  # Counter for early stopping
        best_loss = float('inf')

        for epoch in range(self.epochs):
            epoch_loss_G = 0
            epoch_loss_D = 0

            # Training phase
            self.upsampling_module.train()
            self.discriminator.train()
            if self.attention_module:
                self.attention_module.train()
            if self.senet_module:
                self.senet_module.train()

            for lr_grace_05, lr_grace_025, hr_aux in self.train_loader:
                lr_grace = F.interpolate(lr_grace_05, scale_factor=0.5, mode='bicubic', align_corners=False)
                lr_grace, hr_aux = lr_grace.to(self.device), hr_aux.to(self.device)
                lr_grace_025 = lr_grace_025.to(self.device)

                # Combine lr_grace and downsampled hr_aux
                downsampled_aux = F.interpolate(hr_aux, scale_factor=0.25, mode='bicubic', align_corners=False)
                combined_input = torch.cat([lr_grace, downsampled_aux], dim=1)

                # Learn relationship features
                relationship_features = combined_input
                if self.attention_module:
                    relationship_features = self.attention_module(relationship_features)
                elif self.senet_module:
                    relationship_features = self.senet_module(relationship_features)

                # Generate HR result using HAT module
                hr_generated = self.upsampling_module(relationship_features)

                # Discriminator training
                self.optimizer_D.zero_grad()
                real_output = self.discriminator(lr_grace_025)
                fake_output = self.discriminator(hr_generated.detach())
                real_labels = torch.ones_like(real_output, device=self.device)
                fake_labels = torch.zeros_like(fake_output, device=self.device)

                loss_D_real = self.adversarial_loss(real_output, real_labels)
                loss_D_fake = self.adversarial_loss(fake_output, fake_labels)
                loss_D = (loss_D_real + loss_D_fake) / 2
                loss_D.backward()
                self.optimizer_D.step()

                # Generator training
                self.optimizer_U.zero_grad()
                fake_output = self.discriminator(hr_generated)
                loss_G_adv = self.adversarial_loss(fake_output, real_labels)
                loss_G_pixel = self.pixelwise_loss(hr_generated, lr_grace_025)
                loss_G_ssim = 1 - self.ssim_loss(hr_generated, lr_grace_025)
                loss_G_tv = self.tv_loss(hr_generated)
                loss_G_perceptual = self.perceptual_loss(hr_generated, lr_grace_025)
                loss_weight = epoch / self.epochs  # Linearly increase adversarial weight HAT
                loss_G = (1 - loss_weight) * loss_G_pixel + loss_weight * loss_G_adv + loss_G_tv + loss_G_perceptual
                loss_G.backward()
                self.optimizer_U.step()

                epoch_loss_G += loss_G.item()
                epoch_loss_D += loss_D.item()

            # Average losses over the epoch
            avg_epoch_loss_G = epoch_loss_G / len(self.train_loader)
            avg_epoch_loss_D = epoch_loss_D / len(self.train_loader)

            # Early Stopping Check
            if avg_epoch_loss_G < best_loss - min_delta:
                best_loss = avg_epoch_loss_G
                trigger_times = 0
                # Save the best model state
                torch.save(self.upsampling_module.state_dict(), 'best_model.pth')
            else:
                trigger_times += 1
                print(f'EarlyStopping: {trigger_times}/{patience} epochs with no improvement.')
                if trigger_times >= patience:
                    print('Early stopping triggered.')
                    # Load the best model state before stopping
                    self.upsampling_module.load_state_dict(torch.load('best_model.pth'))
                    return train_losses_G, train_losses_D

            # Update the schedulers at the end of the epoch
            self.scheduler_D.step()
            self.scheduler_U.step()

            train_losses_G.append(avg_epoch_loss_G)
            train_losses_D.append(avg_epoch_loss_D)

            print(f'Epoch [{epoch+1}/{self.epochs}], Loss D: {avg_epoch_loss_D:.4f}, Loss G: {avg_epoch_loss_G:.4f}')

        # Load the best model at the end of training
        self.upsampling_module.load_state_dict(torch.load('best_model.pth'))
        return train_losses_G, train_losses_D

    def evaluate(self):
        self.upsampling_module.eval()
        if self.attention_module:
            self.attention_module.eval()
        if self.senet_module:
            self.senet_module.eval()
        with torch.no_grad():
            preds = []
            trues = []
            bs = 0
            for lr_grace_05, lr_grace_025, hr_aux in self.test_loader:
                bs += 1
                # Save predictions and true values for metrics calculation
                lr_grace_05, lr_grace_025, hr_aux = lr_grace_05.to(self.device), lr_grace_025.to(self.device), hr_aux.to(self.device)
                lr_grace = F.interpolate(lr_grace_05, scale_factor=0.5, mode='bicubic', align_corners=False)

                # Combine lr_grace and downsampled hr_aux
                downsampled_aux = F.interpolate(hr_aux, scale_factor=0.25, mode='bicubic', align_corners=False)
                combined_input = torch.cat([lr_grace, downsampled_aux], dim=1)

                # Learn relationship features
                relationship_features = combined_input
                if self.attention_module:
                    relationship_features = self.attention_module(relationship_features)
                elif self.senet_module:
                    relationship_features = self.senet_module(relationship_features)

                # Generate HR result using upsampling module
                hr_generated = self.upsampling_module(relationship_features)

                # Upsample lr_grace to create the ground truth for hr_generated
                hr_grace_upsampled = lr_grace_025
                preds.append(hr_generated.cpu().numpy())
                trues.append(hr_grace_upsampled.cpu().numpy())

            # Compute evaluation metrics
            preds = np.concatenate(preds, axis=0).reshape(-1)
            trues = np.concatenate(trues, axis=0).reshape(-1)

            cc = np.corrcoef(trues, preds)
            mse = mean_squared_error(trues, preds)
            mae = mean_absolute_error(trues, preds)
            r2 = r2_score(trues, preds)

            print(f"Test MSE: {mse}, Test MAE: {mae}, Test R²: {r2}, Test cc: {cc}")

        return preds, trues, r2

# Define the EnsembleTrainer class
class EnsembleTrainer:
    def __init__(self, num_ensemble, model_trainer_kwargs, ensemble_dir='ensemble_models'):
        """
        Initializes the EnsembleTrainer.
        
        Parameters:
            num_ensemble (int): Number of ensemble members.
            model_trainer_kwargs (dict): Keyword arguments to pass to ModelTrainer.
            ensemble_dir (str): Directory to save ensemble models.
        """
        self.num_ensemble = num_ensemble
        self.model_trainer_kwargs = model_trainer_kwargs
        self.ensemble_dir = ensemble_dir
        os.makedirs(self.ensemble_dir, exist_ok=True)
        self.seeds = [42 + i for i in range(num_ensemble)]  # Different seed for each member

    def set_seed(self, seed):
        """Sets the random seed for reproducibility."""
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    def train_ensemble(self):
        """Trains each ensemble member and saves the trained models."""
        for i, seed in enumerate(self.seeds):
            print(f"\n===== Training Ensemble Member {i+1}/{self.num_ensemble} with Seed {seed} =====")
            self.set_seed(seed)
            # Update model_trainer_kwargs with the fixed 'rand' for consistent data split
            trainer_kwargs = copy.deepcopy(self.model_trainer_kwargs)
            trainer_kwargs['rand'] = 42  # Fixed seed for data split to ensure same train/test split across members
            
            # Initialize ModelTrainer
            trainer = ModelTrainer(**trainer_kwargs)
            
            # Train the model
            train_losses_G, train_losses_D = trainer.train()
            
            # Save the trained upsampling_module
            model_save_path = os.path.join(self.ensemble_dir, f'best_model_member_{i+1}.pth')
            torch.save(trainer.upsampling_module.state_dict(), model_save_path)
            print(f"Saved Ensemble Member {i+1} model to '{model_save_path}'")

    def load_ensemble_models(self, model_class, device, input_channels, attention_type=None):
        """
        Loads all ensemble models from the ensemble directory.
        
        Parameters:
            model_class (class): The class of the upsampling module.
            device (torch.device): Device to load the models onto.
            input_channels (int): Number of input channels for the model.
            attention_type (str or None): Type of attention used, if any.
        
        Returns:
            list: List of loaded models.
        """
        models = []
        for i in range(1, self.num_ensemble + 1):
            model_path = os.path.join(self.ensemble_dir, f'best_model_member_{i}.pth')
            if not os.path.exists(model_path):
                raise FileNotFoundError(f"Model file '{model_path}' does not exist.")
            model = model_class(input_channels=input_channels, attention_type=attention_type).to(device)
            model.load_state_dict(torch.load(model_path, map_location=device))
            model.eval()
            models.append(model)
            print(f"Loaded Ensemble Member {i} from '{model_path}'")
        return models

    def predict_ensemble(self, models, test_loader):
        """
        Generates predictions from all ensemble members.
        
        Parameters:
            models (list): List of trained models.
            test_loader (DataLoader): DataLoader for the test dataset.
        
        Returns:
            np.ndarray: Array of predictions from each ensemble member. 
                        Shape: (num_ensemble, num_samples, channels, lat, lon)
            np.ndarray: Array of ground truth values. 
                        Shape: (num_samples, channels, lat, lon)
        """
        [lr_grace_05o, trend05o], [lr_grace_025o, trend25o], hr_auxo, grace_scaler_05o, grace_scaler_025o, aux_scalerso = load_data()
        # Will store predictions from each ensemble member
        # and the ground truths (we only need one copy of ground truth)
        all_preds = []
        all_trues = None  # We'll set this after the first model

        # Loop over each ensemble member
        for idx, model in enumerate(models):
            print(f"\n===== Predicting with Ensemble Member {idx+1} =====")
            
            # Lists to accumulate batch-wise predictions and ground truths for this model
            preds_per_model = []
            trues_per_model = []

            with torch.no_grad():
                # Loop over batches in the test_loader
                for lr_grace_05, lr_grace_025, hr_aux in test_loader:
                    # Prepare inputs
                    lr_grace = F.interpolate(lr_grace_05, scale_factor=0.5, mode='bicubic', align_corners=False)
                    lr_grace = lr_grace.to(model.device)
                    hr_aux = hr_aux.to(model.device)
                    lr_grace_025 = lr_grace_025.to(model.device)

                    # Combine lr_grace and downsampled hr_aux
                    downsampled_aux = F.interpolate(hr_aux, scale_factor=0.25, mode='bicubic', align_corners=False)
                    combined_input = torch.cat([lr_grace, downsampled_aux], dim=1)

                    # Generate HR result using the upsampling module
                    hr_generated = model(combined_input)

                    # Collect predictions and ground truth for this batch
                    preds_per_model.append(hr_generated.cpu().numpy())
                    trues_per_model.append(lr_grace_025.cpu().numpy())

            # Concatenate all batches for this ensemble member
            preds_per_model = np.concatenate(preds_per_model, axis=0)  # (num_samples, channels, lat, lon)
            trues_per_model = np.concatenate(trues_per_model, axis=0)  # (num_samples, channels, lat, lon)
            preds_per_model = grace_scaler_05o.inverse_transform(preds_per_model.reshape(-1, 1)).reshape(preds_per_model.shape)
            trues_per_model = grace_scaler_05o.inverse_transform(trues_per_model.reshape(-1, 1)).reshape(trues_per_model.shape)
            print(f"Ensemble Member {idx+1} Predictions Shape: {preds_per_model.shape}")

            # Store the predictions
            all_preds.append(preds_per_model)

            # For ground truth, we only need to store it once
            # (assuming the ground truth doesn't change between ensemble members)
            if idx == 0:
                all_trues = trues_per_model

        # Stack all ensemble predictions into shape: (num_ensemble, num_samples, channels, lat, lon)
        all_preds = np.stack(all_preds, axis=0)
        
        print(f"\nAll Ensemble Predictions Shape: {all_preds.shape}")
        print(f"Ground Truths Shape: {all_trues.shape}")
        
        return all_preds, all_trues

    def compute_uncertainty(self, all_preds, trues):
        """
        Computes the mean and standard deviation across ensemble predictions.
        
        Parameters:
            all_preds (np.ndarray): Predictions from all ensemble members. Shape: (num_ensemble, num_samples, channels, lat, lon)
            trues (np.ndarray): Ground truth values. Shape: (num_samples, channels, lat, lon)
        
        Returns:
            np.ndarray: Mean predictions across ensemble members.
            np.ndarray: Uncertainty estimates (standard deviation) across ensemble members.
            float: R² score for the ensemble mean.
        """
        tpbh = np.load('tpb_h.npy')
        mask = tpbh == 0  # Assuming mask shape is (lat, lon)       
        # Apply spatial mask to predictions
        all_preds_masked = all_preds.copy()
        all_preds_masked[:, :, :, mask] = np.nan
        
        # Apply spatial mask to ground truth
        trues_masked = trues.copy()
        trues_masked[:, :, mask] = np.nan
        
        # Compute spatial means (over lat/lon) for predictions and truth
        preds_ts = np.nanmean(all_preds_masked, axis=(3, 4))  # Shape: (5, T, C)
        trues_ts = np.nanmean(trues_masked, axis=(2, 3))      # Shape: (T, C)
        
        # Compute ensemble statistics
        mean_preds = np.nanmean(preds_ts, axis=0)  # Shape: (T, C)
        std_preds = np.nanstd(preds_ts, axis=0)    # Shape: (T, C)
        
        # Prepare R² calculation with aligned non-NaN values
        valid_mask = ~np.isnan(trues_ts) & ~np.isnan(mean_preds)
        trues_valid = trues_ts[valid_mask].flatten()
        preds_valid = mean_preds[valid_mask].flatten()
        
        r2 = r2_score(trues_valid, preds_valid)
        
        return mean_preds, std_preds, r2

    def save_uncertainty(self, std_preds, save_path='ensemble_uncertainty.npy'):
        """Saves the uncertainty estimates to a .npy file."""
        np.save(save_path, std_preds)
        print(f"Saved uncertainty estimates to '{save_path}'")

    def save_mean_predictions(self, mean_preds, save_path='ensemble_mean_predictions.npy'):
        """Saves the mean predictions to a .npy file."""
        np.save(save_path, mean_preds)
        print(f"Saved mean predictions to '{save_path}'")

def visualize_uncertainty(std_preds, sample_indices=[0], channel=0):
    """
    Visualizes the uncertainty maps for specified samples.
    
    Parameters:
        std_preds (np.ndarray): Uncertainty estimates. Shape: (num_samples, channels, lat, lon)
        sample_indices (list): List of sample indices to visualize.
        channel (int): Channel index to visualize (default is 0).
    """
    for idx in sample_indices:
        plt.figure(figsize=(8, 6))
        plt.imshow(std_preds[idx, channel], cmap='viridis')
        plt.colorbar(label='Uncertainty (Std Dev)')
        plt.title(f'Uncertainty Map for Sample {idx}')
        plt.xlabel('Longitude')
        plt.ylabel('Latitude')
        plt.show()

def main():
    # Number of ensemble members
    num_ensemble = 5  # You can adjust this number as needed

    # Load data (this is where your downscaled data is loaded)
    [lr_grace_05, trend05], [lr_grace_025, trend25], hr_aux, grace_scaler_05, grace_scaler_025, aux_scalers = load_data_with_augmentation()

    # Initialize model_trainer_kwargs
    model_trainer_kwargs = {
        'epochs': 100,
        'batch_size': 16,
        'relationship_learner': None,
        'relationship_output_channels': 64,
        'smoothing_method': None,
        'attention': None,
        'senet': None
    }

    # Initialize EnsembleTrainer
    ensemble_trainer = EnsembleTrainer(num_ensemble=num_ensemble, 
                                       model_trainer_kwargs=model_trainer_kwargs, 
                                       ensemble_dir='ensemble_models')

    # Train ensemble models
    ensemble_trainer.train_ensemble()

    # Load ensemble models
    models = ensemble_trainer.load_ensemble_models(
        model_class=FlexibleUpsamplingModule, 
        device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 
        input_channels=hr_aux.shape[-1] + 1,  
        attention_type=None
    )

    # Initialize ModelTrainer (this will give access to the full data for prediction)
    temp_trainer_kwargs = copy.deepcopy(model_trainer_kwargs)
    temp_trainer_kwargs['rand'] = 42  # Fixed seed for data split
    temp_trainer_kwargs['epochs'] = 0  # No training, just to load the data
    temp_trainer = ModelTrainer(**temp_trainer_kwargs)

    # You can predict on the full dataset, not just the test set
    full_loader = temp_trainer.full_loader  # Ensure 'full_loader' gives access to the entire dataset

    # Perform ensemble predictions for the full dataset
    all_preds, trues = ensemble_trainer.predict_ensemble(models, full_loader)

    # Compute uncertainty across the entire dataset
    mean_preds, std_preds, r2 = ensemble_trainer.compute_uncertainty(all_preds, trues)

    # Save uncertainty estimates
    ensemble_trainer.save_mean_predictions(mean_preds, save_path='ensemble_mean_predictions.npy')
    ensemble_trainer.save_uncertainty(std_preds, save_path='ensemble_uncertainty_averaged.npy')

    # Save full dataset ground truths and mean predictions
    np.save('ensemble_trues.npy', trues)
    np.save('ensemble_mean_preds.npy', mean_preds)

    # Visualize uncertainty for selected samples
    #visualize_uncertainty(std_preds, sample_indices=[0, 10, 20], channel=0)

    # Print R² score for full dataset
    print(f"Ensemble Mean R² Score (Full Dataset): {r2}")

    # Save the full uncertainty map (now with the same shape as your downscaled data)


if __name__ == "__main__":
    main()

(181, 90, 44)
(181, 180, 88, 1)
[509.70157107]
[-32767.]
[-32767.]
[-32767.]
Combined HR Aux Data Shape: (181, 180, 88, 45)
0.0
65.5
Sliced HR Aux Data Shape: (181, 180, 88, 45)
-5.350948318234112
(180, 88, 7)
最大误差: 8.881784197001252e-16


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def visualize_uncertainty(std_preds, sample_indices=[0], channel=0, rotate=False):
    """
    Visualizes the uncertainty maps for specified samples.
    
    Parameters:
        std_preds (np.ndarray): Uncertainty estimates. Shape: (num_samples, channels, lat, lon)
        sample_indices (list): List of sample indices to visualize.
        channel (int): Channel index to visualize (default is 0).
        rotate (bool): Whether to rotate the uncertainty map by 90 degrees anti-clockwise.
    """
    for idx in sample_indices:
        data = std_preds[idx, channel]
        if rotate:
            data = np.rot90(data)  # Rotate 90 degrees anti-clockwise
        
        plt.figure(figsize=(8, 6))  # Keep figure size constant
        plt.imshow(data, cmap='viridis', aspect='auto')  # Ensure correct aspect ratio
        plt.colorbar(label='Uncertainty (Std Dev)', pad=0.01)
        plt.title(f'Uncertainty Map for Sample {idx}', pad=10)
        plt.xlabel('Longitude', labelpad=5)
        plt.ylabel('Latitude', labelpad=5)
        plt.xticks(fontsize=8)
        plt.yticks(fontsize=8)
        plt.tight_layout()  # Adjust layout for better spacing
        plt.show()
unc=np.load('ensemble_uncertainty.npy')
print(np.shape(unc))
mee=np.nanmean(unc)
print(mee)

tpbh=np.load('tpb_h.npy')
unc[:,:,tpbh==0]=np.nan
visualize_uncertainty(unc, sample_indices=[0, 10, 20], channel=0, rotate=True)  # Adjust sample_indices as needed