In [None]:
!apt update
!apt install -y libgl1-mesa-glx

In [None]:
%run setupModel.py


In [None]:

image_dict=get_images_from_s3(5,3)

In [None]:
def params():
    # All configurable parameters
    return {
        'threshold_minutes': 15,
        'ghi_resolution': '15m',
        'int_ghi_resolution': 15,
        'num_previous_images': 2,
        'Hdelta_t': 16,
        'delta_t': 15,  # Be careful: multiplied by image frequency
        'image_delta_t': 1
    }
# Call the params function and unpack values into variables
config = params()
threshold_minutes = config['threshold_minutes']
ghi_resolution = config['ghi_resolution']
int_ghi_resolution = config['int_ghi_resolution']
num_previous_images = config['num_previous_images']
Hdelta_t = config['Hdelta_t']
delta_t = config['delta_t']
image_delta_t = config['image_delta_t']


In [None]:
#keys = get_images_keys_from_s3(images_every) # List of image keys
image_keys = list(image_dict.keys())


In [None]:

class InfluxDBQueryClient:
    def __init__(self, ip, port, token, org):
        self.url = f"http://{ip}:{port}"
        self.token = token
        self.org = org
        self.timeout=5e5
        self.client = InfluxDBClient(url=self.url, token=self.token, org=self.org,timeout=self.timeout)

    def query_measured_ghi(self, start_time, end_time, resolution):
        """
        Query the measured GHI from the InfluxDB database
        :param start_time: datetime.datetime
        :param end_time: datetime.datetime
        :param resolution: str
        :return: pd.DataFrame
        """
        bucket = "microgrid_ST"

        # Convert datetime to Unix timestamp
        t0 = round(start_time.timestamp())
        tf = round(end_time.timestamp())

        # Prepare the Flux query
        flux_query = f"""
        from(bucket: "microgrid_ST")
        |> range(start: {t0}, stop: {tf})
          |> filter(fn: (r) => r["_measurement"] == "microgrid")
          |> filter(fn: (r) => r["Resource"] == "meteobox_roof")
          |> filter(fn: (r) => r["_field"] == "GHI")
          |> aggregateWindow(every: {resolution}, fn: mean, createEmpty: false)
          |> yield(name: "mean")
          |> pivot(rowKey:["_time"], columnKey: ["Resource"], valueColumn: "_value")
        """

        # Query the data
        query_api = self.client.query_api()
        measured_GHI_api15 = query_api.query_data_frame(org=self.org, query=flux_query)
        return measured_GHI_api15[0]

    
    def pull_solcast_forecast(self, forecast_time):
        """
        Pull Solcast forecast data for GHI from InfluxDB.
        :param forecast_time: datetime.datetime, start of the query range
        :return: pd.DataFrame
        """
        to = forecast_time + timedelta(minutes=1)
        # Convert datetime to Unix timestamp
        t0 = round(forecast_time.timestamp())
        tf = round(to.timestamp())

        bucket = "Forecasting_ST"

        # Prepare the Flux query
        flux_query = f"""
        from(bucket: "{bucket}")
        |> range(start: {t0}, stop: {tf})
        |> filter(fn: (r) => r["type"] == "Solcast")
        |> filter(fn: (r) => r["_field"] == "ghi")
        |> pivot(rowKey:["_time"], columnKey: ["prediction_time"], valueColumn: "_value")
        |> yield(name: "mean")
        """
        # Query the data
        query_api = self.client.query_api()
        solcast_forecast = query_api.query_data_frame(org=self.org, query=flux_query)

        return solcast_forecast
    
   

    
    def pull_solcast_forecast_bulk(self, start_time, end_time):
        """
        Pull Solcast forecast data for GHI from InfluxDB over a time range.
        :param start_time: datetime.datetime
        :param end_time: datetime.datetime
        :return: pd.DataFrame
        """
        bucket = "Forecasting_ST"
        t0 = round(start_time.timestamp())
        tf = round(end_time.timestamp())
    
        flux_query = f"""
        from(bucket: "{bucket}")
        |> range(start: {t0}, stop: {tf})
        |> filter(fn: (r) => r["type"] == "Solcast")
        |> filter(fn: (r) => r["_field"] == "ghi")
        |> pivot(rowKey:["_time"], columnKey: ["prediction_time"], valueColumn: "_value")
        |> yield(name: "mean")
        """
    
        query_api = self.client.query_api()
        try:
            df = query_api.query_data_frame(org=self.org, query=flux_query)
        except Exception as e:
            print("Error al hacer la consulta:", e)
            raise

        #df = query_api.query_data_frame(org=self.org, query=flux_query)
        return df
    def pull_solcast_forecast_bulk2(self, start_time, end_time):
        bucket = "Forecasting_ST"
        t0 = round(start_time.timestamp())
        tf = round(end_time.timestamp())
    
        flux_query = f"""
        from(bucket: "{bucket}")
        |> range(start: {t0}, stop: {tf})
        |> filter(fn: (r) => r["type"] == "Solcast")
        |> filter(fn: (r) => r["_field"] == "ghi" or r["_field"] == "ghi10" or r["_field"] == "ghi90")
        |> pivot(rowKey:["_time"], columnKey: ["prediction_time", "_field"], valueColumn: "_value")
        |> yield(name: "mean")
        """
    
        query_api = self.client.query_api()
        try:
            df = query_api.query_data_frame(org=self.org, query=flux_query)
        except Exception as e:
            print("Error al hacer la consulta:", e)
            raise
    
        return df
    

In [None]:


