### Get Models and Weights

In [None]:

!wget https://raw.githubusercontent.com/YuvalRozner/WeatherNet/main/Backend/Model_Pytorch/utils/models_for_inference.zip
    

### Unzip Weights

In [None]:

import zipfile
zip_file_path = "/content/models_for_inference.zip"

# Extract the zip file to a temporary directory
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall("/content/models_for_inference")


### `data.py`

In [None]:
import pickle 
import pandas as pd
import os
import json
from sklearn.preprocessing import StandardScaler
import pickle
import numpy as np
import torch
from tqdm import tqdm
"""
this file let you load the data of the stations from the pkl files 
and the coordinates of the stations from the json file
"""

def normalize_coordinates(x_coords, y_coords):
    """
    Normalize the X and Y coordinates to the range [0, 1].

    Args:
        x_coords (numpy.ndarray): Array of X coordinates in meters.
        y_coords (numpy.ndarray): Array of Y coordinates in meters.

    Returns:
        tuple: Normalized X and Y coordinates as torch tensors.
    """
    x_min, x_max = x_coords.min(), x_coords.max()
    y_min, y_max = y_coords.min(), y_coords.max()

    x_normalized = (x_coords - x_min) / (x_max - x_min)
    y_normalized = (y_coords - y_min) / (y_max - y_min)

    # Convert to torch tensors
    x_normalized = torch.tensor(x_normalized, dtype=torch.float32).unsqueeze(1)  # [num_stations, 1]
    y_normalized = torch.tensor(y_normalized, dtype=torch.float32).unsqueeze(1)  # [num_stations, 1]

    return x_normalized, y_normalized 
def drop_nan_rows_multiple_custom(df_list,custom_na =['-']):
    """
    Removes rows from all DataFrames in the list where any DataFrame has NaN or custom NaN representations in any column.

    Parameters:
    df_list (List[pd.DataFrame]): List of DataFrames to process.
    reset_indices (bool): Whether to reset the index after dropping rows. Defaults to True.
    custom_na (List[str]): List of custom strings to be treated as NaN. Defaults to ['-'].

    Returns:
    List[pd.DataFrame]: List of cleaned DataFrames.
    """
    if not df_list:
        raise ValueError("The list of DataFrames is empty.")

    # Ensure all DataFrames have the same number of rows
    num_rows = df_list[0].shape[0]
    for df in df_list:
        if df.shape[0] != num_rows:
            raise ValueError("All DataFrames must have the same number of rows.")

    # Step 0: Replace custom NaN representations with np.nan
    cleaned_df_list_initial = []
    for df in df_list:
        df_cleaned = df.replace(custom_na, np.nan)
        cleaned_df_list_initial.append(df_cleaned)

    # Step 1: Identify rows with any NaN in each DataFrame
    nan_indices_list = [df.isnull().any(axis=1) for df in cleaned_df_list_initial]

    # Step 2: Combine the indices where NaNs are present in any DataFrame
    combined_nan = pd.Series([False] * num_rows, index=df_list[0].index)
    for nan_mask in nan_indices_list:
        combined_nan = combined_nan | nan_mask

    # Get the indices to drop
    indices_to_drop = combined_nan[combined_nan].index

    # Step 3: Drop the identified indices from all DataFrames
    cleaned_df_list = []
    for df in tqdm(cleaned_df_list_initial, desc="Dropping NaN rows"):
        cleaned_df = df.drop(indices_to_drop)
        if True:
            cleaned_df = cleaned_df.reset_index(drop=True)
        cleaned_df_list.append(cleaned_df)

    return cleaned_df_list
def drop_nan_rows_multiple(df_list, reset_indices=True):
    """
    Removes rows from all DataFrames in the list where any DataFrame has NaN in any column.
    
    Parameters:
    df_list (List[pd.DataFrame]): List of DataFrames to process.
    reset_indices (bool): Whether to reset the index after dropping rows. Defaults to True.
    
    Returns:
    List[pd.DataFrame]: List of cleaned DataFrames.
    """
    if not df_list:
        raise ValueError("The list of DataFrames is empty.")
    #for df in df_list:
    #    df.reset_index(drop=True, inplace=True)
    # Ensure all DataFrames have the same number of rows
    num_rows = df_list[0].shape[0]
    for df in df_list:
        if df.shape[0] != num_rows:
            raise ValueError("All DataFrames must have the same number of rows.")
    
    # Step 1: Identify rows with any NaN in each DataFrame
    nan_indices_list = [df.isnull().any(axis=1) for df in df_list]
    
    # Step 2: Combine the indices where NaNs are present in any DataFrame
    combined_nan = pd.Series([False] * num_rows, index=df_list[0].index)
    for nan_mask in nan_indices_list:
        combined_nan = combined_nan | nan_mask
    
    # Get the indices to drop
    indices_to_drop = combined_nan[combined_nan].index
    
    # Step 3: Drop the identified indices from all DataFrames
    cleaned_df_list = []
    for df in tqdm(df_list, desc="Dropping NaN rows"):
        cleaned_df = df.drop(indices_to_drop)
        if reset_indices:
            cleaned_df = cleaned_df.reset_index(drop=True)
        cleaned_df_list.append(cleaned_df)
    
    return cleaned_df_list

# Define the normalization function
def normalize_coordinates(x_coords, y_coords):
    """
    Normalize the X and Y coordinates to the range [0, 1].
    """
    x_min, x_max = x_coords.min(), x_coords.max()
    y_min, y_max = y_coords.min(), y_coords.max()

    x_normalized = (x_coords - x_min) / (x_max - x_min)
    y_normalized = (y_coords - y_min) / (y_max - y_min)

    # Convert to torch tensors
    x_normalized = torch.tensor(x_normalized, dtype=torch.float32).unsqueeze(1)  # [num_stations, 1]
    y_normalized = torch.tensor(y_normalized, dtype=torch.float32).unsqueeze(1)  # [num_stations, 1]

    return x_normalized, y_normalized

def timeEncode(dataframes):
    day = 24*60*60
    year = (365.2425)*day

    for df in dataframes:
        if 'Date Time' in df.columns:
            timestamp_s = df['Date Time'].map(pd.Timestamp.timestamp)
            df['Day sin'] = np.sin(timestamp_s * (2 * np.pi / day))
            df['Day cos'] = np.cos(timestamp_s * (2 * np.pi / day))
            df['Year sin'] = np.sin(timestamp_s * (2 * np.pi / year))
            df['Year cos'] = np.cos(timestamp_s * (2 * np.pi / year))
            df.drop(columns=['Date Time'], inplace=True)


def preprocessing_tensor_df(df):
    """
    Apply the same preprocessing steps as during training.
    """
    print("preproccessing data...")
    # Slice the DataFrame and create a copy to avoid SettingWithCopyWarning
    df = df[5::6].copy()
    date_time = pd.to_datetime(df.pop('Date Time'), format='%d.%m.%Y %H:%M:%S')

    # Handle 'wv (m/s)'
    wv = df['wv (m/s)']
    bad_wv = wv == -9999.0
    df.loc[bad_wv, 'wv (m/s)'] = 0.0  # Use .loc to modify the original DataFrame
    wv = df.pop('wv (m/s)')

    # Handle 'max. wv (m/s)'
    max_wv = df['max. wv (m/s)']
    bad_max_wv = max_wv == -9999.0
    df.loc[bad_max_wv, 'max. wv (m/s)'] = 0.0  # Use .loc to modify the original DataFrame
    max_wv = df.pop('max. wv (m/s)')

    # Convert to radians.
    wd_rad = df.pop('wd (deg)') * np.pi / 180

    # Calculate wind x and y components using .loc
    df.loc[:, 'Wx'] = wv * np.cos(wd_rad)
    df.loc[:, 'Wy'] = wv * np.sin(wd_rad)
    df.loc[:, 'max Wx'] = max_wv * np.cos(wd_rad)
    df.loc[:, 'max Wy'] = max_wv * np.sin(wd_rad)

    # Time-based features
    timestamp_s = date_time.map(pd.Timestamp.timestamp)
    day = 24 * 60 * 60
    year = 365.2425 * day

    df.loc[:, 'Day sin'] = np.sin(timestamp_s * (2 * np.pi / day))
    df.loc[:, 'Day cos'] = np.cos(timestamp_s * (2 * np.pi / day))
    df.loc[:, 'Year sin'] = np.sin(timestamp_s * (2 * np.pi / year))
    df.loc[:, 'Year cos'] = np.cos(timestamp_s * (2 * np.pi / year))

    return df

def normalize_data(train_data, val_data, scaler_path='./scaler.pkl'):
    """
    Fit a StandardScaler on the training data and transform both train and val data.
    Save the scaler to disk for future use.

    Args:
        train_data (np.ndarray): Training data.
        val_data (np.ndarray): Validation data.
        scaler_path (str): Path to save the scaler.

    Returns:
        train_data_scaled (np.ndarray): Scaled training data.
        val_data_scaled (np.ndarray): Scaled validation data.
        scaler (StandardScaler): Fitted scaler object.
    """
    scaler = StandardScaler()
    scaler.fit(train_data)

    train_data_scaled = scaler.transform(train_data)
    val_data_scaled = scaler.transform(val_data)

    # Save the scaler
    with open(scaler_path, 'wb') as f:
        pickle.dump(scaler, f)

    print(f"Scaler saved to {scaler_path}")

    return train_data_scaled, val_data_scaled, scaler

def preprocessing_our_df(df):
    """
    Apply the same preprocessing steps as during training.
    """
    print("preproccessing data...")
    df = df[5::6].copy()
    # drop nan
    df = df.dropna()
    return df

def return_and_save_scaler_normalize_data(train_data, val_data, scaler_path='./scaler.pkl'):
   
    scaler = StandardScaler()
    scaler.fit(train_data)
    
    # Save the scaler
    with open(scaler_path, 'wb') as f:
        pickle.dump(scaler, f)
    
    print(f"Scaler saved to {scaler_path}")
    
    return scaler

