In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from scipy.ndimage import gaussian_filter
from sklearn.metrics import r2_score
import yaml
import os
import sys
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
import matplotlib.pyplot as plt
sys.path.append(os.path.join(os.getcwd(), '../../'))
from data_loader import HDF5Dataset
from model import CNN_LSTM
from scipy import stats

def load_yaml_config(file_path):
    with open(file_path, 'r') as file:
        return yaml.safe_load(file)
config = load_yaml_config('config/config.yaml')

def apply_gaussian_perturbation(data, location, std_dev, amplitude=1, flip_sign=False):
    """
    Apply a Gaussian perturbation to the input data.
    Parameters:
    - data: 4D array (batch, time, latitude, longitude) of weather data.
    - location: Tuple (lat_idx, lon_idx) indicating the center of the Gaussian perturbation.
    - std_dev: Standard deviation of the Gaussian kernel.
    - amplitude: Amplitude of the Gaussian perturbation.
    - flip_sign: Boolean, if True the perturbation is subtracted, otherwise added.
    Returns:
    - perturbed_data: 4D array of perturbed weather data.
    """
    batch, time, lat, lon = data.shape
    x = np.arange(lon)
    y = np.arange(lat)
    x, y = np.meshgrid(x, y)
    gaussian = np.exp(-((x-location[1])**2 + (y-location[0])**2) / (2 * std_dev**2))
    gaussian = amplitude * gaussian / gaussian.max()  # Normalize the Gaussian
    perturbed_data = np.copy(data)
    for b in range(batch):
        for t in range(time):
            if flip_sign:
                perturbed_data[b, t] -= gaussian
            else:
                perturbed_data[b, t] += gaussian
    return perturbed_data

def run_inference(model, data_loader, device):
    model.eval()
    all_outputs = []
    for batch in data_loader:
        ppt = batch['ppt'].to(device)
        tmin = batch['tmin'].to(device)
        tmax = batch['tmax'].to(device)
        outputs = model(ppt, tmin, tmax)
        all_outputs.append(outputs.detach().cpu().numpy())
    return np.concatenate(all_outputs, axis=0)

def calculate_sensitivity(original_outputs, perturbed_outputs):
    """
    Calculate the sensitivity as mean absolute difference between original and perturbed outputs.
    
    Parameters:
    - original_outputs: numpy array of shape (num_samples, 61)
    - perturbed_outputs: numpy array of shape (num_samples, 61)
    
    Returns:
    - sensitivity: numpy array of shape (61,)
    """
    return np.mean(np.abs(perturbed_outputs - original_outputs), axis=0)

def generate_sensitivity_maps(sensitivities, save_path, ks_distances=None, area_fractions=None):
    num_locations = sensitivities.shape[0]  # Assume shape is (61, lat, lon)
    for i in range(num_locations):
        plt.imshow(sensitivities[i], cmap='hot', interpolation='nearest')
        plt.colorbar()
        plt.title(f"Sensitivity Map for Streamflow Location {i+1}")
        plt.savefig(f"{save_path}sensitivity_output_{i+1}.png")
        plt.close()

def calculate_sensitivity_metrics(sensitivities, watershed_mask):
    """
    Calculate Kolmogorov-Smirnov distance and area fraction metrics for sensitivity maps.
    
    Parameters:
    - sensitivities: numpy array of shape (61, lat, lon) containing sensitivity values
    - watershed_mask: boolean array of shape (lat, lon) indicating watershed locations
    
    Returns:
    - ks_distances: array of K-S distances for each gauge
    - area_fractions: array of area fractions above half max sensitivity
    """
    ks_distances = []
    area_fractions = []
    
    for gauge_idx in range(sensitivities.shape[0]):
        # Get sensitivities for current gauge
        gauge_sens = sensitivities[gauge_idx]
        
        # Calculate K-S distance between distributions inside and outside watershed
        inside_dist = gauge_sens[watershed_mask].flatten()
        outside_dist = gauge_sens[~watershed_mask].flatten()
        ks_stat = stats.ks_2samp(inside_dist, outside_dist).statistic
        ks_distances.append(ks_stat)
        
        # Calculate area fraction above half max sensitivity
        half_max = gauge_sens.max() / 2
        area_frac = np.mean(gauge_sens > half_max)
        area_fractions.append(area_frac)
        
    return np.array(ks_distances), np.array(area_fractions)

NameError: name '__file__' is not defined