def create_merged_data_dict(image_keys,
                            threshold_minutes=15,
                            num_previous_images=3,
                            image_delta_t=1,
                            Hdelta_t=3,
                            delta_t=1):
    start_total = time.time()

    # --- 1) Prepare image_df ---
    t1 = time.time()
    image_timestamps = [extract_timestamp_from_image(k) for k in image_keys]
    image_df = pd.DataFrame({
        'image_timestamp': pd.to_datetime(image_timestamps),
        'image_key': image_keys
    }).sort_values('image_timestamp').reset_index(drop=True)
    image_df['img_idx'] = image_df.index
    print(f"Step 1 (Prepare image_df) took {time.time() - t1:.2f}s")

    # --- 2) Time range for GHI query ---
    t2 = time.time()
    dates = image_df['image_timestamp'].dt.date.unique()
    starts = [pd.Timestamp(d).tz_localize('UTC') - pd.Timedelta(hours=4) for d in dates]
    ends   = [pd.Timestamp(d).tz_localize('UTC') + pd.Timedelta(days=1, hours=4) for d in dates]
    start_time, end_time = min(starts), max(ends) + pd.Timedelta(minutes=1)
    print(f"Step 2 (Compute query window) took {time.time() - t2:.2f}s")

    # --- 3) Query GHI from InfluxDB ---
    t3 = time.time()
    client = InfluxDBQueryClient(
        ip='***',
        port=***,
        token='***',
        org="DESL-EPFL",
    )
    
    ghi_df = client.query_measured_ghi(start_time, end_time, '1m')
         
     
    ghi_df = (ghi_df.rename(columns={'_time':'time','_value':'GHI'})
                  .assign(time=lambda df: pd.to_datetime(df['time']).dt.tz_convert('UTC'))
                  .set_index('time'))
    ghi_df['GHI_15min_avg'] = ghi_df['GHI'].rolling(window=15, center=True, min_periods=1).mean()
    print(f"Step 3 (Query GHI) took {time.time() - t3:.2f}s")

    # --- 4) Clear-sky calculation ---
    t4 = time.time()
    time_range = pd.date_range(start_time, end_time, freq='1min', tz='UTC')
    cs = batiment.get_clearsky(time_range).ghi.rename('clear_sky_GHI').to_frame()
    print(f"Step 4 (Clear-sky calc) took {time.time() - t4:.2f}s")

    # --- 5) Merge asof measured GHI at image times ---
    t5 = time.time()
    merged = pd.merge_asof(
        image_df,
        ghi_df[['GHI','GHI_15min_avg']].reset_index().rename(columns={'time':'closest_time'}),
        left_on='image_timestamp', right_on='closest_time',
        direction='nearest', tolerance=pd.Timedelta(minutes=threshold_minutes)
    )
    merged['time_diff'] = (merged['image_timestamp'] - merged['closest_time']).abs().dt.total_seconds()/60
    merged.loc[merged['time_diff']>threshold_minutes, ['GHI','GHI_15min_avg']] = np.nan
    merged = merged.dropna(subset=['GHI'])
    print(f"Step 5 (As-of merge GHI) took {time.time() - t5:.2f}s")

    # --- 6) Merge asof clear-sky at image times ---
    t6 = time.time()
    merged = pd.merge_asof(
        merged.sort_values('image_timestamp'),
        cs.reset_index().rename(columns={'index':'closest_time'}),
        on='closest_time', direction='nearest', tolerance=pd.Timedelta(minutes=threshold_minutes)
    ).dropna(subset=['clear_sky_GHI'])
    print(f"Step 6 (As-of merge clear-sky) took {time.time() - t6:.2f}s")

    # --- 7) Daily max clear-sky ---
    t7 = time.time()
    merged['date'] = merged['image_timestamp'].dt.date
    merged['max_clear_sky_ghi'] = merged.groupby('date')['clear_sky_GHI'].transform('max')
    print(f"Step 7 (Max clear-sky) took {time.time() - t7:.2f}s")

    # --- 8) Normalize ---
    t8 = time.time()
    merged['GHI_normalized'] = merged['GHI'] / merged['max_clear_sky_ghi']
    merged['clear_sky_ghi_normalized'] = merged['clear_sky_GHI'] / merged['max_clear_sky_ghi']
    print(f"Step 8 (Normalization) took {time.time() - t8:.2f}s")

    # --- 9) Past images list ---
    t9 = time.time()
    for i in range(1, num_previous_images+1):
        merged[f'prev_img_{i}'] = merged['image_key'].shift(i * image_delta_t)
    merged['image_keys'] = merged.apply(
        lambda r: [r['image_key']] + [r[f'prev_img_{i}'] for i in range(1, num_previous_images+1) if pd.notna(r[f'prev_img_{i}'])], axis=1
    )
    print(f"Step 9 (Past images) took {time.time() - t9:.2f}s")

    # --- 10) Build future-offset table ---
    t10 = time.time()
    img = merged[['img_idx','image_timestamp','max_clear_sky_ghi']]
    steps = np.arange(0, Hdelta_t+1) * delta_t
    offsets = pd.DataFrame({
        'img_idx': np.repeat(img['img_idx'].values, len(steps)),
        'step': np.tile(steps, len(img))
    })
    offsets['forecast_time'] = offsets['img_idx'].map(img.set_index('img_idx')['image_timestamp']) + pd.to_timedelta(offsets['step'], unit='m')
    print(f"Step 10 (Build offsets) took {time.time() - t10:.2f}s")

    # --- 11) Bulk asof for GHI and clear-sky at offsets ---
    t11 = time.time()
    ghi_lu = ghi_df.reset_index()[['time','GHI']]
    cs_lu  = cs.reset_index().rename(columns={'index':'time'})[['time','clear_sky_GHI']]
    offsets = pd.merge_asof(offsets.sort_values('forecast_time'), ghi_lu.sort_values('time'), left_on='forecast_time', right_on='time', direction='nearest', tolerance=pd.Timedelta(minutes=threshold_minutes))
    offsets = pd.merge_asof(offsets.sort_values('forecast_time'), cs_lu.sort_values('time'), left_on='forecast_time', right_on='time', direction='nearest', tolerance=pd.Timedelta(minutes=threshold_minutes))
    print(f"Step 11 (Bulk asof) took {time.time() - t11:.2f}s")

    # --- 12) Group into sequences ---
    t12 = time.time()
    seqs = (offsets.sort_values(['img_idx','step'])
            .groupby('img_idx')
            .agg({'GHI': lambda x: x.tolist(), 'clear_sky_GHI': lambda x: x.tolist()})
           )
    seqs['max_cs'] = img.set_index('img_idx')['max_clear_sky_ghi']
    seqs['ghi_values_normalized'] = seqs.apply(lambda r: [v/r['max_cs'] if pd.notna(v) else np.nan for v in r['GHI']], axis=1)
    seqs['clear_sky_ghi_normalized_seq'] = seqs.apply(lambda r: [v/r['max_cs'] if pd.notna(v) else np.nan for v in r['clear_sky_GHI']], axis=1)
    seqs = seqs.rename(columns={'GHI':'ghi_values','clear_sky_GHI':'clear_sky_ghi_seq'})
    print(f"Step 12 (Group sequences) took {time.time() - t12:.2f}s")

    # --- 13) Join sequences back and build dict ---
    t13 = time.time()
    merged = merged.join(seqs, on='img_idx')
    print(f"Step 13 (Join sequences) took {time.time() - t13:.2f}s")

    t14 = time.time()
    records = merged.dropna(subset=['ghi_values'])[[
        'image_timestamp','image_key','ghi_values','ghi_values_normalized',
        'clear_sky_ghi_seq','clear_sky_ghi_normalized_seq','max_clear_sky_ghi','image_keys'
    ]].to_dict(orient='records')

    merged_data_dict = {
        rec['image_timestamp']: dict(
            image_key=rec['image_key'],
            ghi_values=rec['ghi_values'],
            ghi_values_normalized=rec['ghi_values_normalized'],
            clear_sky_ghi_seq=rec['clear_sky_ghi_seq'],
            clear_sky_ghi_normalized_seq=rec['clear_sky_ghi_normalized_seq'],
            max_clear_sky_ghi=rec['max_clear_sky_ghi'],
            image_keys=rec['image_keys'],
            key_timestamp=rec['image_timestamp']
        ) for rec in records
    }
    print(f"Step 14 (Build dict) took {time.time() - t14:.2f}s; Records: {len(merged_data_dict)}")

        # --- 15) Solcast: per-key latest forecast ---
    t15 = time.time()
    x_splits=5
    # 1) Compute overall time window
    t_start = min(v['key_timestamp'] for v in merged_data_dict.values()) - timedelta(minutes=10)
    t_end   = max(v['key_timestamp'] for v in merged_data_dict.values()) + timedelta(minutes=10)
    
    # 2) Generate split boundaries
    total_duration = t_end - t_start
    splits = [
        t_start + i * (total_duration / x_splits)
        for i in range(x_splits + 1)
    ]
    
    # 3) Pull each chunk
    dfs = []
    for i in range(x_splits):
        start_i = splits[i]
        end_i   = splits[i+1]
        t0 = time.time()
        df_i = client.pull_solcast_forecast_bulk2(start_i, end_i)
        print(f"Pull {i+1}/{x_splits} finished in {time.time() - t0:.2f} s")
        dfs.append(df_i)
    
    # 4) Combine, parse and index
    sol_df = pd.concat(dfs, axis=0)
    sol_df['_time'] = pd.to_datetime(sol_df['_time'])
    sol_df = sol_df.drop_duplicates(subset='_time').set_index('_time').sort_index()




    
    print(f"Step 15.1 query (Solcast) took {time.time() - t15:.2f}s")
    t15 = time.time()
    lead_minutes = (np.arange(1, Hdelta_t+1) * 15.0).tolist()
    mean_cols = [f"{m}_ghi" for m in lead_minutes]
    q10_cols  = [f"{m}_ghi10" for m in lead_minutes]
    q90_cols  = [f"{m}_ghi90" for m in lead_minutes]

    # For each timestamp, find the last Solcast row at or before that time
    sol_index = sol_df.index
    for ts, rec in merged_data_dict.items():
        pos = sol_index.searchsorted(ts) - 1  # index of last row <= ts
        max_clearsky = rec.get('max_clear_sky_ghi', np.nan)
        if pos < 0:
            # no forecast available before ts
            preds = np.full((3, Hdelta_t), np.nan)
        else:
            row = sol_df.iloc[pos]
            # extract values
            q10_vals  = row[q10_cols].values.astype(float)/max_clearsky
            mean_vals = row[mean_cols].values.astype(float)/max_clearsky
            q90_vals  = row[q90_cols].values.astype(float)/max_clearsky
            preds = np.vstack([q10_vals, mean_vals, q90_vals])  # Shape: (3, Hdelta_t)
        rec['solcast_predictions'] = preds.tolist()
    print(f"Step 15 (Solcast) took {time.time() - t15:.2f}s")

    # --- 16) Re-index and clean ---
    t16 = time.time()
    new_dict = {i: v for i, v in enumerate(merged_data_dict.values())}
    def has_nan(x):
        if isinstance(x, (int, float)):
            return np.isnan(x)
        if isinstance(x, (list, np.ndarray)):
            return any(has_nan(elem) for elem in x)
        return False
    cleaned = {k: v for k, v in new_dict.items() if not any(has_nan(val) for val in v.values())}
    print(f"Total execution time: {time.time() - start_total:.2f}s; Dropped {len(new_dict)-len(cleaned)} items")
    return cleaned