def normalize_data_independent(train_data, val_data, scaler_dir='./scalers'):
    """
    Fit a StandardScaler per station on the training data and transform both train and val data.
    Save each scaler to disk for future use.
    
    Args:
        train_data (np.ndarray): Training data of shape (T_train, num_stations, num_features).
        val_data (np.ndarray): Validation data of shape (T_val, num_stations, num_features).
        scaler_dir (str): Directory path to save the scalers.
        
    Returns:
        train_data_scaled (np.ndarray): Scaled training data of shape (T_train, num_stations, num_features).
        val_data_scaled (np.ndarray): Scaled validation data of shape (T_val, num_stations, num_features).
        scalers (list of StandardScaler): List containing a scaler for each station.
    """
    if not os.path.exists(scaler_dir):
        os.makedirs(scaler_dir)
    
    T_train, num_stations, num_features = train_data.shape
    T_val = val_data.shape[0]
    
    # Initialize arrays to hold scaled data
    train_data_scaled = np.zeros_like(train_data)
    val_data_scaled = np.zeros_like(val_data)
    
    scalers = []
    
    for station_idx in range(num_stations):
        scaler = StandardScaler()
        
        # Extract training data for the current station
        train_station_data = train_data[:, station_idx, :]  # Shape: (T_train, num_features)
        
        # Fit the scaler on training data
        scaler.fit(train_station_data)
        scalers.append(scaler)
        
        # Transform training and validation data for the current station
        train_data_scaled[:, station_idx, :] = scaler.transform(train_station_data)
        val_data_scaled[:, station_idx, :] = scaler.transform(val_data[:, station_idx, :])
        
        # Save the scaler for the current station
        scaler_path = os.path.join(scaler_dir, f'scaler_station_{station_idx}.pkl')
        with open(scaler_path, 'wb') as f:
            pickle.dump(scaler, f)
        print(f"Scaler for Station {station_idx} saved to {scaler_path}")
    
    return train_data_scaled, val_data_scaled, scalers

def normalize_data_collective(train_data, val_data, scaler_path='./scaler.pkl'):
    """
    Fit a single StandardScaler across all stations and features.
    
    Args:
        train_data (np.ndarray): Training data of shape (T_train, num_stations, num_features).
        val_data (np.ndarray): Validation data of shape (T_val, num_stations, num_features).
        scaler_path (str): Path to save the scaler.
        
    Returns:
        train_scaled (np.ndarray), val_scaled (np.ndarray), scaler (StandardScaler)
    """
    T_train, num_stations, num_features = train_data.shape
    T_val = val_data.shape[0]
    
    # Reshape to (T_train*num_stations, num_features)
    train_reshaped = train_data.reshape(-1, num_features)
    val_reshaped = val_data.reshape(-1, num_features)
    
    scaler = StandardScaler()
    scaler.fit(train_reshaped)
    
    train_scaled = scaler.transform(train_reshaped).reshape(train_data.shape)
    val_scaled = scaler.transform(val_reshaped).reshape(val_data.shape)
    
    # Save the scaler
    with open(scaler_path, 'wb') as f:
        pickle.dump(scaler, f)
    print(f"Scaler saved to {scaler_path}")
    
    return train_scaled, val_scaled, scaler

def load_pkl_file(station_name):
    current_path = os.path.dirname(__file__)
    file_path = f"{current_path}\\..\\..\\..\\data\\{station_name}.pkl"
    try:
        with open(file_path, 'rb') as file:
            data = pickle.load(file)
        print(f"data succsesfuly loaded from {file_path}")
        return data
    except Exception as e:
        print(f"Failed to load file:\n{e}")
        return None

def openJsonFile():
    current_path = os.path.dirname(__file__)
    file_path = f"{current_path}\\..\\..\\data code files\\stations_details_updated.json"
    with open(file_path) as file:
        stations = json.load(file)
    return stations

def loadCoordinatesNewIsraelData(stations_details, station_name):
    for station_id, station_details in stations_details.items():
        if station_details["name"] == station_name:
            return station_details["coordinates_in_a_new_israe"]["east"], station_details["coordinates_in_a_new_israe"]["north"]

def loadData(station_names):
    stations_data = {}
    stations_details = openJsonFile()
    for station in station_names:
        stations_csv = load_pkl_file(station)
        station_coordinates = loadCoordinatesNewIsraelData(stations_details, station)
        stations_data[station] = stations_csv, station_coordinates
    return stations_data
"""
# example of use for this file
if __name__ == "__main__":
    # Load the data
    stations_data = loadData(["Afeq","Harashim"])
    if "Afeq" in stations_data:
        print("Data of Afeq:")
        print(stations_data["Afeq"][0].head())

        print("Coordinate of Afeq:")
        print(stations_data["Afeq"][1])

        print("First coordinate of Afeq:")
        print(stations_data["Afeq"][1][0])

        print("Second coordinate of Afeq:")
        print(stations_data["Afeq"][1][1])
    else:
        print("Afeq data not found")

    print("yey")

"""

### `constantsParams.py`

In [None]:
BEGINING_OF_YEAR = "01010000"
ENDING_OF_YEAR = "12312350"
START_YEAR = 2005
END_YEAR = 2024

DATA_DIRECTORY = "data/"

STATIONS_LIST = {
    "Newe Yaar": "186",
    "Tavor Kadoorie": "13",
    "Yavneel": "11",
    "En Hashofet": "67",
    "Eden Farm": "206",
    "Eshhar": "205",
    "Afula Nir Haemeq": "16"
}

columns = [
    "Date Time", "BP (hPa)", "DiffR (w/m^2)", "Grad (w/m^2)", "NIP (w/m^2)", "RH (%)",
    "TD (degC)", "TDmax (degC)", "TDmin (degC)", "WD (deg)", "WDmax (deg)",
    "WS (m/s)", "Ws1mm (m/s)", "Ws10mm (m/s)", "WSmax (m/s)", "STDwd (deg)"
]

COLUMN_PAIRS = [
    ("date", "Date Time"),
    ("BP", "BP (hPa)"),
    ("DiffR", "DiffR (w/m^2)"),
    ("Grad", "Grad (w/m^2)"),
    ("NIP", "NIP (w/m^2)"),
    ("RH", "RH (%)"),
    ("TD", "TD (degC)"),
    ("TDmax", "TDmax (degC)"),
    ("TDmin", "TDmin (degC)"),
    ("WD", "WD (deg)"),
    ("WDmax", "WDmax (deg)"),
    ("WS", "WS (m/s)"),
    ("WS1mm", "Ws1mm (m/s)"),
    ("Ws10mm", "Ws10mm (m/s)"),
    ("WSmax", "WSmax (m/s)"),
    ("STDwd", "STDwd (deg)")
]

COLUMNS_TO_REMOVE = ['date_for_sort', 'BP (hPa)', 'Time', 'Grad (w/m^2)', 'DiffR (w/m^2)', 'NIP (w/m^2)', 'Ws10mm (m/s)', 'Ws1mm (m/s)']

VALUES_TO_FILL = ['TD (degC)', 'TDmin (degC)', 'TDmax (degC)', 'RH (%)']

NA_VALUES = ['None', 'null', '-', '', ' ', 'NaN', 'nan', 'NAN']

### `import_and_process_data.py`

In [None]:
import pandas as pd# type: ignore
import requests# type: ignore
import json
from datetime import datetime
import os
import numpy as np# type: ignore
import matplotlib.pyplot as plt # type: ignore
from tqdm import tqdm # type: ignore


##    function used for the first time to get the data from the IMS ##
######################################################################################################################
def fetch_weather_data(station_id, start_date, end_date):
    url = f"https://ims.gov.il/he/envista_station_all_data_time_range/{station_id}/BP%26DiffR%26Grad%26NIP%26RH%26TD%26TDmax%26TDmin%26TW%26WD%26WDmax%26WS%26WS1mm%26Ws10mm%26Ws10maxEnd%26WSmax%26STDwd%26Rain/{start_date}/{end_date}/1/S"
    response = requests.get(url)
    data = json.loads(response.content)
    return data

def fetch_data_for_station(station_id, start_year, end_year):
    all_data = []
    for year in tqdm(range(start_year, end_year + 1), desc="Fetching data by year"):
        today_fore0 = f"{year}" + BEGINING_OF_YEAR
        today_fore23 = f"{year}" + ENDING_OF_YEAR
        data = fetch_weather_data(station_id, today_fore0, today_fore23)
        process_data(data)
        # Convert the data to a DataFrame and append to the list
        df = pd.DataFrame(data['data']['records'])
        all_data.append(df)
    # Concatenate all DataFrames
    combined_df = pd.concat(all_data, ignore_index=True)
    return combined_df

def fetch_data_for_station_manual_time_range(station_id, start_year, end_year, start_date, end_date):
    all_data = []
    for year in tqdm(range(start_year, end_year + 1), desc="Fetching data by year"):
        today_fore0 = f"{year}" + start_date
        today_fore23 = f"{year}" + end_date
        data = fetch_weather_data(station_id, today_fore0, today_fore23)
        process_data(data)
        # Convert the data to a DataFrame and append to the list
        df = pd.DataFrame(data['data']['records'])
        all_data.append(df)
    # Concatenate all DataFrames
    combined_df = pd.concat(all_data, ignore_index=True)
    return combined_df

def get_station_data(station_id, start_year, end_year):
    # Get all data for the station
    combined_df = fetch_data_for_station(station_id, start_year, end_year)
    # Convert the DataFrame back to the dictionary format expected by process_data
    data = {'data': {'records': combined_df.to_dict(orient='records')}}
    # Process the data
    process_data(data)
    return data

