<a href="https://colab.research.google.com/github/ZackAel/beginners-python/blob/master/moe-f-blanka.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

import os

## --- Setup LOGGING --- ##
from datetime import datetime

import fire
import torch.nn.functional as _F
from loguru import logger
from tqdm import tqdm

# Log DEBUG to .log file:
script_dir = os.path.dirname(os.path.abspath(__file__))
base_name = os.path.basename(__file__).split(".")[0]

log_dir = os.path.join(script_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d")
log_path = os.path.join(log_dir, f"{base_name}_{timestamp}.log")
logger.add(log_path,
           format="{time} {level} {message}",
           level="DEBUG",
           filter=lambda record: record["level"].name == "DEBUG",
           mode='w'
)

## --- CONSTANTS --- ###
EPSILON = 1e-5

## --- HELPERS --- ###
from utils_common.utils_common import *
from utils_mse import utils_moef as um

## -- MSE HELPER FUNCTIONS -- ##
def compute_A_mse(
        w_i: torch.Tensor,
        y_t: torch.Tensor,
        X_until_t: torch.Tensor,
        experts: list,
        delta: float,
        t: int,
        F: torch.Tensor = None,
        DeltaF: torch.Tensor = None,
) -> torch.Tensor:
    """Compute the matrix A, which represents the adjusted predictions of experts.
    :return: A tensor representing the A matrix with shape (N,).
    """
    N = len(experts)
    if F is None:
        #F = torch.tensor([expert(X_until_t) for expert in experts], dtype=torch.float)
        F = torch.stack([expert(X_until_t) for expert in experts], dim = 0)
    if DeltaF is None:
        if X_until_t.shape[0] > 1:
            #F_prior = torch.tensor([expert(X_until_t[:-1]) for expert in experts], dtype=torch.float)
            F_prior = torch.stack( [expert(X_until_t[:-1] ) for expert in experts], dim=0 )
        else:
            F_prior = torch.zeros_like(F, dtype=torch.float)  # F : (N, H*C)
        DeltaF = F - F_prior

    F = F.float()  # Redundant for MSE case with F -> (N, H*C)
    wF = torch.sum(w_i.view(-1, 1) * F, dim=0)
    LHS = (y_t - F)
    RHS = wF - DeltaF + 1
    A = 2 * LHS * RHS

    return A

def compute_B_mse(
        Y: torch.Tensor,
        x: torch.Tensor,
        experts: list,
        delta: float,
        t: int,
        F: torch.Tensor = None,
) -> torch.Tensor:
    """
    Compute the matrix B, which scales the error terms.
    :return: A tensor representing the B matrix with shape (N,).
    """
    if F is None:
        F = torch.stack([expert(x) for expert in experts], dim=0)
    #B = 2 ** (3 / 2) * torch.exp(t * 4 * log_delta) * (Y - F)
    B = 2 ** (3 / 2) * (Y - F)
    return B

def compute_ell_mse(
        Y: torch.Tensor,
        X_until_t: torch.Tensor,
        experts: list,
        F: torch.Tensor = None,
) -> torch.Tensor:
    """
    Compute the loss for each expert's prediction.
    :return: A tensor representing the loss matrix with shape (N,).
    """
    loss_func = torch.nn.MSELoss()
    if F is None:
        F = torch.stack([expert(X_until_t) for expert in experts], dim=0)

    ell = torch.tensor([loss_func(f, Y) for f in F])
    return ell

## -- END - MSE HELPER FUNCTIONS -- ##

def compute_A_bar(
        pi: torch.Tensor,
        y_t: torch.Tensor,
        X_until_t: torch.Tensor,
        experts: list,
        delta: float,
        t: int,
        F: torch.Tensor = None,
        DeltaF: torch.Tensor = None,
        use_uniform_w: bool = False,
        normalize_A: bool = True,
        *args, **configs
) -> torch.Tensor:
    """
        Compute the weighted sum of the matrix A across all experts.
        :param pi: Expert weights with shape (N, N).
        :param y_t: (H*C)
        :param X_until_t: Input features with shape (t, N, H*C ).
        :param experts: A list of expert functions.
        :param delta: A float value for adjustments.
        :param t: Current time step as an integer.
        :param F: Expert predictions
        :param DeltaF:
        :param use_uniform_w: Should be always False, Boolean to decide whether to use uniform weights.
        :param normalize_A: Boolean to decide whether to normalize A.

        :return: A tensor representing the weighted A matrix with shape (N,).
    """
    is_using_mse = configs.get('is_using_mse', True)
    delta = delta or configs.get('delta_val', 1)

    N = len(experts)
    if use_uniform_w:
        A = torch.stack([compute_A_mse(torch.eye(N)[i], y_t, X_until_t, experts, delta, t, F, DeltaF,) for i in range(N)])
    else:
        A = torch.stack([compute_A_mse(w_i=pi[i],
                               y_t=y_t,
                               X_until_t=X_until_t,
                               experts=experts,
                               delta=delta,
                               t=t, F=F, DeltaF=DeltaF)
                                    for i in range(N)],
                        dim=0)      # A.shape = (N, N, H*C) or (3, 3, 672)

    if normalize_A:
        row_sums = A.sum(dim=1, keepdim=True)
        row_sums[row_sums == 0] = 1
        A_normalized = A / row_sums
        A_bar = torch.sum(A_normalized * pi.T.unsqueeze(-1), dim=0)  # A_normalized: (N, N, H*C); pi.t().unsqueeze(-1): (N, N, 1)
    else:
        A_bar = torch.sum(A * pi.T.unsqueeze(-1), dim=0)

    return A_bar


def mixture_of_filters(X: torch.Tensor, Y: torch.Tensor, seed: int, experts: list,
                       lam: float,
                       delta: float,
                       *args, **configs) -> torch.Tensor:
    """
    Combine predictions from multiple experts using a dynamic weighting algorithm.
    :param X: Input data tensor with shape (T, N, H, C). where T is the time in hours or rows
                and each column 0:N-1 are experts pred. for day t in T.
    :param Y: Output data tensor with shape (T, H, C).
    :return: A tensor representing the predicted values.
    """
    lam = lam or configs.get('lambda_val', 1.0)
    delta = delta or configs.get('delta_val', 1)
    update_Q = configs.get('update_Q', True)
    normalize_pi = configs.get('normalize_pi', False)
    is_log_values = configs.get('is_log_values', False)
    clamp_at = configs.get('clamp_at', CLAMP_AT)

    T = configs.get('T', 2785)
    L = configs.get('L', 720)
    H = configs.get('H', 96)
    C = configs.get('C', 7)
    N = len(experts)
    assert X.shape == (T, N, H*C)
    assert Y.shape == (T, H*C)

    pi = torch.ones((N, N), dtype=torch.float) / N
    Q = torch.ones((N, N)) / (N - 1) - torch.eye(N) / (N - 1) - torch.eye(N)
    L_prior = torch.zeros(N)
    F_prior = torch.zeros((N, H * C), dtype=torch.float)
    next_pi = torch.zeros_like(pi, dtype=torch.float)
    next_Q = torch.zeros_like(Q)
    Y_hat = torch.zeros_like(Y)
    prob_Y_hat = torch.zeros((T, H * C), dtype=torch.float)  # update at each iteration
    PI_BAR_LIST = []
    for t in tqdm(range(T)):
        F = torch.stack([expert(X[:t + 1]) for expert in experts], dim=0)       # (3, 672)
        DeltaF = F - F_prior
        F_prior = F
        X_until_t = X[:t+1]
        A_bar = compute_A_bar(pi=pi, y_t=Y[t], X_until_t=X_until_t, experts=experts, delta=delta, t=t, F=F,
                              DeltaF=DeltaF, *args, **configs)  # A_bar: (N,) i.e. (5,)
        if torch.eq(A_bar, 0).any():
            logger.warning(f"WARNING: POTENTIAL 0 DIV: 0 value in A_bar at {t}")
            A_bar = A_bar + (A_bar == 0).float() * EPSILON
            logger.info(f"To A_bar(after): {A_bar} by adding {EPSILON} for numerical stability")

        B = compute_B_mse(Y=Y[t], x=X_until_t, experts=experts, delta=delta, t=t, F=F,)
        if torch.isnan(B).any():
            logger.warning(f"ERROR: BUG: NaN value in B at {t}")
        if torch.eq(B, 0).any():
            logger.warning(f"WARNING: POTENTIAL 0 DIV: 0 value in B at {t}, htat will cause 0 division resulting in 'inf'")
            logger.info(f"Adjusting B (before): {B}")
            B = B + (B==0).float() * EPSILON
            logger.info(f"To B(after): {B} by adding {EPSILON} for numerical stability")

        L = compute_ell_mse(Y=Y[t], X_until_t=X_until_t, experts=experts, F=F,)
        if torch.isnan(L).any():
            logger.warning(f"ERROR: BUG: NaN value in L at {t} with seed: {seed}")
        DeltaL = L - L_prior
        DeltaL = DeltaL.unsqueeze(-1)  # # Reshape DeltaL to be compatible with A_bar and B: (N,) -> (N, 1)
        L_prior = L
        Delta_W_bar = (DeltaL - A_bar) / B
        As = torch.stack([compute_A_mse(w_i=torch.eye(N)[i],
                                    y_t=Y[t],
                                    X_until_t=X_until_t,
                                    experts=experts, delta=delta, t=t, F=F,
                                    DeltaF=DeltaF, )
                            for i in range(N)]
        )
        check_for_nans([pi, Q, next_pi, F, As, B, L],
                       ["pi", "Q", "next_pi", "F", "As" , "B", "L"], t, seed)
        log_args = {
            f"Q_{t}_before_update": Q.clone().detach(),
            f"pi_{t}_before_update": pi.clone().detach(),
            f"L_{t}": L.clone().detach(),
            'DeltaL': DeltaL.clone().detach(),
        }
        for n in range(N):
            pi_n = pi[n].unsqueeze(-1)          # (N, 1)
            drift = Q @ pi_n
            diffusion = pi_n * (As[:, n] - A_bar[n]) / B[n]
            update = drift + diffusion * (Delta_W_bar[n] + EPSILON)
            def get_normalized_z_score(u_mu):
                _mean = u_mu.mean()
                _std = u_mu.std()
                _normalized = (u_mu - _mean) / _std
                return _normalized
            # Compute update_mean and normalize
            update_mean = torch.mean(update, dim=-1)
            update_normalized = get_normalized_z_score(update_mean)
            pi_n_updated = pi_n.squeeze() + update_normalized.squeeze()
            pi_n_updated = torch.nn.functional.softmax(pi_n_updated, dim=0) # Valid Prob. dist: Normalize to ensure it sums to 1
            log_args.update({
                f"pi_update_{t}_{n}": pi_n_updated
            })

            next_pi[n] = pi_n_updated
            next_pi[n] = torch.clamp(next_pi[n], min=0)

            if torch.isnan(next_pi[n]).any():
                logger.warning(f"ERROR: BUG: NaN value in next_pi[{n}] at {t}")
            if normalize_pi:
                sum_next_pi_n = torch.sum(next_pi[n])
                if sum_next_pi_n > 0:
                    next_pi[n] = next_pi[n] / sum_next_pi_n
                else:
                    is_zerolike = torch.isclose(sum_next_pi_n, torch.tensor(0.0))
                    logger.warning(
                        f"Sum of probabilities for next_pi[{n}] was zero at time {t}. Assigning equal probabilities.")
                    next_pi[n] = torch.full_like(next_pi[n], fill_value=1/len(next_pi[n]))
            next_pi[n] = torch.clamp(next_pi[n], EPSILON, 1.0)  # Clamp values if necessary (adjust range as needed)

        pi = next_pi.clone().detach()
        loss = torch.nn.MSELoss()
        s = s_mse = torch.tensor([loss(f, Y[t]) for f in (pi @ F)])
        Y_hat_t = pi @ F
        pi_bar_t = _F.softmin(lam*s, dim=0)   # (N, 1)

        PI_BAR_LIST.append(pi_bar_t)    # Add to PI_BAR_LIST for plotting weight composition
        Y_moef_t = pi_bar_t @ Y_hat_t
        Y_hat[t] = Y_moef_t
        log_args.update({
            f"s_mse_{t}": s_mse.clone().detach(),
            f"pi_bar_{t}": pi_bar_t.clone().detach(),
        })
        ## Ln 26-32: Update Q
        if update_Q:
            import scipy
            alpha = configs.get("alpha_val", 0.10)    # N.B. lowering Identity dominance (<0.10) may lead to numerical instability (complex parts)
            N = pi_bar_t.shape[0]
            P_tilde = pi_bar_t.repeat(N, 1)
            P = (1 - alpha) * P_tilde + alpha * torch.eye(N)
            P_eigvals = torch.linalg.eigvals(P)
            log_P = torch.tensor(scipy.linalg.logm(P.numpy()).real, dtype=torch.float32)
            M  = _F.relu(log_P)
            RHS = torch.sum(M, dim=0, keepdim=True) * torch.eye(N)
            Q  = M - RHS

        if is_log_values:
            log_values(t, **log_args)
        logger.info("-" * 25)
        logger.info(f"==== Run for t:{t} completed ====")
        logger.info("-" * 25)
    # End Loop For 1...T
    return Y_hat, None, Q, pi, PI_BAR_LIST


def run_experiments_data(id: str, seed: int, experiments_dir: str = EXPERIMENTS_PATH, **configs):
    is_save_results = configs.get('is_save_results', True)
    lambda_val = configs.get('lambda_val', 1.0)
    delta_val = configs.get('delta_val', 1)
    alpha_val = configs.get('alpha_val', 0.10)
    is_debug = configs.get('is_debug', False)
    dataset = configs.get('DATASET')
    is_generate_csvs = configs.get('is_generate_csvs', False)

    # Use the values in your experiments
    logger.info(f"{dataset}: Running experiment for ID {id} with seed {seed}")

    ### ETTh data 720_96 config:
    columns = ['HUFL', 'HULL', 'MUFL', 'MULL', 'LUFL', 'LULL', 'OT']
    T = configs.get('T', 2785)
    L = configs.get('L', 720)
    H = configs.get('H', 96)
    C = len(columns)

    logger.info(f"Configs: T={T}, lambda_val={lambda_val}, delta_val={delta_val}")
    ## 1. Setup I/O: Extract True Label and N Expert Predictions
    if is_generate_csvs:
        df_Y, expert_df_X_dict, X, Y = um.setup_moef_io_all_experts(seed=seed, **configs)
        # Should get back Y: (2785, 96, 7) and X_old: (3, 2785, 96, 7) -> X_permuted: (2785, 3, 96, 7)
    else:
        # Load X, Ys direclty from the test_results folder
        X, Y, expert_names, expert_metrics = um.get_moef_io_for_all_experts_from_test_results(seed=seed, **configs)

    ## 2. Setup Experts, Loss Function and Init Variables
    num_of_experts = N = X.shape[1]
    X = X.reshape(T, N, H*C)
    Y = Y.reshape(T, H*C)

    def make_expert(i):
        return lambda x: x[-1, i, :] if x.shape[0] > 0 else torch.randn(H * C)
    experts = [make_expert(i) for i in range(num_of_experts)]

    ## 3. Call MoE-F ALGORITHM with Setup
    Y_hat, prob_Y_hat, Q, pi, pi_bar_list = mixture_of_filters(X, Y, seed=seed,
                                                               experts=experts, lam=lambda_val, delta=delta_val,
                                                               **configs)
    if torch.isnan(Y_hat).any():
        logger.warning(f"Warning: NaN values found in Y_hat with window size T = {T}.")
        logger.warning("Existing program - as going forward is pointless")
        raise ValueError("NaN values detected in Y_hat tensor.")
    else:
        logger.info(f"Yay! No NaN values in Y_hat with window size T = {T}")

    # Calculate performance
    loss_fn = torch.nn.MSELoss()
    loss_mse = loss_fn(Y_hat, Y)
    loss_mae = torch.mean(torch.abs(Y_hat - Y))
    loss_rse = torch.sqrt(torch.sum((Y - Y_hat) ** 2)) / torch.sqrt(torch.sum((Y - Y.mean()) ** 2))

    logger.info(f"MoE-F Loss MSE: {loss_mse}, MAE: {loss_mae}, RSE: {loss_rse}")
    logger.info("-" * 20)
    for k,v in expert_metrics.items():
        # expert_metrics[expert_name] = (mse, mae, rse)
        mse, mae, rse = v
        logger.info(f"\t{k} Loss MSE: {mse}, MAE: {mae}, RSE: {rse}")

    if is_save_results:
        ## Save the results in "experiments_tsf/ETTh1_sl720_pl96/all-experts_seed-0/df_moef.csv"
        results_save_dir = configs.get("EXPERIMENTS_PATH")
        results_dir = os.path.join(results_save_dir, f"{id}_seed-{seed}")
        os.makedirs(results_dir, exist_ok=True)
        result_file_path = os.path.join(results_dir, "result.txt")
        logger.info(f"Saving results to {result_file_path}")
        with open(result_file_path, mode='a') as result_file:
            result_file.write("\n")
            result_file.write(f"{dataset}_{L}_{H}_MoE-F_test_0_seed{seed}_lambda-{lambda_val}_alpha-{alpha_val}_delta-{delta_val}\n")
            result_file.write("\n")
            result_file.write(f"Experiment ID: {id}, Seed: {seed}, Lambda: {lambda_val}, Alpha: {alpha_val}, Delta: {delta_val}\n")
            result_file.write( f"MoE-F MSE: {loss_mse.item():.4f}, MAE: {loss_mae.item():.4f}, RSE: {loss_rse.item():.4f}\n")
            result_file.write("\n")

            # Log and save expert metrics
            result_file.write("Expert Models Performance:\n")
            result_file.write("-" * 25 + "\n")
            for expert_name, (mse, mae, rse) in expert_metrics.items():
                result_file.write(f"{expert_name}\t MSE: {mse:.4f}, MAE: {mae:.4f}, RSE: {rse:.4f}\n")

            # Additional content formatting
            #result_file.write(f"mse:{loss_mse.item():.4f}, mae:{loss_mae.item():.4f}, rse:{loss_rse.item():.4f}\n")
            result_file.write("\n")
            result_file.write("="*100 + "\n")

        logger.info(f"Results saved to {result_file_path}")
    logger.info("=" * 50)

def main(config_file: str = "config_720_96.yaml"):
    configs = load_config(config_file)

    # List of IDs to iterate through
    IDs = ["all-experts"]
    SEEDS = configs['SEEDS']

    # Run experiments
    for id in IDs:
        logger.info("=" * 50)
        logger.info(f"Running for: {id} ... ")
        for seed in SEEDS[:1]:
            logger.info(f"Using seed: {seed}")
            torch.manual_seed(seed)
            run_experiments_data(id=id, seed=seed, **configs)
        logger.info("=" * 50)

if __name__ == "__main__":
    fire.Fire(main)