In [None]:
retries=0
max_retries=10
wait_seconds=5
while True:
    try:
        merged_data_dict = create_merged_data_dict(
            image_keys,
            threshold_minutes,
            num_previous_images,
            image_delta_t,
            Hdelta_t,
            delta_t
        )
        print("✅ Merged data dict created successfully.")
        break
    except Exception as e:
        retries += 1
        print(f"⚠️ Attempt {retries} failed: {e}")
        if retries >= max_retries:
            raise RuntimeError("❌ Max retries exceeded while creating merged_data_dict.")
        print(f"⏳ Retrying in {wait_seconds} seconds...")
        time.sleep(wait_seconds)

In [None]:
def solcast_batchloader(keys, merged_data_dict, image_dict, batch_size,
                         expected_num_images, Hdelta_t, use_ghi_now):


    def process_sample(k):
        k = k.numpy()
        if isinstance(k, bytes):
            k = k.decode('utf-8')
        rec = merged_data_dict[k]

        # Existing inputs
        image_keys = rec['image_keys'][:expected_num_images]
        image_seq = np.stack([image_dict[ik] for ik in image_keys], axis=0)
        clear_sky_input = np.array(rec['clear_sky_ghi_normalized_seq'][:Hdelta_t + 1])
        ghi_now_input = rec['ghi_values_normalized'][0] if use_ghi_now else 0.0

        # New solcast input (3 x Hdelta_t)
        solcast_input = np.array(rec['solcast_predictions'], dtype=np.float32)

        ghi_target = np.array(rec['ghi_values_normalized'][:Hdelta_t + 1])

        return (
            image_seq.astype(np.float32),
            clear_sky_input.astype(np.float32),
            np.float32(ghi_now_input),
            solcast_input,
            ghi_target.astype(np.float32)
        )

    def tf_process_sample(k):
        image_seq, clear_sky_input, ghi_now_input, solcast_input, ghi_target = tf.py_function(
            func=process_sample,
            inp=[k],
            Tout=[tf.float32, tf.float32, tf.float32, tf.float32, tf.float32]
        )

        # Set shapes
        image_seq.set_shape([expected_num_images, 250, 250, 3])
        clear_sky_input.set_shape([Hdelta_t + 1])
        ghi_now_input.set_shape([])
        solcast_input.set_shape([3, Hdelta_t])
        ghi_target.set_shape([Hdelta_t + 1])

        inputs = {
            'image_sequence': image_seq,
            'clear_sky_input': clear_sky_input,
            'ghi_now_input': tf.expand_dims(ghi_now_input, axis=-1),
            'solcast_input': solcast_input
        }

        return inputs, ghi_target
    ds = tf.data.Dataset.from_tensor_slices(np.array(keys, dtype=np.int64))
    #ds = tf.data.Dataset.from_tensor_slices(tf.constant(keys, dtype=tf.string))
    ds = ds.map(tf_process_sample, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds


In [None]:
from tensorflow.keras import layers, Model, regularizers
import tensorflow as tf

def create_model_quantiles_with_ghi_and_solcast(
    num_previous_images,
    Hdelta_t,
    dropout_rate=0.1,
    l2_lambda=1e-5,
    lambda_mean_loss=1.0,
    lambda_quantile_loss=150.0,
    lambda_width_loss=5
):
    H = Hdelta_t + 1

    # --- inputs ---
    image_input = layers.Input(shape=(num_previous_images + 1, 250, 250, 3), name='image_sequence')
    clear_sky_input = layers.Input(shape=(H,), name='clear_sky_input')
    ghi_now_input = layers.Input(shape=(1,), name='ghi_now_input')
    solcast_input = layers.Input(shape=(3, Hdelta_t), name='solcast_input')  # [3, Hdelta_t]

    # --- deeper 3D-CNN backbone ---
    x = layers.Conv3D(32, (3, 3, 3), padding='same', activation='relu')(image_input)
    x = layers.Conv3D(32, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling3D((1, 2, 2))(x)

    x = layers.Conv3D(64, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.Conv3D(64, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling3D((1, 2, 2))(x)

    # --- stacked ConvLSTM2D ---
    x = layers.ConvLSTM2D(64, (3, 3), padding='same', return_sequences=True, activation='tanh')(x)
    x = layers.ConvLSTM2D(64, (3, 3), padding='same', return_sequences=False, activation='tanh')(x)
    x = layers.LayerNormalization()(x)
    x = layers.SpatialDropout2D(dropout_rate)(x)

    # --- spatial pooling ---
    x = layers.GlobalAveragePooling2D()(x)

    # --- deeper clear-sky feature embedding ---
    cs = layers.Dense(128, activation='relu')(clear_sky_input)
    cs = layers.Dropout(dropout_rate)(cs)
    cs = layers.Dense(64, activation='relu')(cs)
    cs = layers.Dense(32, activation='relu')(cs)

    # --- solcast embedding ---
    s = layers.Flatten()(solcast_input)  # flatten [3, Hdelta_t]
    s = layers.Dense(64, activation='relu')(s)
    s = layers.Dropout(dropout_rate)(s)
    s = layers.Dense(32, activation='relu')(s)

    # --- concatenate all features ---
    x = layers.Concatenate()([x, cs, ghi_now_input, s])

    # --- MLP head with residual connection ---
    mlp_input = x
    x = layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(l2_lambda))(x)
    x = layers.Dropout(dropout_rate)(x)
    x = layers.Dense(128, activation='relu')(x)

    mlp_proj = layers.Dense(128)(mlp_input)  # Project mlp_input to match x
    x = layers.Add()([x, mlp_proj])  # Residual connection

    x = layers.Dense(64, activation='relu')(x)

    # --- output: [H, 3] for [q10, mean, q90] per horizon ---
    head = layers.Dense(3 * H, activation='linear')(x)
    output = layers.Reshape((H, 3), name='quantiles')(head)

    model = Model(
        inputs=[image_input, clear_sky_input, ghi_now_input, solcast_input],
        outputs=output
    )

    # --- losses and optimizer ---
    weighted_mse_metric = weighted_compressed_mse(
        H, initial_weight=1.0, second_weight=1.0, decay=0.5
    )
    weighted_quantile_metric = weighted_quantile_loss(
        H, initial_weight=1.0, second_weight=1.0, decay=0.5
    )

    initial_lr = 3e-4
    drop_every_x_steps = 210 * 5
    lr_schedule = StepDecay(initial_lr, drop_every_x_steps)
    optimizer = tf.keras.optimizers.AdamW(learning_rate=lr_schedule)

    loss_fn = combined_quantile_loss_factory(
        H, lambda_mean=lambda_mean_loss, lambda_q=lambda_quantile_loss, lambda_width=lambda_width_loss
    )

    def mean_prediction_interval_width(y_true, y_pred):
        q10, _, q90 = tf.unstack(y_pred, axis=-1)
        width = q90 - q10
        return tf.reduce_mean(width)

    model.compile(
        optimizer=optimizer,
        loss=loss_fn,
        metrics=[
            mean_channel_mae,
            weighted_mse_metric,
            weighted_quantile_metric,
            mean_prediction_interval_width
        ]
    )

    model.summary()
    return model


In [None]:
# --- 3) custom MAE that only looks at the mean channel ---
def mean_channel_mae(y_true, y_pred):
    """
    y_pred[...,1] is the mean prediction.
    """
    mean_pred = y_pred[..., 1]
    return tf.reduce_mean(tf.abs(y_true - mean_pred))
def weighted_compressed_mse(H, initial_weight=1.0, second_weight=1.0, decay=0.5):
    """
    Returns a metric function that computes a weighted, scaled MSE across horizons.
    Weights:
      - Horizon 0 → initial_weight
      - Horizon i ≥ 1 → second_weight * decay^(i-1)
    """
    def Mean_metric(y_true, y_pred):
        mu = y_pred[..., 1]  # extract predicted mean, shape [batch, H]
        error = (y_true - mu) * 30
        mse_per_horizon = tf.reduce_mean(tf.square(error), axis=0)  # shape [H]

        # Compute weights
        decay_factors = tf.pow(decay, tf.cast(tf.range(H - 1), tf.float32))
        remaining_weights = second_weight * decay_factors
        weights = tf.concat([[initial_weight], remaining_weights], axis=0)  # shape [H]

        weighted_mse = tf.reduce_sum(weights * mse_per_horizon)
        normalization = tf.reduce_sum(weights)
        return weighted_mse / normalization

    return Mean_metric
def weighted_quantile_loss(H, initial_weight=1.0, second_weight=1.0, decay=0.5):
    """
    Returns a metric function that computes the weighted average quantile loss
    (pinball loss) over q10 and q90 across horizons.
    """
    def Q_metric(y_true, y_pred):
        q10 = y_pred[..., 0]
        q90 = y_pred[..., 2]

        # Pinball loss per horizon
        pin10_h = tf.reduce_mean(quantile_loss(0.1, y_true, q10), axis=0)
        pin90_h = tf.reduce_mean(quantile_loss(0.9, y_true, q90), axis=0)

        # Construct weights
        decay_factors = tf.pow(decay, tf.cast(tf.range(H - 1), tf.float32))
        remaining_weights = second_weight * decay_factors
        weights = tf.concat([[initial_weight], remaining_weights], axis=0)  # shape [H]

        weighted_pinball = tf.reduce_sum(weights * (pin10_h + pin90_h))
        normalization = tf.reduce_sum(weights)
        return 50*weighted_pinball / normalization

    return Q_metric



In [None]:

class StepDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_lr, drop_every_x_steps):
        self.initial_lr = initial_lr
        self.drop_every_x_steps = drop_every_x_steps

    def __call__(self, step):
        factor = tf.math.floor(step / self.drop_every_x_steps)
        return self.initial_lr * tf.math.pow(0.3, factor)

In [None]:
def combined_quantile_loss_factory(
    H, lambda_mean=1.0, lambda_q=1.0, lambda_width=0.0,
    initial_weight=1.0, second_weight=1.0, decay=0.7):
    """
    Returns a loss(y_true, y_pred) that:
      - Combines MSE, quantile loss, and optional interval width penalty.
      - Applies initial_weight to horizon 0,
      - Applies second_weight * decay^(i-1) to horizons i >= 1.
    """
    def loss_fn(y_true, y_pred):
        q10, mu, q90 = tf.unstack(y_pred, axis=-1)  # [batch, H]

        # Per-horizon losses
        mse_h    = tf.reduce_mean(tf.square((y_true - mu) * 30), axis=0)
        pin10_h  = tf.reduce_mean(quantile_loss(0.1, y_true, q10), axis=0)
        pin90_h  = tf.reduce_mean(quantile_loss(0.9, y_true, q90), axis=0)
        width_h  = tf.reduce_mean(q90 - q10, axis=0)

        # Construct weights: [initial_weight, second_weight * decay^0, ..., decay^{H-2}]
        decay_factors = tf.pow(decay, tf.cast(tf.range(H - 1), tf.float32))  # [H-1]
        remaining_weights = second_weight * decay_factors
        weights = tf.concat([[initial_weight], remaining_weights], axis=0)  # shape [H]

        # Weighted losses
        weighted_mse = tf.reduce_sum(weights * mse_h)
        weighted_pinball = tf.reduce_sum(weights * (pin10_h + pin90_h))
        weighted_width = tf.reduce_sum(weights * width_h)

        total_loss = (
            lambda_mean * weighted_mse +
            lambda_q * weighted_pinball +
            lambda_width * weighted_width
        )
        normalization = tf.reduce_sum(weights)
        return total_loss / normalization

    return loss_fn



In [None]:
def quantile_loss(q, y_true, y_pred):
    """
    Pinball loss per element: L = max(q * e, (q-1) * e) where e = y_true - y_pred.
    Returns shape [batch, H].
    """
    e = y_true - y_pred
    return tf.maximum(q * e, (q - 1.0) * e)


In [None]:
expected_num_images=num_previous_images+1

In [None]:
def short_train_model(
    merged_data_dict,
    image_dict,
    batch_size,
    num_previous_images,
    Hdelta_t,
    expected_num_images,
    epochs=30,
    use_ghi_now=True
):
    """
    Trains a quantile model using 20% random subsets of the training data at each epoch,
    holds out 10% for validation and 10% for final testing. Prints per-horizon RMSE,
    plots forecast skill, and returns the trained model and history_all.
    """
    import numpy as np
    from sklearn.model_selection import train_test_split

    # Prepare keys, skipping initial/history and final horizon
    all_keys = sorted(merged_data_dict.keys())
    start = num_previous_images
    end = -Hdelta_t
    keys = np.array(all_keys)[start:end]

    # initial split: 80% train_val, 20% temp (to split into 10% val, 10% test)
    train_val_keys, temp_keys = train_test_split(
        keys, test_size=0.2, shuffle=True, random_state=42
    )
    # split temp into validation and test (each 10% of total)
    val_keys, test_keys = train_test_split(
        temp_keys, test_size=0.5, shuffle=True, random_state=42
    )

    # prepare static datasets for validation and testing
    val_ds = solcast_batchloader(
        val_keys, merged_data_dict, image_dict,
        batch_size, expected_num_images, Hdelta_t, use_ghi_now
    )
    test_ds = solcast_batchloader(
        test_keys, merged_data_dict, image_dict,
        batch_size, expected_num_images, Hdelta_t, use_ghi_now
    )

    # build model
    model = create_model_quantiles_with_ghi_and_solcast(
        num_previous_images, Hdelta_t
    )

    # Initialize history_all dictionary
    history_all = {}

    # training loop: each epoch uses fresh 20% sample of train_val keys
    for epoch in range(epochs):
        # sample 20% of train_val_keys
        k = int(0.1 * len(train_val_keys))
        sub_keys = np.random.choice(train_val_keys, size=k, replace=False)

        # build train dataset for this epoch
        train_ds = solcast_batchloader(
            sub_keys, merged_data_dict, image_dict,
            batch_size, expected_num_images, Hdelta_t, use_ghi_now
        )
        steps_per_epoch = train_ds.cardinality().numpy()
        val_steps = val_ds.cardinality().numpy()

        print(f"Epoch {epoch+1}/{epochs}: training on {k} samples")
        hist = model.fit(
            train_ds,
            epochs=epoch+1,
            initial_epoch=epoch,
            steps_per_epoch=steps_per_epoch,
            validation_data=val_ds,
            validation_steps=val_steps,
            verbose=2,
            shuffle=False
        )

        # Accumulate history
        for key, values in hist.history.items():
            history_all.setdefault(key, []).extend(values)

    # final evaluation on test set
    y_trues, y_q10s, y_means, y_q90s = [], [], [], []
    for x_batch, y_batch in test_ds:
        y_trues.append(y_batch.numpy())
        preds = model.predict(x_batch, verbose=0)
        y_q10s.append(preds[..., 0])
        y_means.append(preds[..., 1])
        y_q90s.append(preds[..., 2])

    y_true = np.vstack(y_trues)
    y_q10 = np.vstack(y_q10s)
    y_mean = np.vstack(y_means)
    y_q90 = np.vstack(y_q90s)

    # scale back using max clear-sky
    max_factors = np.array([merged_data_dict[k]['max_clear_sky_ghi'] for k in test_keys])
    max_factors = np.repeat(max_factors[:, None], Hdelta_t+1, axis=1)
    y_true_real = y_true * max_factors[:len(y_true)]
    y_mean_real = y_mean * max_factors[:len(y_true)]

    # compute and print RMSE per horizon
    rmse = np.sqrt(np.mean((y_true_real - y_mean_real)**2, axis=0))
    print("\nTest RMSE per horizon:")
    for i, v in enumerate(rmse):
        print(f"  Horizon {i}: RMSE = {v:.3f}")

    # plot forecast skill for horizons 1..Hdelta_t
    forecast_skill_plot(merged_data_dict, test_keys, Hdelta_t, rmse[1:])

    return model, history_all


In [None]:
model,history=short_train_model(
    merged_data_dict,
    image_dict,
    16,
    num_previous_images,
    Hdelta_t,
    expected_num_images,
    epochs=60,
    use_ghi_now=True
)

In [None]:
dropout_rate=0.1
l2_lambda=1e-5
lambda_mean_loss=1.0
lambda_quantile_loss=150.0
lambda_width_loss=5

In [None]:
optimizer = tf.keras.optimizers.AdamW()
H = Hdelta_t + 1

loss_fn = combined_quantile_loss_factory(
    H, lambda_mean=lambda_mean_loss, lambda_q=lambda_quantile_loss, lambda_width=lambda_width_loss
)
weighted_mse_metric = weighted_compressed_mse(
    H, initial_weight=1.0, second_weight=1.0, decay=0.5
)
weighted_quantile_metric = weighted_quantile_loss(
    H, initial_weight=1.0, second_weight=1.0, decay=0.5
)
def mean_prediction_interval_width(y_true, y_pred):
    q10, _, q90 = tf.unstack(y_pred, axis=-1)
    width = q90 - q10
    return tf.reduce_mean(width)

model.compile(
    optimizer=optimizer,
    loss=loss_fn,
    metrics=[
        mean_channel_mae,
        weighted_mse_metric,
        weighted_quantile_metric,
        mean_prediction_interval_width
    ])

In [None]:
# Suppose `model` is your trained tf.keras.Model
model.save('with_solcast_model.h5')  
# This will create a SavedModel directory at './my_quantile_model'


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
from datetime import timedelta

def Q_short_cross_validate_model_withGHI_and_solcast(
    merged_data_dict,
    image_dict,
    batch_size,
    num_previous_images,
    Hdelta_t,
    expected_num_images,
    x_splits=5,
    epochs=30,
    train_fraction=0.3,
    use_ghi_now=True
):
    all_keys = sorted(merged_data_dict.keys())
    start = num_previous_images * delta_t
    end = -Hdelta_t * delta_t
    keys = np.array(all_keys)[start:end]
    split_size = len(keys) // x_splits

    skill_splits       = []
    picp_model_splits  = []
    pinaw_model_splits = []
    legend_labels      = []

    for split_idx in range(x_splits):
        print(f"\n=== Split {split_idx+1}/{x_splits} ===")
        test_keys  = keys[split_idx*split_size:(split_idx+1)*split_size]
        train_keys = np.setdiff1d(keys, test_keys)

        val_n    = int(0.1 * len(train_keys))
        val_keys = train_keys[:val_n]
        tr_keys  = train_keys[val_n:]

        val_ds  = solcast_batchloader(val_keys, merged_data_dict, image_dict,
                                      batch_size, expected_num_images, Hdelta_t, use_ghi_now)
        test_ds = solcast_batchloader(test_keys, merged_data_dict, image_dict,
                                      batch_size, expected_num_images, Hdelta_t, use_ghi_now)

        model = create_model_quantiles_with_ghi_and_solcast(num_previous_images, Hdelta_t)

        for epoch in range(epochs):
            k  = int(train_fraction * len(tr_keys))
            bs = np.random.choice(tr_keys, k, replace=False)
            train_ds = solcast_batchloader(bs, merged_data_dict, image_dict,
                                           batch_size, expected_num_images, Hdelta_t, use_ghi_now)

            model.fit(train_ds,
                      epochs=epoch+1, initial_epoch=epoch,
                      steps_per_epoch=train_ds.cardinality().numpy(),
                      validation_data=val_ds,
                      validation_steps=val_ds.cardinality().numpy(),
                      verbose=2, shuffle=False)

        y_true_blocks, q10_blocks, q50_blocks, q90_blocks = [], [], [], []
        timestamps = []
        for x_b, y_b in test_ds:
            preds = model.predict(x_b, verbose=0)
            y_true_blocks.append(y_b.numpy())
            q10_blocks.append(preds[..., 0])
            q50_blocks.append(preds[..., 1])
            q90_blocks.append(preds[..., 2])
        y_true = np.vstack(y_true_blocks)
        q10    = np.vstack(q10_blocks)
        q50    = np.vstack(q50_blocks)
        q90    = np.vstack(q90_blocks)

        for k in test_keys:
            timestamps.append(merged_data_dict[k]['key_timestamp'])
        timestamps = pd.to_datetime(timestamps)[:y_true.shape[0]]
        split_start = timestamps.min().strftime('%Y-%m-%d')
        split_end   = timestamps.max().strftime('%Y-%m-%d')
        legend_labels.append(f"Split {split_idx+1}: {split_start} → {split_end}")

        maxfs = np.array([merged_data_dict[k]['max_clear_sky_ghi'] for k in test_keys])
        maxfs = np.repeat(maxfs[:, None], Hdelta_t+1, axis=1)[:y_true.shape[0]]
        y_true_r = y_true * maxfs
        q10_r    = q10    * maxfs
        q50_r    = q50    * maxfs
        q90_r    = q90    * maxfs

        rmse_model = np.sqrt(np.mean((y_true_r - q50_r)**2, axis=0))
        print(f"→ RMSE model: {np.round(rmse_model, 2)}")

        # Forecast skill vs Solcast (optional: replace with Solcast RMSE here)
        skill = [1 - rmse_model[h] / rmse_model[h] for h in range(1, Hdelta_t+1)]  # dummy comparison
        skill_splits.append(skill)

        # PICP and PINAW
        yt = y_true_r[:, 1:]
        Lm = q10_r[:, 1:]
        Um = q90_r[:, 1:]
        Mf = maxfs[:, 1:]

        picp_m  = ((yt >= Lm) & (yt <= Um)).mean(axis=0)
        pinaw_m = ((Um - Lm) / Mf).mean(axis=0)

        picp_model_splits.append(picp_m)
        pinaw_model_splits.append(pinaw_m)

    # --- PLOTTING ---
    plt.figure(figsize=(9,6))
    for i, skill in enumerate(skill_splits):
        plt.plot([(h+1)*delta_t for h in range(len(skill))], skill,
                 marker='o', label=legend_labels[i])
    plt.axhline(0, color='r', linestyle='--')
    plt.xlabel("Horizon (min)")
    plt.ylabel("Forecast Skill")
    plt.title("Forecast Skill vs Horizon")
    plt.legend()
    plt.grid(True)
    plt.savefig('q_ghi_solcast_skill.png')
    plt.show()

    plt.figure(figsize=(9,6))
    for i, picp in enumerate(picp_model_splits):
        plt.plot([(h+1)*delta_t for h in range(len(picp))], picp,
                 marker='o', label=legend_labels[i])
    plt.xlabel("Horizon (min)")
    plt.ylabel("PICP (Model)")
    plt.title("PICP across horizons")
    plt.grid(True)
    plt.legend()
    
    plt.savefig('q_ghi_solcast_picp.png')
    plt.show()

    plt.figure(figsize=(9,6))
    for i, pinaw in enumerate(pinaw_model_splits):
        plt.plot([(h+1)*delta_t for h in range(len(pinaw))], pinaw,
                 marker='o', label=legend_labels[i])
    plt.xlabel("Horizon (min)")
    plt.ylabel("PINAW (Model)")
    plt.title("PINAW across horizons")
    plt.grid(True)
    plt.legend()
    
    plt.savefig('q_ghi_solcast_pinaw.png')
    plt.show()

    return {
        "skill": skill_splits,
        "picp_model": picp_model_splits,
        "pinaw_model": pinaw_model_splits
    }


In [None]:
from tensorflow.keras import layers, Model, regularizers
import tensorflow as tf

def create_model_quantiles_with_ghi_and_solcast(
    num_previous_images,
    Hdelta_t,
    dropout_rate=0.1,
    l2_lambda=1e-5,
    lambda_mean_loss=1.0,
    lambda_quantile_loss=150.0,
    lambda_width_loss=5
):
    H = Hdelta_t + 1

    # --- inputs ---
    image_input = layers.Input(shape=(num_previous_images + 1, 250, 250, 3), name='image_sequence')
    clear_sky_input = layers.Input(shape=(H,), name='clear_sky_input')
    ghi_now_input = layers.Input(shape=(1,), name='ghi_now_input')
    solcast_input = layers.Input(shape=(3, Hdelta_t), name='solcast_input')  # [3, Hdelta_t]

    # --- deeper 3D-CNN backbone ---
    x = layers.Conv3D(32, (3, 3, 3), padding='same', activation='relu')(image_input)
    x = layers.Conv3D(32, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling3D((1, 2, 2))(x)

    x = layers.Conv3D(64, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.Conv3D(64, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling3D((1, 2, 2))(x)

    # --- stacked ConvLSTM2D ---
    x = layers.ConvLSTM2D(64, (3, 3), padding='same', return_sequences=True, activation='tanh')(x)
    x = layers.ConvLSTM2D(64, (3, 3), padding='same', return_sequences=False, activation='tanh')(x)
    x = layers.LayerNormalization()(x)
    x = layers.SpatialDropout2D(dropout_rate)(x)

    # --- spatial pooling ---
    x = layers.GlobalAveragePooling2D()(x)

    # --- deeper clear-sky feature embedding ---
    cs = layers.Dense(128, activation='relu')(clear_sky_input)
    cs = layers.Dropout(dropout_rate)(cs)
    cs = layers.Dense(64, activation='relu')(cs)
    cs = layers.Dense(32, activation='relu')(cs)

    # --- solcast embedding ---
    s = layers.Flatten()(solcast_input)  # flatten [3, Hdelta_t]
    s = layers.Dense(64, activation='relu')(s)
    s = layers.Dropout(dropout_rate)(s)
    s = layers.Dense(32, activation='relu')(s)

    # --- concatenate all features ---
    x = layers.Concatenate()([x, cs, ghi_now_input, s])

    # --- MLP head with residual connection ---
    mlp_input = x
    x = layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(l2_lambda))(x)
    x = layers.Dropout(dropout_rate)(x)
    x = layers.Dense(128, activation='relu')(x)

    mlp_proj = layers.Dense(128)(mlp_input)  # Project mlp_input to match x
    x = layers.Add()([x, mlp_proj])  # Residual connection

    x = layers.Dense(64, activation='relu')(x)

    # --- output: [H, 3] for [q10, mean, q90] per horizon ---
    head = layers.Dense(3 * H, activation='linear')(x)
    output = layers.Reshape((H, 3), name='quantiles')(head)

    model = Model(
        inputs=[image_input, clear_sky_input, ghi_now_input, solcast_input],
        outputs=output
    )

    # --- losses and optimizer ---
    weighted_mse_metric = weighted_compressed_mse(
        H, initial_weight=1.0, second_weight=1.0, decay=0.5
    )
    weighted_quantile_metric = weighted_quantile_loss(
        H, initial_weight=1.0, second_weight=1.0, decay=0.5
    )

    initial_lr = 3e-4
    drop_every_x_steps = 210 * 5
    lr_schedule = StepDecay(initial_lr, drop_every_x_steps)
    optimizer = tf.keras.optimizers.AdamW(learning_rate=lr_schedule)

    loss_fn = combined_quantile_loss_factory(
        H, lambda_mean=lambda_mean_loss, lambda_q=lambda_quantile_loss, lambda_width=lambda_width_loss
    )

    def mean_prediction_interval_width(y_true, y_pred):
        q10, _, q90 = tf.unstack(y_pred, axis=-1)
        width = q90 - q10
        return tf.reduce_mean(width)

    model.compile(
        optimizer=optimizer,
        loss=loss_fn,
        metrics=[
            mean_channel_mae,
            weighted_mse_metric,
            weighted_quantile_metric,
            mean_prediction_interval_width
        ]
    )

    model.summary()
    return model


In [None]:
Q_short_cross_validate_model_withGHI_and_solcast(
    merged_data_dict,
    image_dict,
    32,
    num_previous_images,
    Hdelta_t,
    expected_num_images,
    x_splits=5,
    epochs=1,
    train_fraction=0.1,
    use_ghi_now=True
)

In [None]:
import numpy as np

keys = list(merged_data_dict.keys())
# infer Hdelta_t from lengths
Hdelta_t = len(merged_data_dict[keys[0]]['solcast_predictions'][0])

picp = np.zeros(Hdelta_t)

for h in range(Hdelta_t):
    hits = []
    for k in keys:
        sol_preds = merged_data_dict[k]['solcast_predictions']  
        # sol_preds[0] = 10th, [1]=mean, [2]=90th, each a list of length Hdelta_t
        sol10  = sol_preds[0][h]
        sol50  = sol_preds[1][h]
        sol90  = sol_preds[2][h]
        # true normalized GHI at horizon h+1
        y_true_norm = merged_data_dict[k]['ghi_values_normalized'][h+1]
        hits.append((y_true_norm >= sol10) and (y_true_norm <= sol90))
    picp[h] = np.mean(hits)

# Print per-horizon PICP
for h, p in enumerate(picp, start=1):
    print(f"Horizon {h*delta_t} min: PICP = {p:.3f}")

print(f"\nOverall PICP (80% interval): {picp.mean():.3f}")