def get_station_data_manual_time_range(station_id, start_year, end_year, start_date, end_date):
    # Get all data for the station
    combined_df = fetch_data_for_station_manual_time_range(station_id, start_year, end_year, start_date, end_date)
    # Convert the DataFrame back to the dictionary format expected by process_data
    data = {'data': {'records': combined_df.to_dict(orient='records')}}
    # Process the data
    process_data(data)
    return data

def remove_unwanted_keys(data):
    # Remove 'sid', 'sname', and 'date_for_sort' from each record in data
    for record in data['data']['records']:
        # if 'date_for_sort' in record:
        #     del record['date_for_sort']
        if 'sid' in record:
            del record['sid']
        if 'TW' in record:
            del record['TW']
        if 'sname' in record:
            del record['sname']

def replace_column_names(data):
    # Replace the names of the columns by the pairs in COLUMN_PAIRS
    for record in data['data']['records']:
        for new_name, old_name in COLUMN_PAIRS:
            if new_name in record:
                record[old_name] = record.pop(new_name)

def process_data(data):
    remove_unwanted_keys(data)
    replace_column_names(data)

def get_data_of_stations_from_ims_by_constants_params(StationsList, startYear, endYear):
    dataframes = {}
    for station_name, station_id in StationsList.items():
      print(f"\n Downloading data for {station_name}")
      try:
          data = get_station_data(station_id, START_YEAR, END_YEAR)
          df = pd.DataFrame(data['data']['records'])
          dataframes[station_name] = df
      except IndexError as e:
          print(f"Error processing data for {station_name}: {e}")
    return dataframes

def get_data_of_stations_from_ims_manual_time_range(StationsList, startYear, endYear, startDate, endDate):
    dataframes = {}
    for station_name, station_id in StationsList.items():
      print(f"\n Downloading data for {station_name}")
      try:
          data = get_station_data_manual_time_range(station_id, startYear, endYear, startDate, endDate)
          df = pd.DataFrame(data['data']['records'])
          dataframes[station_name] = df
      except IndexError as e:
          print(f"Error processing data for {station_name}: {e}")
    return dataframes
######################################################################################################################


##    functions used for loading and saving the data to pickles   ##
######################################################################################################################
def save_dataframes_to_pickles(dataframes, DATA_DIRECTORY):
  for df_name, df in dataframes.items():
      file_path = os.path.join(DATA_DIRECTORY, f"{df_name}.pkl")
      df.to_pickle(file_path)
      print(f"Saved {df_name} to {file_path}")

def load_dataframes_from_pickles(DATA_DIRECTORY):
    data_files = [f for f in os.listdir(DATA_DIRECTORY) if f.endswith('.pkl')]

    dataframes = {}
    for file in tqdm(data_files, desc="Loading Pickle files of year data"):
        file_path = os.path.join(DATA_DIRECTORY, file)
        df_name = os.path.splitext(file)[0]
        dataframes[df_name] = pd.read_pickle(file_path)
    
    return dataframes
######################################################################################################################


##    function used for displaying ##
######################################################################################################################
def display_dataframes_heads(dataframes):
    for df_name, df in dataframes.items():
        print(f"Heads of {df_name}:")
        print(df.head())
        print("\n")

def display_wind_before_vectorize(dataframes):
    plt.figure(figsize=(14, 6))
    # Assuming 'dataframes' is a dictionary of DataFrames and we take the first one
    first_df_name = list(dataframes.keys())[0]
    first_df = dataframes[first_df_name]

    first_df['WD (deg)'] = first_df['WD (deg)'].replace(NA_VALUES, np.nan).infer_objects(copy=False)
    first_df['WS (m/s)'] = first_df['WS (m/s)'].replace(NA_VALUES, np.nan).infer_objects(copy=False)
    first_df['WDmax (deg)'] = first_df['WDmax (deg)'].replace(NA_VALUES, np.nan).infer_objects(copy=False)
    first_df['WSmax (m/s)'] = first_df['WSmax (m/s)'].replace(NA_VALUES, np.nan).infer_objects(copy=False)

    # Convert columns to numeric, forcing errors to NaN
    first_df['WD (deg)'] = pd.to_numeric(first_df['WD (deg)'], errors='coerce')
    first_df['WS (m/s)'] = pd.to_numeric(first_df['WS (m/s)'], errors='coerce')
    first_df['WDmax (deg)'] = pd.to_numeric(first_df['WDmax (deg)'], errors='coerce')
    first_df['WSmax (m/s)'] = pd.to_numeric(first_df['WSmax (m/s)'], errors='coerce')

    # Create subplots
    fig, ax = plt.subplots(1, 2, figsize=(14, 6))

    # Mask to filter out NaN values for wind
    mask_wind = first_df['WD (deg)'].notna() & first_df['WS (m/s)'].notna()

    # Create the 2D histogram plot for wind
    hist_wind = ax[0].hist2d(
        first_df.loc[mask_wind, 'WD (deg)'],
        first_df.loc[mask_wind, 'WS (m/s)'],
        bins=(50, 50),
        vmax=400
    )
    fig.colorbar(hist_wind[3], ax=ax[0])
    ax[0].set_xlabel('Wind Direction [deg]')
    ax[0].set_ylabel('Wind Velocity [m/s]')
    ax[0].set_title(f'2D Histogram of Wind for {first_df_name}')

    # Mask to filter out NaN values for gust
    mask_gust = first_df['WDmax (deg)'].notna() & first_df['WSmax (m/s)'].notna()

    # Create the 2D histogram plot for gust
    hist_gust = ax[1].hist2d(
        first_df.loc[mask_gust, 'WDmax (deg)'],
        first_df.loc[mask_gust, 'WSmax (m/s)'],
        bins=(50, 50),
        vmax=400
    )
    fig.colorbar(hist_gust[3], ax=ax[1])
    ax[1].set_xlabel('Gust Direction [deg]')
    ax[1].set_ylabel('Gust Velocity [m/s]')
    ax[1].set_title(f'2D Histogram of Gust for {first_df_name}')

    plt.tight_layout()
    plt.show()

def display_wind_after_vectorize(dataframes):
  # Plot 2D histogram plots of the wind ('Wind_x', 'Wind_y') and gust ('Gust_x', 'Gust_y') for the first dataframe
  first_df_name = list(dataframes.keys())[0]
  first_df = dataframes[first_df_name]

  fig, ax = plt.subplots(1, 2, figsize=(14, 6))

  # Mask to filter out NaN values for wind
  mask_wind = first_df['Wind_x'].notna() & first_df['Wind_y'].notna()

  # Create the 2D histogram plot for wind
  hist_wind = ax[0].hist2d(first_df.loc[mask_wind, 'Wind_x'], first_df.loc[mask_wind, 'Wind_y'], bins=(50, 50), vmax=400)
  fig.colorbar(hist_wind[3], ax=ax[0])
  ax[0].set_xlabel('Wind X Component')
  ax[0].set_ylabel('Wind Y Component')
  ax[0].set_title(f'2D Histogram of Wind Components for {first_df_name}')

  # Mask to filter out NaN values for gust
  mask_gust = first_df['Gust_x'].notna() & first_df['Gust_y'].notna()

  # Create the 2D histogram plot for gust
  hist_gust = ax[1].hist2d(first_df.loc[mask_gust, 'Gust_x'], first_df.loc[mask_gust, 'Gust_y'], bins=(50, 50), vmax=400)
  fig.colorbar(hist_gust[3], ax=ax[1])
  ax[1].set_xlabel('Gust X Component')
  ax[1].set_ylabel('Gust Y Component')
  ax[1].set_title(f'2D Histogram of Gust Components for {first_df_name}')

  plt.tight_layout()
  plt.show()

def print_length_of_dataframes(dataframes):
   print("\n  Length of dataframes:")
   for df_name, df in dataframes.items():
    print("Length of dataframe {}: {}".format(df_name, len(df)))
######################################################################################################################


##    function used for syncing the dataframes ##
######################################################################################################################
def sort_dataframes(dataframes):
  # Sort each dataframe by the column 'date_for_sort'
  for df_name, df in dataframes.items():
    dataframes[df_name] = df.sort_values(by='date_for_sort')

def slice_dataframes_beginning(dataframes, begin_date):
   for df_name, df in dataframes.items():
    index_to_keep = df[df['Date Time'] == begin_date].index
    if not index_to_keep.empty:
      index_to_keep = index_to_keep[0]
      print(f"Index to keep from dataframe {df_name}: {index_to_keep}")
      dataframes[df_name] = df.loc[index_to_keep:].reset_index(drop=True)
    else:
      print(f"{begin_date} not found in dataframe {df_name}")

def delete_rows_not_existing_in_all_dataframes(dataframes):
  # Find common 'Date Time' keys present in all dataframes
  common_keys = set.intersection(*(set(df['Date Time']) for df in dataframes.values()))
  # Initialize a dictionary to store the number of deleted rows
  deleted_rows = {}
  # Initialize a variable to store the latest deleted date across all dataframes
  latest_deleted_date = None
  # Remove rows not in common_keys and count deletions
  for df_name, df in dataframes.items():
    initial_count = len(df)
    df_filtered = df[df['Date Time'].isin(common_keys)].reset_index(drop=True)
    deleted = initial_count - len(df_filtered)
    dataframes[df_name] = df_filtered
    deleted_rows[df_name] = deleted
    if deleted > 0:
      max_deleted_date = df[~df['Date Time'].isin(common_keys)]['Date Time'].max()
      if latest_deleted_date is None or max_deleted_date > latest_deleted_date:
        latest_deleted_date = max_deleted_date

  # Return the number of rows deleted from each dataframe and the latest deleted date
  return deleted_rows, latest_deleted_date
######################################################################################################################


##    function used for preprocessing the dataframes ##
######################################################################################################################
def remove_unecessery_columns(dataframes, columns_to_remove):
  for df_name, df in dataframes.items():
    dataframes[df_name] = df.drop(columns=[col for col in columns_to_remove if col in df.columns])

def format_the_time_column(dataframes):
  for df_name, df in dataframes.items():
    if 'Date Time' in df.columns:
      df['Date Time'] = pd.to_datetime(df.pop('Date Time'), format="%d/%m/%Y %H:%M")
      df['Year'] = df['Date Time'].dt.year

def take_round_hours(dataframes):
  for df_name, df in dataframes.items():
    if 'Date Time' in df.columns:
      df = df[df['Date Time'].dt.minute == 0]
      dataframes[df_name] = df

def fill_1_missing_values(dataframes, values_to_fill, should_print=False):
  for df_name, df in dataframes.items():
    for value in values_to_fill:
      if value in df.columns:
        df[value] = df[value].replace(NA_VALUES, np.nan)
        
        # Fill NaN values wrapped with two non-NaN values:
        nan_wrapped_count = 0
        value_values = df[value].values
        for i in range(1, len(value_values) - 1):
          if pd.isna(value_values[i]) and not pd.isna(value_values[i - 1]) and not pd.isna(value_values[i + 1]):
            try:
              value_values[i] = (float(value_values[i - 1]) + float(value_values[i + 1])) / 2
              nan_wrapped_count += 1
            except ValueError as e:
              print(f"ValueError encountered in {df_name} at index {i} for column {value}: {e}")
        if should_print:
          print(f"Number of NaN values wrapped with two non-NaN values and filled in {df_name} station for column {value}: {nan_wrapped_count} which is {nan_wrapped_count / len(df) * 100}% of the data")

def fill_2_missing_values(dataframes, values_to_fill, should_print=False):
  for value in values_to_fill:
    for df_name, df in dataframes.items():
      if value in df.columns:
        df[value] = df[value].replace(NA_VALUES, np.nan)
        
        # Fill two consecutive NaN values wrapped with two non-NaN values:
        nan_wrapped_count = 0
        value_values = df[value].values
        i = 2  # Start from index 2 to ensure i-2 is valid
        while i < len(value_values) - 2:
          if (pd.isna(value_values[i]) and pd.isna(value_values[i+1]) and
            not pd.isna(value_values[i - 2]) and not pd.isna(value_values[i - 1]) and
            not pd.isna(value_values[i + 2]) and not pd.isna(value_values[i + 3])):
            
            val1 = float(value_values[i - 2])
            val2 = float(value_values[i - 1])
            val3 = float(value_values[i + 2])
            val4 = float(value_values[i + 3])
            
            # Determine trends
            trend_before = val2 < val1
            trend_after = val4 < val3
            if trend_before == trend_after:
              try:
                diff = val3 - val2
                value_values[i] = val2 + diff / 3
                value_values[i + 1] = val2 + diff * 2 / 3
                nan_wrapped_count += 2
                i += 2  # Skip the next index as it's already processed
                continue
              except ValueError as e:
                print(f"ValueError encountered in {df_name} at indices {i} and {i+1} for column {value}: {e}")
          i += 1
        if should_print:
          print(f"Number of NaN values wrapped with two non-NaN values and filled in {df_name} station for column {value}: {nan_wrapped_count} which is {nan_wrapped_count / len(df) * 100:.2f}% of the data")
    
def fill_3_missing_values(dataframes, values_to_fill, should_print=False):
  for value in values_to_fill:
    for df_name, df in dataframes.items():
      if value in df.columns:
        df[value] = df[value].replace(NA_VALUES, np.nan)
        
        # Fill three consecutive NaN values wrapped with two non-NaN values:
        nan_wrapped_count = 0
        value_values = df[value].values
        i = 2  # Start from index 2 to ensure i-2 is valid
        while i < len(value_values) - 4:
          if (pd.isna(value_values[i]) and pd.isna(value_values[i+1]) and pd.isna(value_values[i+2]) and
            not pd.isna(value_values[i - 2]) and not pd.isna(value_values[i - 1]) and
            not pd.isna(value_values[i + 3]) and not pd.isna(value_values[i + 4])):
            
            val1 = float(value_values[i - 2])
            val2 = float(value_values[i - 1])
            val3 = float(value_values[i + 3])
            val4 = float(value_values[i + 4])
            
            # Determine trends
            trend_before = val2 < val1
            trend_after = val4 < val3
            if trend_before == trend_after:
              try:
                diff = val3 - val2
                value_values[i] = val2 + diff / 4
                value_values[i + 1] = val2 + (diff * 2) / 4
                value_values[i + 2] = val2 + (diff * 3) / 4
                nan_wrapped_count += 3
                i += 3  # Skip the next indices as they're already processed
                continue
              except ValueError as e:
                print(f"ValueError encountered in {df_name} at indices {i}, {i+1}, and {i+2} for column {value}: {e}")
          i += 1
        if should_print:
          print(f"Number of NaN values wrapped with three non-NaN values and filled in {df_name} station for column {value}: {nan_wrapped_count} which is {nan_wrapped_count / len(df) * 100:.2f}% of the data")

def replace_time_with_cyclic_representation(dataframes):
  day = 24*60*60
  year = (365.2425)*day

  for df_name, df in dataframes.items():
    if 'Date Time' in df.columns:
      timestamp_s = df['Date Time'].map(pd.Timestamp.timestamp)
      df['Day sin'] = np.sin(timestamp_s * (2 * np.pi / day))
      df['Day cos'] = np.cos(timestamp_s * (2 * np.pi / day))
      df['Year sin'] = np.sin(timestamp_s * (2 * np.pi / year))
      df['Year cos'] = np.cos(timestamp_s * (2 * np.pi / year))
      dataframes[df_name] = df.drop(columns=['Date Time'])

def vectorize_wind(dataframes):
  print("vectorizing wind.")
  for df_name, df in dataframes.items():
    try:
      wind_speed = pd.to_numeric(df.pop('WS (m/s)'), errors='coerce')
      wind_direction_rad = pd.to_numeric(df.pop('WD (deg)'), errors='coerce') * np.pi / 180
      if wind_speed is not None and wind_direction_rad is not None:
        mask_wind = wind_speed.notna() & wind_direction_rad.notna()
        df['Wind_x'] = wind_speed * np.cos(wind_direction_rad)
        df['Wind_y'] = wind_speed * np.sin(wind_direction_rad)
        df.loc[~mask_wind, ['Wind_x', 'Wind_y']] = np.nan

      gust_speed = pd.to_numeric(df.pop('WSmax (m/s)'), errors='coerce')
      gust_direction_rad = pd.to_numeric(df.pop('WDmax (deg)'), errors='coerce') * np.pi / 180
      if gust_speed is not None and gust_direction_rad is not None:
        mask_gust = gust_speed.notna() & gust_direction_rad.notna()
        df['Gust_x'] = gust_speed * np.cos(gust_direction_rad)
        df['Gust_y'] = gust_speed * np.sin(gust_direction_rad)
        df.loc[~mask_gust, ['Gust_x', 'Gust_y']] = np.nan
    except KeyError as e:
      print(f"KeyError encountered in {df_name}: {e}")
    except TypeError as e:
      print(f"TypeError encountered in {df_name}: {e}")

def drop_nan_rows_multiple(dataframes, reset_indices=True):
    """
    Removes rows from all DataFrames in the list where any DataFrame has NaN in any column.
    
    Parameters:
    df_list (List[pd.DataFrame]): List of DataFrames to process.
    reset_indices (bool): Whether to reset the index after dropping rows. Defaults to True.
    
    Returns:
    Tuple[List[pd.DataFrame], Optional[pd.Timestamp]]: List of cleaned DataFrames and the latest 'Date Time' of removed rows.
    """
    if not dataframes:
        raise ValueError("The list of DataFrames is empty.")
    
    # Step 1: Identify rows with any NaN in each DataFrame
    nan_indices_list = [df.isnull().any(axis=1) for df in dataframes.values()]
    
    # Step 2: Combine the indices where NaNs are present in any DataFrame
    combined_nan = pd.Series([False] * len(dataframes[list(dataframes.keys())[0]]), index=dataframes[list(dataframes.keys())[0]].index)
    for nan_mask in nan_indices_list:
        combined_nan = combined_nan | nan_mask
    
    # Get the indices to drop
    indices_to_drop = combined_nan[combined_nan].index
    print(" number of rows with NaNs to drop: ", len(indices_to_drop))
    
    # Step 3: Find the latest 'Date Time' of the rows to be removed
    latest_removed_date = None
    for df in dataframes.values():
        if 'Date Time' in df.columns:
            removed_dates = df.loc[indices_to_drop, 'Date Time']
            if not removed_dates.empty:
                max_date = pd.to_datetime(removed_dates, errors='coerce').max()
                if latest_removed_date is None or (max_date is not pd.NaT and max_date > latest_removed_date):
                    latest_removed_date = max_date
    
    # Step 4: Drop the identified indices from all DataFrames
    cleaned_df_list = []
    for df in dataframes.values():
        cleaned_df = df.drop(indices_to_drop)
        if reset_indices:
            cleaned_df = cleaned_df.reset_index(drop=True)
        cleaned_df_list.append(cleaned_df)
    
    return cleaned_df_list, latest_removed_date
######################################################################################################################

def get_prccessed_latest_data_by_hour_and_station(stations_list, hours_back, begin_forecast_time=datetime.now()):
  success = True
  end_datetime = begin_forecast_time.replace(minute=0, second=0, microsecond=0)
  start_datetime = end_datetime - pd.Timedelta(days=7) # 7 days back
  startYear, endYear = start_datetime.year, end_datetime.year
  startDate, endDate = start_datetime.strftime("%m%d%H%M"), end_datetime.strftime("%m%d%H%M")
  print(f"startDate: {startDate}, endDate: {endDate}")

  dataframes = get_data_of_stations_from_ims_manual_time_range(stations_list, startYear, endYear, startDate, endDate)
  if len(dataframes) == 0:
    success = False
  
  sort_dataframes(dataframes)
  latest_deleted_date = delete_rows_not_existing_in_all_dataframes(dataframes)
  if latest_deleted_date and isinstance(latest_deleted_date, datetime) and latest_deleted_date > end_datetime:
    success = False
  remove_unecessery_columns(dataframes, COLUMNS_TO_REMOVE)
  format_the_time_column(dataframes)
  fill_1_missing_values(dataframes, VALUES_TO_FILL)
  fill_2_missing_values(dataframes, VALUES_TO_FILL)
  fill_3_missing_values(dataframes, VALUES_TO_FILL)
  take_round_hours(dataframes)
  fill_1_missing_values(dataframes, VALUES_TO_FILL)
  fill_2_missing_values(dataframes, VALUES_TO_FILL)
  last_datetime = pd.to_datetime(dataframes[list(dataframes.keys())[0]]['Date Time'].iloc[-1], format="%d/%m/%Y %H:%M")
  last_hour = last_datetime.strftime("%H:%M")
  last_date = last_datetime.strftime("%Y-%m-%d")
  print(f"last hour: {last_hour}, last hour date: {last_date}")

  replace_time_with_cyclic_representation(dataframes)
  vectorize_wind(dataframes)
  latest_deleted_date = drop_nan_rows_multiple(dataframes)
  if latest_deleted_date and isinstance(latest_deleted_date, datetime) and latest_deleted_date > end_datetime:
    success = False
    
  for df_name, df in dataframes.items(): # return only the last hours_back hours
      if len(df) > hours_back:
          dataframes[df_name] = df.iloc[-hours_back:].reset_index(drop=True)
  return dataframes, last_hour, last_date, success

## how to use this function:
    # stations_list = STATIONS_LIST
    # hours_back = 72
    # dataframes, last_hour, last_date, success = get_prccessed_latest_data_by_hour_and_station(stations_list, hours_back)
    # print(f"len of df: {len(dataframes)}")
    # print(f"len of df[0]: {len(dataframes[list(dataframes.keys())[0]])}")
    # print(f"len of df[1]: {len(dataframes[list(dataframes.keys())[1]])}")
    # print(f"len of df[2]: {len(dataframes[list(dataframes.keys())[2]])}")
    # print(f"Last hour: {last_hour}")
    # print(f"Last hour date: {last_date}")
    # print(f"success: {success}")


def main():
  #menue: 
  ## load from:
  get_data_from_ims = True # (including the first process)
  load_data_from_directory = not get_data_from_ims and True
  ## save to:
  save_data_to_pickles_in_the_end = True
  ## sync:
  sync = True #master switch
  should_sort_dataframes = sync and True
  should_slice_dataframes_beginning = sync and True
  should_delete_rows_not_existing_in_all_dataframes = sync and True
  ## preprocessing:
  preprocess = True
  should_remove_unecessery_columns = preprocess and True
  should_format_the_time_column = preprocess and True
  data_imputation = preprocess and True
  should_fill_data_1_missing_value = data_imputation and True
  should_fill_data_2_missing_values = data_imputation and True
  should_fill_data_3_missing_values = data_imputation and True
  should_take_round_hours = preprocess and True
  should_replace_time_with_cyclic_representation = preprocess and True #leave it false
  should_vectorize_wind = preprocess and True
  should_drop_nan_rows = preprocess and True
    ## display:
  should_display_heads_of_dataframes = False
  should_print_length_of_dataframes = True
  should_display_wind_before_vectorize = False
  should_display_wind_after_vectorize = should_vectorize_wind and False


  if get_data_from_ims:
    dataframes = get_data_of_stations_from_ims_by_constants_params(STATIONS_LIST, START_YEAR, END_YEAR)

  if load_data_from_directory:
    dataframes = load_dataframes_from_pickles(DATA_DIRECTORY)

  if should_sort_dataframes:
    sort_dataframes(dataframes)

  if should_slice_dataframes_beginning:
    slice_dataframes_beginning(dataframes, '01/04/2005 00:00')
  
  if should_print_length_of_dataframes:
    print_length_of_dataframes(dataframes)

  if should_delete_rows_not_existing_in_all_dataframes:
    delete_rows_not_existing_in_all_dataframes(dataframes)

  if sync and should_print_length_of_dataframes:
    print_length_of_dataframes(dataframes)

  if should_remove_unecessery_columns:
    remove_unecessery_columns(dataframes, COLUMNS_TO_REMOVE)

  if should_format_the_time_column:
    format_the_time_column(dataframes)

  if data_imputation: # imputation before rounding hours
    if should_fill_data_1_missing_value:
      fill_1_missing_values(dataframes, VALUES_TO_FILL, should_print=True)
    if should_fill_data_2_missing_values:
      fill_2_missing_values(dataframes, VALUES_TO_FILL, should_print=True)
    if should_fill_data_3_missing_values:
      fill_3_missing_values(dataframes, VALUES_TO_FILL, should_print=True)

  if should_take_round_hours:
    take_round_hours(dataframes)

  if data_imputation: # imputation after rounding hours
    if should_fill_data_1_missing_value:
      fill_1_missing_values(dataframes, VALUES_TO_FILL, should_print=True)
    if should_fill_data_2_missing_values:
      fill_2_missing_values(dataframes, VALUES_TO_FILL, should_print=True)

  if should_replace_time_with_cyclic_representation:
    replace_time_with_cyclic_representation(dataframes)

  if should_display_wind_before_vectorize:
    display_wind_before_vectorize(dataframes)

  if should_vectorize_wind:
    vectorize_wind(dataframes)

  if should_display_wind_after_vectorize:
    display_wind_after_vectorize(dataframes)
  
  if should_drop_nan_rows:
    drop_nan_rows_multiple(dataframes)

  if should_display_heads_of_dataframes:
    display_dataframes_heads(dataframes)

  if save_data_to_pickles_in_the_end:
    save_dataframes_to_pickles(dataframes, DATA_DIRECTORY)
"""
if __name__ == "__main__":
    main()
"""

### `model.py`

In [None]:
# model.py

import torch
import torch.nn as nn
import math

class StationCNN(nn.Module):
    def __init__(self,
                 input_features=15,
                 output_per_feature=3,
                 kernel_size=3,
                 use_batch_norm=False,
                 use_residual=False):
        """
        Args:
            input_features (int): Number of input features per station.
            output_per_feature (int): Number of output channels per feature.
            kernel_size (int): Size of the convolutional kernel.
            use_batch_norm (bool): Whether to use Batch Normalization.
            use_residual (bool): Whether to use residual connections.
        """
        super(StationCNN, self).__init__()
        self.output_per_feature = output_per_feature
        self.use_batch_norm = use_batch_norm
        self.use_residual = use_residual

        # Total out_channels = input_features * output_per_feature
        self.out_channels = input_features * output_per_feature

        # First convolutional layer
        self.conv1 = nn.Conv1d(
            in_channels=input_features,
            out_channels=self.out_channels,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
            groups=input_features  # Depthwise convolution
        )
        self.relu1 = nn.ReLU()

        # Optional Batch Normalization
        if self.use_batch_norm:
            self.bn1 = nn.BatchNorm1d(self.out_channels)

        # Second convolutional layer (optional for deeper CNN)
        self.conv2 = nn.Conv1d(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
            groups=input_features  # Maintain feature independence
        )
        self.relu2 = nn.ReLU()

        # Optional Batch Normalization
        if self.use_batch_norm:
            self.bn2 = nn.BatchNorm1d(self.out_channels)

        # Optional Residual Connection
        if self.use_residual:
            self.residual_conv = nn.Conv1d(
                in_channels=input_features,
                out_channels=self.out_channels,
                kernel_size=1,
                groups=input_features  # Depthwise 1x1 convolution
            )

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, input_features, time_steps]
        
        Returns:
            torch.Tensor: Output tensor of shape [batch_size, output_per_feature, time_steps, input_features]
        """
        b, f, t = x.shape  # [batch_size, input_features, time_steps]

        # First convolution
        out = self.conv1(x)  # [batch_size, input_features * output_per_feature, time_steps]
        if self.use_batch_norm:
            out = self.bn1(out)
        out = self.relu1(out)

        # Second convolution
        out = self.conv2(out)  # [batch_size, input_features * output_per_feature, time_steps]
        if self.use_batch_norm:
            out = self.bn2(out)
        out = self.relu2(out)

        # Optional Residual Connection
        if self.use_residual:
            residual = self.residual_conv(x)
            out = out + residual  # Ensures a new tensor is produced
            out = self.relu2(out)

        # Reshape to [batch_size, output_per_feature, input_features, time_steps]
        out = out.view(b, self.output_per_feature, f, t)
        # Permute to [batch_size, output_per_feature, time_steps, input_features]
        out = out.permute(0, 1, 3, 2)  # [batch_size, output_per_feature, time_steps, features]

        return out  # [batch_size, output_per_feature, time_steps, features]

class CoordinatePositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super(CoordinatePositionalEncoding, self).__init__()
        # Assuming two coordinates: X and Y
        self.lat_linear = nn.Linear(1, d_model // 2)
        self.lon_linear = nn.Linear(1, d_model // 2)
        self.activation = nn.ReLU()

    def forward(self, lat, lon):
        """
        Args:
            lat (torch.Tensor): [num_stations, 1] - Normalized X coordinates
            lon (torch.Tensor): [num_stations, 1] - Normalized Y coordinates
        Returns:
            torch.Tensor: [num_stations, d_model]
        """
        lat_enc = self.lat_linear(lat)  # [num_stations, d_model//2]
        lon_enc = self.lon_linear(lon)  # [num_stations, d_model//2]
        spatial_emb = self.activation(torch.cat([lat_enc, lon_enc], dim=1))  # [num_stations, d_model]
        return spatial_emb

class TemporalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(TemporalPositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): [batch_size, num_stations, time_steps, d_model]
        Returns:
            torch.Tensor: [batch_size, num_stations, time_steps, d_model]
        """
        x = x + self.pe[:, :x.size(2), :].unsqueeze(1)  # [batch, num_stations, time_steps, d_model]
        return x

class TargetedWeatherPredictionModel(nn.Module):
    def __init__(self, num_stations, time_steps, feature_dim, kernel_size,
                 d_model, nhead, num_layers, target_station_idx, label_width=1,
                 output_per_feature=3, use_batch_norm=False, use_residual=False):
        """
        Args:
            num_stations (int): Number of stations.
            time_steps (int): Number of time steps in the sliding window.
            feature_dim (int): Number of features per station.
            kernel_size (int): Size of the CNN kernel.
            d_model (int): Dimension of the model (for Transformer).
            nhead (int): Number of attention heads in the Transformer.
            num_layers (int): Number of Transformer encoder layers.
            target_station_idx (int): Index of the target station.
            label_width (int): Number of prediction steps.
            output_per_feature (int): Number of output channels per feature in CNN.
            use_batch_norm (bool): Whether to use Batch Normalization in CNNs.
            use_residual (bool): Whether to use residual connections in CNNs.
        """
        super(TargetedWeatherPredictionModel, self).__init__()
        self.num_stations = num_stations
        self.time_steps = time_steps
        self.target_station_idx = target_station_idx
        self.label_width = label_width
        self.output_per_feature = output_per_feature

        # Initialize separate CNNs for each station
        self.station_cnns = nn.ModuleList([
            StationCNN(
                input_features=feature_dim,
                output_per_feature=output_per_feature,
                kernel_size=kernel_size,
                use_batch_norm=use_batch_norm,
                use_residual=use_residual
            )
            for _ in range(num_stations)
        ])

        # Coordinate Positional Encoding
        self.coord_pos_encoding = CoordinatePositionalEncoding(d_model=d_model)

        # Linear layer to map CNN features to d_model
        # New feature_dim after CNN: output_per_feature * original feature_dim
        self.feature_mapping = nn.Linear(feature_dim * output_per_feature, d_model)

        # Temporal Positional Encoding
        self.temporal_pos_encoding = TemporalPositionalEncoding(d_model=d_model, max_len=time_steps)

        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Final prediction layer
        self.fc_out = nn.Linear(d_model, label_width)  # Output label_width predictions

    def forward(self, x, lat, lon):
        """
        Args:
            x (torch.Tensor): [batch_size, num_stations, time_steps, feature_dim]
            lat (torch.Tensor): [num_stations, 1] - Normalized X coordinates
            lon (torch.Tensor): [num_stations, 1] - Normalized Y coordinates
        Returns:
            torch.Tensor: [batch_size, label_width]
        """
        batch_size, num_stations, time_steps, feature_dim = x.size()

        # Extract temporal features for each station
        # Initialize a list to collect CNN outputs
        temporal_features = []
        for i in range(num_stations):
            station_data = x[:, i, :, :]  # [batch_size, time_steps, feature_dim]
            station_data = station_data.permute(0, 2, 1)  # [batch_size, feature_dim, time_steps]
            cnn_out = self.station_cnns[i](station_data)  # [batch_size, output_per_feature, time_steps, feature_dim]
            temporal_features.append(cnn_out)

        # Stack temporal features: [batch_size, num_stations, output_per_feature, time_steps, feature_dim]
        temporal_features = torch.stack(temporal_features, dim=1)  # [batch, num_stations, output_per_feature, time_steps, features]

        # Reshape to combine output_per_feature and features dimensions
        # New shape: [batch_size, num_stations, time_steps, output_per_feature * feature_dim]
        temporal_features = temporal_features.view(batch_size, num_stations, self.output_per_feature, time_steps, feature_dim)
        temporal_features = temporal_features.permute(0, 1, 3, 2, 4)  # [batch, num_stations, time_steps, output_per_feature, features]
        temporal_features = temporal_features.contiguous().view(batch_size, num_stations, time_steps, self.output_per_feature * feature_dim)  # [batch, num_stations, time_steps, output_per_feature * features]

        # Spatial positional encoding using coordinates
        spatial_emb = self.coord_pos_encoding(lat, lon)  # [num_stations, d_model]
        spatial_emb = spatial_emb.unsqueeze(0).unsqueeze(2)  # [1, num_stations, 1, d_model]

        # Map temporal features to d_model
        temporal_features = self.feature_mapping(temporal_features)  # [batch_size, num_stations, time_steps, d_model]

        # Apply temporal positional encoding
        temporal_features = self.temporal_pos_encoding(temporal_features)  # [batch, num_stations, time_steps, d_model]

        # Combine temporal and spatial features
        combined_features = temporal_features + spatial_emb  # [batch, num_stations, time_steps, d_model]

        # Reshape for Transformer: [batch_size, num_stations * time_steps, d_model]
        combined_features = combined_features.view(batch_size, num_stations * time_steps, -1)

        # Transpose for Transformer: [sequence_length, batch_size, d_model]
        combined_features = combined_features.permute(1, 0, 2)  # [num_stations * time_steps, batch_size, d_model]

        # Transformer expects [sequence_length, batch_size, d_model]
        transformer_out = self.transformer_encoder(combined_features)  # [sequence_length, batch_size, d_model]

        # Reshape back: [batch_size, num_stations, time_steps, d_model]
        transformer_out = transformer_out.permute(1, 0, 2)  # [batch_size, sequence_length, d_model]
        transformer_out = transformer_out.view(batch_size, num_stations, time_steps, -1)  # [batch_size, num_stations, time_steps, d_model]

        # Select target station's features: [batch_size, time_steps, d_model]
        target_features = transformer_out[:, self.target_station_idx, :, :]  # [batch_size, time_steps, d_model]

        # Instead of mean pooling, retain temporal information or use other aggregation
        # Here, we'll take the last time step's features for simplicity
        last_time_step_features = target_features[:, -1, :]  # [batch_size, d_model]

        # Final prediction
        prediction = self.fc_out(last_time_step_features)  # [batch_size, label_width]

        return prediction


### `parameters.py`

In [None]:
import torch
import os
device = 'cuda' if torch.cuda.is_available() else 'cpu'  # Determine device

# must define these 3 variables below!!!
###########################################################################################################################################################

# for training - where the output will be saved
#base_path = 'C:\\Users\\dorsh\\Documents\\GitHub\\WeatherNet\\backend\\Model_Pytorch\\AdvancedModel\\models' # in general we put the folder of the path that contains the parameters.py file
colab_path = 'C:\\Users\\dorsh\\Documents\\GitHub\\WeatherNet\\backend\\Model_Pytorch\\AdvancedModel\\models' # in general we put the folder of the path that contains the parameters.py file
colab_path = r'/content/models_for_inference'
#'C:\\Users\\dorsh\\Documents\\GitHub\\WeatherNet\\backend\\Model_Pytorch\\AdvancedModel\\models'
name_of_the_model_to_save_train = r'1_12'

# for inference
models_paths_dir_names_for_inference = ['1_12','12_24','24_36','36_60'] # for instance for alot of models we want to inference: ['model_1','model_2' ... ] for one : ['model_1']

###########################################################################################################################################################
base_path = colab_path
output_path = os.path.join(base_path, name_of_the_model_to_save_train)
checkpoints_path = os.path.join(output_path, 'checkpoints')
scalers_path = os.path.join(output_path, 'scalers')
inference_output_path = os.path.join(output_path, 'inference_output')

STATIONS_COORDINATES = {
    'Tavor Kadoorie':           (238440, 734540), #station id: 13
    'Newe Yaar':                (217010, 734820), #station id: 186
    'Yavneel':                  (248110, 733730), #station id: 11
    'En Hashofet':              (209310, 723170), #station id: 67
    'Eden Farm':                (246190, 708240), #station id: 206
    'Eshhar':                   (228530, 754390), #station id: 205
    'Afula Nir Haemeq':         (226260, 722410)  #station id: 16
}

STATIONS_COORDINATES_COLAB = {
    f'/content/drive/MyDrive/final data/Tavor Kadoorie.pkl':     (238440, 734540),
    f'/content/drive/MyDrive/final data/Newe Yaar.pkl':          (217010, 734820),
    f'/content/drive/MyDrive/final data/Yavneel.pkl':            (248110, 733730),
    f'/content/drive/MyDrive/final data/En Hashofet.pkl':        (209310, 723170),
    f'/content/drive/MyDrive/final data/Eden Farm.pkl':          (246190, 708240),
    f'/content/drive/MyDrive/final data/Eshhar.pkl':             (228530, 754390),
    f'/content/drive/MyDrive/final data/Afula Nir Haemeq.pkl':   (226260, 722410)
}

STATIONS_LIST = {
    "Tavor Kadoorie":   "13",
    "Newe Yaar":        "186",
    "Yavneel":          "11",
    "En Hashofet":      "67",
    "Eden Farm":        "206",
    "Eshhar":           "205",
    "Afula Nir Haemeq": "16"
}

PARAMS = {
    'paths_in_colab': [
        f'/content/drive/MyDrive/final data/Tavor Kadoorie.pkl',
        f'/content/drive/MyDrive/final data/Newe Yaar.pkl',
        f'/content/drive/MyDrive/final data/Yavneel.pkl',
        f'/content/drive/MyDrive/final data/En Hashofet.pkl',
        f'/content/drive/MyDrive/final data/Eden Farm.pkl',
        f'/content/drive/MyDrive/final data/Eshhar.pkl',
        f'/content/drive/MyDrive/final data/Afula Nir Haemeq.pkl'],
    'fileNames':        ['Tavor Kadoorie', 'Newe Yaar', 'Yavneel', 'En Hashofet', 'Eden Farm', 'Eshhar', 'Afula Nir Haemeq'],
    'target_station':   'Tavor Kadoorie',
    'target_station_desplay_name':   'Tavor Kadoorie',
    'target_station_id': 0,
    'device' :           device,
    'in_channels' :      15, # how many features we have
    'output_path':       output_path,
    'checkpoints_path':  checkpoints_path,
    'scalers_path':      scalers_path,
    'inference_output_path': inference_output_path
}

WINDOW_PARAMS = {
    'input_width' :     72, # window input size
    'label_width' :     12, # how many hours to predict to the future
    'shift' :           1,
    'label_columns' :   ['TD (degC)'],
}

"""
WINDOW_PARAMS = {
    'input_width' :     72, # window input size
    'label_width' :     12, # how many hours to predict to the future
    'shift' :           1,
    'label_columns' :   ['TD (degC)'],
}
"""
"""
WINDOW_PARAMS = {
    'input_width' :     72, # window input size
    'label_width' :     12, # how many hours to predict to the future
    'shift' :           13,
    'label_columns' :   ['TD (degC)'],
}
"""
"""
WINDOW_PARAMS = {
    'input_width' :     72, # window input size
    'label_width' :     12, # how many hours to predict to the future
    'shift' :           25,
    'label_columns' :   ['TD (degC)'],
}
"""
"""
WINDOW_PARAMS = {
    'input_width' :     72, # window input size
    'label_width' :     24, # how many hours to predict to the future
    'shift' :           37,
    'label_columns' :   ['TD (degC)'],
}
"""

TRAIN_PARAMS = {
    'epochs' :          50,
    'batch_size':       32,
    'lr':               1e-5,                                   # 1e-3, 1e-4, 1e-5
    'checkpoint_dir' :  PARAMS['checkpoints_path'],
    'resume':           False,
    'device':           PARAMS['device'],
    'early_stopping_patience':10,                               # how many epochs to wait before stopping the training
    'scheduler_patience':3,                                     # how many epochs to wait before reducing the learning rate
    'scheduler_factor':  0.5,                                   # the factor to reduce the learning rate
    'min_lr':            1e-7,
    'logger_path':       PARAMS['output_path']
}

ADVANCED_MODEL_PARAMS = {
    'num_stations':         len(PARAMS['fileNames']),
    'time_steps':           WINDOW_PARAMS['input_width'],
    'feature_dim':          PARAMS['in_channels'],
    'kernel_size':          3,  # cnn filter size                       4, 5, 6, 7
    'd_model':              64, # input for transformer size            64, 128
    'nhead':                8,  # number of heads in the transformer    8, 16
    'num_layers':           4,  # number of layers in the transformer - 6 - 12
    'target_station_idx':   PARAMS['target_station_id'],
    'label_width':          WINDOW_PARAMS['label_width'],
    'output_per_feature':   3,                                          # 4 ,5
    'use_batch_norm':       False,
    'use_residual':         True
}

models_paths_dir_names_full_paths = [os.path.join(base_path, model_folder_name) for model_folder_name in models_paths_dir_names_for_inference]

INFERENCE_PARAMS = {
    'params_path':             [os.path.join(folder, 'parameters.py') for folder in models_paths_dir_names_full_paths],
    'weights_paths':           [os.path.join(folder, 'checkpoints', 'best_checkpoint.pth') for folder in models_paths_dir_names_full_paths],
    'scaler_folder_path':      PARAMS['scalers_path'],
    'inference_output_path_per_model':  models_paths_dir_names_full_paths, # for saving the output of the inference in the model folder for each model
    'inference_output_path':  os.path.join(base_path, 'inference_output'), # for saving the output of the inference of all models in one folder (later analyze.py will use it)
}

### `inference.py`

In [None]:
# inference.py

import torch
import pickle
import os
from tqdm import tqdm
import pandas as pd
from sklearn.preprocessing import StandardScaler
import numpy as np
import json
from datetime import datetime, timedelta
import importlib.util
import sys
from pathlib import Path
import pytz








def load_params(params_path):
    # Convert path to absolute if it's relative
    params_path = Path(params_path).resolve()

    # Create a module name dynamically (avoid conflicts)
    module_name = "params_module"

    # Load the module
    spec = importlib.util.spec_from_file_location(module_name, params_path)
    params_module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = params_module
    spec.loader.exec_module(params_module)

    return params_module  # Now you can access its attributes


def flatten_data(predictions, actuals):
    flat_predictions = [temp for window in predictions for temp in window]
    flat_actuals = [temp for window in actuals for temp in window]

    data = pd.DataFrame({
        'Predicted': flat_predictions,
        'Actual': flat_actuals
    })

    data['Error'] = data['Predicted'] - data['Actual']
    return data

def generate_forecast_json(city_name, date_str, starting_hour, temperatures, output_file):
    """
    Generates a JSON file containing forecast data.

    Parameters:
    - city_name (str): Name of the city.
    - date_str (str): Starting date in "YYYY-MM-DD" format.
    - starting_hour (int): Starting hour in 24-hour format (0-23).
    - temperatures (np.ndarray): NumPy array of temperature readings.
    - output_file (str): Path to the output JSON file.
    """
    
    # Validate inputs
    if not isinstance(city_name, str):
        raise TypeError("city_name must be a string.")
    
    try:
        current_date = datetime.strptime(date_str, "%Y-%m-%d")
    except ValueError:
        raise ValueError("date_str must be in 'YYYY-MM-DD' format.")
    
    if not (0 <= starting_hour <= 23):
        raise ValueError("starting_hour must be between 0 and 23.")
    
    if not isinstance(temperatures, np.ndarray):
        raise TypeError("temperatures must be a NumPy array.")
    
    # Initialize forecast_data dictionary
    forecast_data = {}
    
    current_hour = starting_hour
    
    for temp in temperatures:
        # Format the current date
        date_key = current_date.strftime("%Y-%m-%d")
        
        # Initialize the date entry if it doesn't exist
        if date_key not in forecast_data:
            forecast_data[date_key] = {"hourly": {}}
        
        # Format the current time
        time_str = "{:02d}:00".format(current_hour)
        
        # Assign the temperature, formatted to one decimal place
        forecast_data[date_key]["hourly"][time_str] = {
            "temperature": "{:.1f}".format(temp)
        }
        
        # Increment the hour
        current_hour += 1
        
        # If hour exceeds 23, reset to 0 and move to the next day
        if current_hour > 23:
            current_hour = 0
            current_date += timedelta(days=1)
    
    # Construct the final JSON structure
    data = {
        "data": {
            "title": city_name,
            "forecast_data": forecast_data
        }
    }
    
    # Write the data to a JSON file
    with open(output_file, 'w') as f:
        json.dump(data, f, indent=2)
    
    print(f"Forecast data successfully written to {output_file}")


def load_scalers(scaler_dir='./output/scalers'):
    """
    Load the previously saved scalers for each station.
    Assumes scalers are saved as 'scaler_station0.pkl', 'scaler_station1.pkl', etc.
    """
    scalers = []
    num_stations = ADVANCED_MODEL_PARAMS['num_stations']
    for i in range(num_stations):
        scaler_path = os.path.join(scaler_dir, f'scaler_station_{i}.pkl')
        if not os.path.exists(scaler_path):
            raise FileNotFoundError(f"Scaler file not found at {scaler_path}")
        with open(scaler_path, 'rb') as f:
            scaler = pickle.load(f)
            scalers.append(scaler)
        print(f"Scaler for Station {i} loaded from {scaler_path}")
    return scalers


def load_model_for_inference(checkpoint_path, model_params, device='cpu'):
    """
    Create a model with the same architecture,
    load checkpoint, and return it in eval mode.
    """
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")

    model = TargetedWeatherPredictionModel(**model_params)
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    print(f"Model loaded from {checkpoint_path}")
    return model


def load_window_multi_station_return_only_input_window(data_np, window_size, scalers, idx=0):
    total_window_size = window_size
    if idx + total_window_size > len(data_np):
        raise ValueError(f"Index {idx} with window size {total_window_size} exceeds data length {len(data_np)}.")

    # Extract input window: shape (window_size, num_stations, num_features)
    window = data_np[idx:idx + window_size, :, :]  # (window_size, num_stations, num_features)

    # Apply individual scalers to each station's data
    scaled_window = []
    num_stations = window.shape[1]
    for station_idx in range(num_stations):
        station_data = window[:, station_idx, :]  # (window_size, num_features)
        scaler = scalers[station_idx]
        station_data_scaled = scaler.transform(station_data)  # (window_size, num_features)
        scaled_window.append(station_data_scaled)

    # Stack scaled data: shape (window_size, num_stations, num_features)
    scaled_window = np.stack(scaled_window, axis=1)

    # Convert to torch.Tensor and reshape to (1, num_stations, time_steps, feature_dim)
    window_tensor = torch.tensor(scaled_window, dtype=torch.float32).unsqueeze(0)
    # batch_size, num_stations, time_steps, feature_dim = x.size()

    window_tensor = window_tensor.permute(0, 2, 1, 3)

    return window_tensor


def load_window_multi_station(data_np, window_size, shift, label_width, scalers, target_column_index, idx=0):
   
    total_window_size = window_size + shift - 1 + label_width
    if idx + total_window_size > len(data_np):
        raise ValueError(f"Index {idx} with window size {total_window_size} exceeds data length {len(data_np)}.")

    # Extract input window: shape (window_size, num_stations, num_features)
    window = data_np[idx:idx + window_size, :, :]  # (window_size, num_stations, num_features)

    # Extract target: shape (label_width, )
    target_start = idx + window_size + shift - 1
    target_end = target_start + label_width
    actual_target = data_np[target_start:target_end, ADVANCED_MODEL_PARAMS['target_station_idx'], target_column_index]
    # actual_target_mean = actual_target.mean()  # Aggregate if label_width >1

    # Apply individual scalers to each station's data
    scaled_window = []
    num_stations = window.shape[1]
    for station_idx in range(num_stations):
        station_data = window[:, station_idx, :]  # (window_size, num_features)
        scaler = scalers[station_idx]
        station_data_scaled = scaler.transform(station_data)  # (window_size, num_features)
        scaled_window.append(station_data_scaled)
    
    # Stack scaled data: shape (window_size, num_stations, num_features)
    scaled_window = np.stack(scaled_window, axis=1)

    # Convert to torch.Tensor and reshape to (1, num_stations, time_steps, feature_dim)
    window_tensor = torch.tensor(scaled_window, dtype=torch.float32).unsqueeze(0)
    # batch_size, num_stations, time_steps, feature_dim = x.size()

    window_tensor = window_tensor.permute(0, 2, 1, 3)

    return window_tensor, actual_target


@torch.no_grad()
def predict(model, input_window, lat, lon, device='cpu'):
    """
    Perform prediction using the multi-station model.

    Args:
        model (torch.nn.Module): Trained TargetedWeatherPredictionModel.
        input_window (torch.Tensor): Input data of shape (1, num_stations, time_steps, feature_dim).
        lat (torch.Tensor): Normalized latitude coordinates of shape (num_stations, 1).
        lon (torch.Tensor): Normalized longitude coordinates of shape (num_stations, 1).
        device (str): Device to perform inference on.

    Returns:
        float: Predicted target value in original scale (e.g., Temperature in °C).
    """
    input_tensor = input_window.to(device)
    lat = lat.to(device)
    lon = lon.to(device)

    # Model prediction (scaled)
    output_scaled = model(input_tensor, lat, lon)

    # Convert to numpy
    output_scaled_np = output_scaled.squeeze(-1).cpu().numpy().reshape(-1, 1)  # Scalar
    return output_scaled_np


if __name__ == "__main__":
    """
    1. define INFERENCE_PARAMS in your file - all the INFERENCE_PARAMS are mandatory in addition to that all the parameters that are in section 2

    2. there is an assumption that yours parameters file (not what define in the infarance parameters nor what in the folders)
        has the same values in these parameters:
        PARAMS['paths_in_colab'] - list of the stations names, if you have different names it wont make sense 
        PARAMS['target_station_id'] - the index of the target station in the list of stations
        ADVANCED_MODEL_PARAMS['target_station_idx'] - the index of the target station in the list of stations
        WINDOW_PARAMS['label_columns'] - the label column
        PARAMS['device']
    """
    inference_mode = 'live'  # Options: 'live', 'analyze'
    analyze_stop_at = 0  # Number of predictions to analyze
    local_tz = pytz.timezone('Asia/Jerusalem')
    delta = timedelta(hours=0)
    start_time_to_predict =  datetime.now(local_tz) - delta
    print(f"start_time_to_predict: {start_time_to_predict}")
    print(f"delta: {delta}")
    verbos_get_prccessed_latest_data_by_hour_and_station = False


    parameters_files = [] # load parameters files
    for path in INFERENCE_PARAMS['params_path']:
        parameters_files.append(load_params(path))

    east = []
    north = []
    filenames = PARAMS['paths_in_colab'] 
    for filename in filenames: # get the coordinates of the stations
        east.append(STATIONS_COORDINATES_COLAB[filename][0])
        north.append(STATIONS_COORDINATES_COLAB[filename][1])

    east = np.array(east)
    north = np.array(north)
    east_normalized, north_normalized = normalize_coordinates(east, north)
    
    scalers = load_scalers(scaler_dir=INFERENCE_PARAMS['scaler_folder_path'])
    # model_params = ADVANCED_MODEL_PARAMS.copy()

    
    model_params = []
    for params_file in parameters_files:
        model_params.append(params_file.ADVANCED_MODEL_PARAMS)

    models = []
    for i, weights_path in enumerate(INFERENCE_PARAMS['weights_paths']):
        model = load_model_for_inference(weights_path, model_params[i], device=PARAMS['device'])
        models.append(model)

    window_params = []
    for params_file in parameters_files:
        window_params.append(params_file.WINDOW_PARAMS)
    
    max_input_width = max([window_param['input_width'] for window_param in window_params])

    device = PARAMS['device']
    target_station_idx = PARAMS['target_station_id']

    if inference_mode == 'live':        
        dataframes, last_hour, last_date, success = get_prccessed_latest_data_by_hour_and_station(STATIONS_LIST, max_input_width)
        last_hour = int(last_hour.split(':')[0])
        if verbos_get_prccessed_latest_data_by_hour_and_station:
            print(f"len of df: {len(dataframes)}")
            print(f"len of df[0]: {len(dataframes[list(dataframes.keys())[0]])}")
            print(f"len of df[1]: {len(dataframes[list(dataframes.keys())[1]])}")
            print(f"Last hour: {last_hour}")
            print(f"Last hour date: {last_date}")
            print(f"success: {success}")
        # datafreames is a dictionary with the station name as key and the dataframe as value - convering it into a list of dataframes
        dataframes_list = [dataframes[station] for station in STATIONS_LIST]
        
        representative_df = dataframes_list[0]
        column_indices = {name: i for i, name in enumerate(representative_df.columns)}
        label_columns = [column_indices[WINDOW_PARAMS['label_columns'][0]]]
        target_col_index = label_columns[0]

        list_of_values = [df.values for df in dataframes_list]
        combined_window = np.stack(list_of_values, axis=1) 
    
        target_scaler = scalers[ADVANCED_MODEL_PARAMS['target_station_idx']]
        predictions_of_models = []
        for i, model in enumerate(models):
            input_width = parameters_files[i].WINDOW_PARAMS['input_width']
            input_window = load_window_multi_station_return_only_input_window(
                data_np=    combined_window[-input_width:],
                window_size=input_width,
                scalers=    scalers
            )
            y_pred_scaled = predict(model, input_window, east_normalized, north_normalized,  device=device)
            dummy = np.zeros((y_pred_scaled.shape[0], target_scaler.mean_.shape[0]))
            dummy[:, target_col_index] = y_pred_scaled[:, 0]
            y_pred_original = target_scaler.inverse_transform(dummy)[:, target_col_index]
            predictions_of_models.append(y_pred_original)

        generate_forecast_json(PARAMS['target_station_desplay_name'], last_date, last_hour+1, np.concatenate(predictions_of_models), "forecast.json")

    elif inference_mode == 'analyze':
        dfs = []
        for filename in filenames:
            df = pd.read_pickle(filename)
            dfs.append(df)

        print("Original size of data:")
        for i, df in enumerate(dfs):
            print(f"Station {i}: {df.shape}")

        list_of_values = [df.values for df in dfs]

        # Train/Validation Split per Station
        train_size = int(0.8 * len(list_of_values[0]))
        list_of_train_data = []
        list_of_val_data = []
        for values in list_of_values:
            train_data = values[:train_size]
            val_data = values[train_size:]
            list_of_train_data.append(train_data)
            list_of_val_data.append(val_data)

        # Combine Data into 3D Arrays
        combined_train_data = np.stack(list_of_train_data, axis=1)  # (T_train, num_stations, num_features)
        combined_val_data = np.stack(list_of_val_data, axis=1)  # (T_val, num_stations, num_features)

        # Ensure consistent column indexing
        representative_df = dfs[0]
        column_indices = {name: i for i, name in enumerate(representative_df.columns)}
        label_columns = [column_indices[WINDOW_PARAMS['label_columns'][0]]]

        # Define target station index
        target_col_index = label_columns[0]

        for i, model in enumerate(models):
            input_width = parameters_files[i].WINDOW_PARAMS['input_width']
            shift = parameters_files[i].WINDOW_PARAMS['shift']
            label_width = parameters_files[i].WINDOW_PARAMS['label_width']

            total_window_size = input_width + shift - 1 + label_width
            end = len(combined_val_data) - total_window_size if analyze_stop_at == 0 else min(analyze_stop_at, len(combined_val_data) - total_window_size)
            
            predictions = []
            actual_temps = []
            for j in tqdm(range(0, end), desc="Predicting"):
                try:
                    input_window, actual_temp = load_window_multi_station(
                        data_np=combined_val_data,
                        window_size=    input_width,
                        shift=          shift,
                        label_width=    label_width,
                        scalers=        scalers,
                        target_column_index=target_col_index,
                        idx=    j
                    )
                    y_pred_scaled = predict(model, input_window, east_normalized, north_normalized,  device=device)
                    # Inverse transform
                    target_scaler = scalers[ADVANCED_MODEL_PARAMS['target_station_idx']]
                    dummy = np.zeros((y_pred_scaled.shape[0], target_scaler.mean_.shape[0]))
                    dummy[:, target_col_index] = y_pred_scaled[:, 0]
                    y_pred_original = target_scaler.inverse_transform(dummy)[:, target_col_index]

                    if len(y_pred_original) != len(actual_temp):
                        continue
                    predictions.append(y_pred_original)
                    actual_temps.append(actual_temp)
                except ValueError as ve:
                    print(f"Skipping index {j}: {ve}")
                    continue
            
            output_dir_for_all = os.path.join(os.path.dirname(__file__), INFERENCE_PARAMS['inference_output_path'])
            os.makedirs(output_dir_for_all, exist_ok=True)
            output_dir_per_folder = os.path.join(os.path.dirname(__file__), INFERENCE_PARAMS['inference_output_path_per_model'][i])
            os.makedirs(output_dir_per_folder, exist_ok=True)
            predictions_actuals_df = flatten_data(predictions, actual_temps)
            predictions_actuals_df['input_width'] = input_width
            predictions_actuals_df['label_width'] = label_width
            predictions_actuals_df.to_csv(os.path.join(output_dir_per_folder, f'{i}_predictions_{i}.csv'), index=False)
            predictions_actuals_df.to_csv(os.path.join(output_dir_for_all, f'{i}_predictions_{i}.csv'), index=False)
    else:
        print(f"Invalid inference_mode : {inference_mode}.")