# Inicializando

In [None]:
!pip install pytorch-forecasting lightning numpy pandas pyarrow matplotlib tqdm_joblib

In [None]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

In [6]:
import os
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import joblib
import json
from joblib import Parallel, delayed
from tqdm import tqdm
from tqdm_joblib import tqdm_joblib
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import LSTM, Dense, Input, Concatenate, Dropout
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback
from tensorflow.keras.optimizers import Adam

# --- Configuração de Paths e Constantes ---
BASE_DIR = "/content/drive/MyDrive/lstm_projeto"
DATA_PATH = os.path.join(BASE_DIR, "data", "final_training_data.parquet")
CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints")
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "model_checkpoint.keras")
EPOCH_TRACK_PATH = CHECKPOINT_PATH.replace(".keras", "_epoch.txt")
CITY_MAP_PATH = os.path.join(BASE_DIR, "city_to_idx.json")
EARLYSTOP_STATE = os.path.join(BASE_DIR, "earlystop_city_state.json")
TRAINING_HISTORY_PATH = os.path.join(BASE_DIR, "training_history.json")

SCALER_DIR = os.path.join(BASE_DIR, "scalers")

# Cria os diretórios no Google Drive se eles não existirem
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(SCALER_DIR, exist_ok=True)

print("Paths configurados com sucesso.")

Paths configurados com sucesso.


# Treinando

## Municipal

In [None]:
import os
import json
import numpy as np
import joblib
from typing import Tuple

import tensorflow as tf
from tensorflow.keras.callbacks import Callback

def create_sequences_multi_step(data: np.ndarray, seq_len: int, horizon: int) -> Tuple[np.ndarray, np.ndarray]:
    """
    Gera janelas (X, y) empilhadas para previsão multi-step.
    X shape: (n_samples, seq_len, n_features)
    y shape: (n_samples, horizon)  -> alvo é a primeira coluna de `data` (numero_casos)
    """
    num_samples = data.shape[0] - seq_len - horizon + 1
    if num_samples <= 0:
        return np.array([]), np.array([])

    X = np.zeros((num_samples, seq_len, data.shape[1]), dtype=np.float32)
    y = np.zeros((num_samples, horizon), dtype=np.float32)

    for i in range(num_samples):
        X[i] = data[i : i + seq_len]
        y[i] = data[i + seq_len : i + seq_len + horizon, 0]
    return X, y


def asymmetric_mse(y_true, y_pred):
    penalty_factor = 5.0  
    error = y_true - y_pred
    denom = tf.maximum(tf.abs(y_true), 1.0) 
    rel = tf.abs(error) / denom
    penalty = tf.where(error > 0, 1.0 + penalty_factor * rel, 1.0)  
    loss = tf.square(error) * penalty
    return tf.reduce_mean(loss)


# --- Callbacks utilitários --- #
class EpochSaver(Callback):
    """Salva o número da última epoch completada (arquivo texto)."""
    def __init__(self, epoch_file_path: str):
        super().__init__()
        self.epoch_file_path = epoch_file_path

    def on_epoch_end(self, epoch, logs=None):
        try:
            with open(self.epoch_file_path, "w") as f:
                f.write(str(int(epoch)))
        except Exception:
            pass


class HistorySaver(Callback):
    def __init__(self, save_path: str):
        super().__init__()
        self.save_path = save_path

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        try:
            with open(self.save_path, "r") as f:
                hist = json.load(f)
        except Exception:
            hist = {}

        for k, v in logs.items():
            try:
                val = float(v)
            except Exception:
                continue
            hist.setdefault(k, []).append(val)

        try:
            with open(self.save_path, "w") as f:
                json.dump(hist, f)
        except Exception:
            pass


class PersistentEarlyStoppingCityVal(Callback):
    def __init__(self, city_val_callback, patience=10, restore_best_weights=True, state_file="earlystop_city_state.json"):
        super().__init__()
        self.city_val_callback = city_val_callback
        self.patience = patience
        self.restore_best_weights = restore_best_weights
        self.state_file = state_file
        self.wait = 0
        self.best_weights = None
        self.best = np.inf

        if os.path.exists(self.state_file):
            try:
                with open(self.state_file, "r") as f:
                    state = json.load(f)
                self.wait = int(state.get("wait", 0))
                self.best = float(state.get("best", np.inf))
            except Exception:
                self.wait = 0
                self.best = np.inf

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        current = logs.get("city_val_loss")
        if current is None and hasattr(self.city_val_callback, "on_epoch_end"):
            self.city_val_callback.on_epoch_end(epoch, logs)
            current = logs.get("city_val_loss")
        if current is None:
            return

        current = float(current)
        logs["city_val_loss"] = current

        if current < self.best:
            self.best = current
            self.wait = 0
            if self.restore_best_weights:
                try:
                    self.best_weights = self.model.get_weights()
                except Exception:
                    self.best_weights = None
        else:
            self.wait += 1

        try:
            with open(self.state_file, "w") as f:
                json.dump({"wait": int(self.wait), "best": float(self.best)}, f)
        except Exception:
            pass

        if self.wait >= self.patience:
            if self.restore_best_weights and self.best_weights is not None:
                try:
                    self.model.set_weights(self.best_weights)
                except Exception:
                    pass
            print(f"\nEarly stopping triggered at epoch {epoch + 1} (city_val_loss: {self.best:.6f})")
            self.model.stop_training = True


class CityValLossWeighted(Callback):
    def __init__(self, X_val, static_val, city_ids_val, y_val, city_sizes, city_weights, save_path, penalty_factor=10.0):
        super().__init__()
        self.X_val = X_val
        self.static_val = static_val
        self.city_ids_val = city_ids_val
        self.y_val = y_val
        self.city_sizes = city_sizes
        self.city_weights = city_weights
        self.save_path = save_path
        self.best = np.inf
        self.penalty_factor = penalty_factor

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        try:
            y_pred = self.model.predict([self.X_val, self.static_val, self.city_ids_val], batch_size=1024, verbose=0)
            if isinstance(y_pred, (list, tuple)):
                y_pred_reg = y_pred[0]
            else:
                y_pred_reg = y_pred
        except Exception:
            try:
                y_pred = self.model.predict([self.X_val, self.static_val], batch_size=1024, verbose=0)
                if isinstance(y_pred, (list, tuple)):
                    y_pred_reg = y_pred[0]
                else:
                    y_pred_reg = y_pred
            except Exception:
                return

        start = 0
        weighted_losses = []
        total_weight = 0.0

        for i, n in enumerate(self.city_sizes):
            y_true_city = self.y_val[start:start+n]
            y_pred_city = y_pred_reg[start:start+n]
            start += n

            error = y_true_city - y_pred_city
            penalty = np.where(error > 0, self.penalty_factor, 1.0)
            asymmetric_loss_city = np.mean(np.square(error) * penalty)

            weight = float(self.city_weights[i])
            weighted_losses.append(asymmetric_loss_city * weight)
            total_weight += weight

        city_val_loss = float(np.sum(weighted_losses) / max(total_weight, 1e-8))
        logs["city_val_loss"] = city_val_loss

        if city_val_loss < self.best:
            self.best = city_val_loss
            try:
                self.model.save(self.save_path)
            except Exception:
                pass
            print(f"\nEpoch {epoch+1}: Melhor modelo salvo com city_val_loss ponderado: {city_val_loss:.6f}")


def build_and_save_city_mapping(codigo_ibge_array, save_path):
    unique_cities = np.sort(np.unique(codigo_ibge_array).astype(int))
    city_to_idx = {int(cid): int(idx+1) for idx, cid in enumerate(unique_cities)}
    with open(save_path, "w") as f:
        json.dump(city_to_idx, f)
    return city_to_idx


def load_city_mapping(path):
    with open(path, "r") as f:
        data = json.load(f)
    return {int(k): int(v) for k, v in data.items()}


def process_municipio(municipio_id):
    import os
    import joblib
    municipio_id = int(municipio_id)
    df_mun = df[df["codigo_ibge"] == municipio_id].copy().sort_values(by=["ano","semana"])
    if len(df_mun) < SEQUENCE_LENGTH + HORIZON:
        return None

    dynamic_data = df_mun[dynamic_features].values
    static_data = df_mun[static_features].iloc[0].values.reshape(1, -1)
    X_mun_raw, y_mun_raw = create_sequences_multi_step(dynamic_data, SEQUENCE_LENGTH, HORIZON)
    if X_mun_raw.shape[0] == 0:
        return None

    split_idx = int(len(X_mun_raw) * 0.8)
    if split_idx == 0 or split_idx == len(X_mun_raw):
        return None

    X_train_raw = X_mun_raw[:split_idx]
    y_train_raw = y_mun_raw[:split_idx]
    X_test_raw  = X_mun_raw[split_idx:]
    y_test_raw  = y_mun_raw[split_idx:]

    static_train = np.repeat(static_data, X_train_raw.shape[0], axis=0)
    static_test  = np.repeat(static_data, X_test_raw.shape[0], axis=0)

    scaler_dyn = joblib.load(os.path.join(SCALER_DIR, "scaler_dyn_global.pkl"))
    scaler_static = joblib.load(os.path.join(SCALER_DIR, "scaler_static_global.pkl"))
    scaler_target = joblib.load(os.path.join(SCALER_DIR, "scaler_target_global.pkl"))

    X_train_scaled = scaler_dyn.transform(X_train_raw.reshape(-1, X_train_raw.shape[-1])).reshape(X_train_raw.shape)
    X_test_scaled  = scaler_dyn.transform(X_test_raw.reshape(-1, X_test_raw.shape[-1])).reshape(X_test_raw.shape)
    static_train_scaled = scaler_static.transform(static_train)
    static_test_scaled  = scaler_static.transform(static_test)

    y_train_scaled = scaler_target.transform(y_train_raw.reshape(-1, 1)).reshape(y_train_raw.shape)
    y_test_scaled  = scaler_target.transform(y_test_raw.reshape(-1, 1)).reshape(y_test_raw.shape)

    city_idx = city_to_idx.get(municipio_id, 0)
    city_ids_train = np.repeat(city_idx, X_train_scaled.shape[0]).astype(np.int32)
    city_ids_test  = np.repeat(city_idx, X_test_scaled.shape[0]).astype(np.int32)

    flags_weeks = df_mun["notificacao"].values
    seq_flags_all = flags_weeks[SEQUENCE_LENGTH - 1 : SEQUENCE_LENGTH - 1 + X_mun_raw.shape[0]]  
    split_idx = int(len(X_mun_raw) * 0.8)
    notif_train = seq_flags_all[:split_idx].astype(float)
    notif_test  = seq_flags_all[split_idx:].astype(float)
    return (X_train_scaled, y_train_scaled, X_test_scaled, y_test_scaled,
            static_train_scaled, static_test_scaled,
            y_train_raw, y_test_raw,
            city_ids_train, city_ids_test,
            notif_train, notif_test)

In [None]:
import os
import json
import numpy as np
import pandas as pd
import joblib
from joblib import Parallel, delayed
from tqdm import tqdm
from tqdm_joblib import tqdm_joblib

import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import LSTM, Dense, Input, Concatenate, Dropout, Embedding, Flatten
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.optimizers import Adam

BASE_DIR = "/content/drive/MyDrive/lstm_projeto"
DATA_PATH = os.path.join(BASE_DIR, "data", "final_training_data.parquet")
CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints")
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "model_checkpoint.keras")
CHECKPOINT_BEST_CITY = os.path.join(CHECKPOINT_DIR, "model_checkpoint_best_city.keras")
EPOCH_TRACK_PATH = CHECKPOINT_PATH.replace(".keras", "_epoch.txt")
SCALER_DIR = os.path.join(BASE_DIR, "scalers")
CITY_MAP_PATH = os.path.join(BASE_DIR, "city_to_idx.json")
TRAINING_HISTORY_PATH = os.path.join(BASE_DIR, "training_history.json")
EARLYSTOP_STATE = os.path.join(BASE_DIR, "earlystop_city_state.json")

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(SCALER_DIR, exist_ok=True)
os.makedirs(os.path.join(BASE_DIR, "data"), exist_ok=True)

TARGET_COLUMN = "numero_casos"
SEQUENCE_LENGTH = 12
HORIZON = 6
EPOCHS = 50
BATCH_SIZE = 64

print("\n--- Etapa 4: Carregando e processando os dados ---")
df = pd.read_parquet(DATA_PATH)
df['codigo_ibge'] = df['codigo_ibge'].astype(int)
df = df.sort_values(by=["codigo_ibge", "ano", "semana"])

df["notificacao"] = df["ano"].isin([2021, 2022]).astype(float)
df['casos_velocidade'] = df.groupby('codigo_ibge')['numero_casos'].diff().fillna(0)
df['casos_aceleracao'] = df.groupby('codigo_ibge')['casos_velocidade'].diff().fillna(0)
df['casos_mm_4_semanas'] = df.groupby('codigo_ibge')['numero_casos'].transform(lambda x: x.rolling(4, min_periods=1).mean())

df["week_sin"] = np.sin(2 * np.pi * df["semana"] / 52)
df["week_cos"] = np.cos(2 * np.pi * df["semana"] / 52)
df["year_norm"] = (df["ano"] - df["ano"].min()) / (df["ano"].max() - df["ano"].min())

dynamic_features = [
    "numero_casos",
    "casos_velocidade", "casos_aceleracao", "casos_mm_4_semanas",
    "T2M", "T2M_MAX", "T2M_MIN",
    "PRECTOTCORR", "RH2M", "ALLSKY_SFC_SW_DWN",
    "week_sin", "week_cos", "year_norm", "notificacao"
]
static_features = ["latitude", "longitude"]

municipios = np.sort(df["codigo_ibge"].unique())

if os.path.exists(CITY_MAP_PATH):
    print("Carregando mapeamento cidade -> idx")
    city_to_idx = load_city_mapping(CITY_MAP_PATH)
else:
    print("Criando mapeamento city_to_idx e salvando")
    city_to_idx = build_and_save_city_mapping(municipios, CITY_MAP_PATH)

num_cities = len(city_to_idx) + 1 

scaler_dyn_path = os.path.join(SCALER_DIR, "scaler_dyn_global.pkl")
scaler_static_path = os.path.join(SCALER_DIR, "scaler_static_global.pkl")
scaler_target_path = os.path.join(SCALER_DIR, "scaler_target_global.pkl")

if os.path.exists(scaler_dyn_path) and os.path.exists(scaler_static_path) and os.path.exists(scaler_target_path):
    print("Scalers globais já existem. Pulando a etapa de criação.")
else:
    print("Scalers globais não encontrados. Iniciando processo de criação (pode demorar)...")
    all_dyn_train, all_static_train, all_target_train = [], [], []

    for mid in tqdm(municipios, desc="Coletando sequências para scalers"):
        df_mun = df[df["codigo_ibge"] == mid].sort_values(by=["ano", "semana"])
        dynamic_data = df_mun[dynamic_features].values
        static_data = df_mun[static_features].iloc[0].values.reshape(1, -1)
        X_seq, y_seq = create_sequences_multi_step(dynamic_data, SEQUENCE_LENGTH, HORIZON)
        if X_seq.shape[0] == 0:
            continue

        all_dyn_train.append(X_seq.reshape(-1, X_seq.shape[-1]))
        all_static_train.append(np.repeat(static_data, X_seq.shape[0], axis=0))
        all_target_train.append(y_seq.reshape(-1, 1))

    # concat e fit
    all_dyn_train = np.concatenate(all_dyn_train, axis=0)
    all_static_train = np.concatenate(all_static_train, axis=0)
    all_target_train = np.concatenate(all_target_train, axis=0)

    from sklearn.preprocessing import MinMaxScaler
    scaler_dyn_global = MinMaxScaler().fit(all_dyn_train)
    scaler_static_global = MinMaxScaler().fit(all_static_train)
    scaler_target_global = MinMaxScaler().fit(all_target_train)  

    joblib.dump(scaler_dyn_global, scaler_dyn_path)
    joblib.dump(scaler_static_global, scaler_static_path)
    joblib.dump(scaler_target_global, scaler_target_path)
    print("Scalers globais criados e salvos.")

print("Iniciando processamento por município com scalers globais")
with tqdm_joblib(tqdm(desc="Processando municípios", total=len(municipios))):
    results = Parallel(n_jobs=6, backend='threading')(
        delayed(process_municipio)(int(mid)) for mid in municipios
    )

results = [r for r in results if r is not None]
if len(results) == 0:
    raise RuntimeError("Nenhum município retornou dados válidos. Verifique parâmetros.")

(X_train_list, y_train_list, X_test_list, y_test_list,
 static_train_list, static_test_list,
 y_train_raw_list, y_test_raw_list,
 city_ids_train_list, city_ids_test_list,
 notif_train_parts, notif_test_parts) = zip(*results)


print("\n--- Etapa 5: Agregando dados para o treinamento ---")
X_train = np.concatenate(X_train_list, axis=0)
y_train = np.concatenate(y_train_list, axis=0)
X_test  = np.concatenate(X_test_list, axis=0)
y_test  = np.concatenate(y_test_list, axis=0)
static_train = np.concatenate(static_train_list, axis=0)
static_test  = np.concatenate(static_test_list, axis=0)

city_ids_train = np.concatenate(city_ids_train_list, axis=0).reshape(-1,1)
city_ids_test  = np.concatenate(city_ids_test_list, axis=0).reshape(-1,1)

y_train_raw_agg = np.concatenate(y_train_raw_list, axis=0)  
y_test_raw_agg  = np.concatenate(y_test_raw_list, axis=0)

print(f"Total de sequências de Treino: {X_train.shape[0]}, Teste: {X_test.shape[0]}")

base_weights = 1.0 + np.log1p(np.max(y_train_raw_agg, axis=1))  
base_weights = np.minimum(base_weights, 10.0)  
sample_weights = (base_weights ** 2).astype(np.float32)

notif_flags_train = np.concatenate(notif_train_parts, axis=0) if len(notif_train_parts) > 0 else np.zeros(0, dtype=float)

if notif_flags_train.shape[0] != X_train.shape[0]:
    print(f"[WARN] notif_flags_train ({notif_flags_train.shape[0]}) != X_train ({X_train.shape[0]}). Ajustando com pad/corte.")
    if notif_flags_train.shape[0] < X_train.shape[0]:
        pad = np.zeros(X_train.shape[0] - notif_flags_train.shape[0], dtype=float)
        notif_flags_train = np.concatenate([notif_flags_train, pad], axis=0)
    else:
        notif_flags_train = notif_flags_train[: X_train.shape[0]]

NOTIF_LEARNING_RATE = 0.5   
sample_weights *= np.where(notif_flags_train > 0.5, NOTIF_LEARNING_RATE, 1.0).astype(np.float32)

sample_weights = np.clip(sample_weights, 0.05, 100.0).astype(np.float32)


city_sizes = [x.shape[0] for x in X_test_list]
city_train_peaks = [np.max(y_raw) if y_raw.size > 0 else 1.0 for y_raw in y_train_raw_list]

print("\n--- Etapa 6: Carregando ou criando o modelo ---")
initial_epoch = 0
model = None
if os.path.exists(CHECKPOINT_PATH):
    print(f"Carregando modelo existente de: {CHECKPOINT_PATH}")
    model = load_model(CHECKPOINT_PATH, custom_objects={'asymmetric_mse': asymmetric_mse})
    if os.path.exists(EPOCH_TRACK_PATH):
        with open(EPOCH_TRACK_PATH, "r") as f:
            try:
                initial_epoch = int(f.read().strip()) + 1
            except Exception:
                initial_epoch = 0
    print(f"Continuando o treinamento a partir da época: {initial_epoch}")
else:
    print("Nenhum checkpoint encontrado. Criando um novo modelo.")
    embed_dim = min(32, max(4, num_cities // 10))

    input_dyn = Input(shape=(SEQUENCE_LENGTH, len(dynamic_features)), name='dynamic_input')
    lstm_out = LSTM(128, return_sequences=True)(input_dyn)
    lstm_out = Dropout(0.1)(lstm_out)
    lstm_out = LSTM(128)(lstm_out)
    lstm_out = Dropout(0.1)(lstm_out)

    input_static = Input(shape=(len(static_features),), name='static_input')
    input_city = Input(shape=(1,), dtype='int32', name='city_id_input')

    city_emb = Embedding(input_dim=num_cities, output_dim=embed_dim, name='city_emb')(input_city)
    city_emb_flat = Flatten()(city_emb)

    concat = Concatenate()([lstm_out, input_static, city_emb_flat])

    reg_output = Dense(HORIZON, activation='linear', name='reg_output')(concat)

    model = Model(inputs=[input_dyn, input_static, input_city], outputs=reg_output)
    optimizer = Adam(learning_rate=0.0001, clipnorm=1.0)
    model.compile(optimizer=optimizer, loss=asymmetric_mse)
    print("Novo modelo criado com embedding e compilado (saída reg_output).")

print("\n--- Etapa 7: Iniciando o treinamento do modelo ---")
epoch_saver = EpochSaver(EPOCH_TRACK_PATH)
checkpoint_last = ModelCheckpoint(filepath=CHECKPOINT_PATH, save_best_only=False, save_freq="epoch", verbose=1)

city_val_loss_cb = CityValLossWeighted(
    X_val=X_test,
    static_val=static_test,
    city_ids_val=city_ids_test,
    y_val=y_test,  
    city_sizes=city_sizes,
    city_weights=city_train_peaks,
    save_path=CHECKPOINT_BEST_CITY,
    penalty_factor=10.0
)

early_stopping_city = PersistentEarlyStoppingCityVal(
    city_val_callback=city_val_loss_cb,
    patience=10,
    restore_best_weights=True,
    state_file=EARLYSTOP_STATE
)

history_saver = HistorySaver(TRAINING_HISTORY_PATH)

callbacks = [city_val_loss_cb, early_stopping_city, history_saver, epoch_saver, checkpoint_last]

history = model.fit(
    [X_train, static_train, city_ids_train],
    y_train,   
    validation_data=([X_test, static_test, city_ids_test], y_test),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    initial_epoch=initial_epoch,
    verbose=1,
    sample_weight=sample_weights  
)

try:
    try:
        with open(TRAINING_HISTORY_PATH, "r") as f:
            existing = json.load(f)
    except Exception:
        existing = {}

    if hasattr(history, "history"):
        for k, v in history.history.items():
            existing.setdefault(k, [])
            existing[k].extend([float(x) for x in v])
    with open(TRAINING_HISTORY_PATH, "w") as f:
        json.dump(existing, f)
except Exception:
    pass

print("\n--- TREINAMENTO FINALIZADO! ---")


## Estadual

In [None]:
import os
import json
import numpy as np
import joblib
import tensorflow as tf
from typing import Tuple
from tensorflow.keras.callbacks import Callback

def create_sequences_multi_step(data: np.ndarray, seq_len: int, horizon: int) -> Tuple[np.ndarray, np.ndarray]:
    n = data.shape[0] - seq_len - horizon + 1
    if n <= 0:
        return np.empty((0, seq_len, data.shape[1])), np.empty((0, horizon))
    X = np.zeros((n, seq_len, data.shape[1]), dtype=np.float32)
    y = np.zeros((n, horizon), dtype=np.float32)
    for i in range(n):
        X[i] = data[i : i + seq_len]
        y[i] = data[i + seq_len : i + seq_len + horizon, 0]
    return X, y

def asymmetric_mse(y_true, y_pred):
    penalty_factor = 1.0
    error = y_true - y_pred
    denom = tf.maximum(tf.abs(y_true), 1.0)
    rel = tf.abs(error) / denom
    penalty = tf.where(error > 0, 1.0 + penalty_factor * rel, 1.0)
    loss = tf.square(error) * penalty
    return tf.reduce_mean(loss)

class EpochSaver(Callback):
    def __init__(self, epoch_file_path: str):
        super().__init__()
        self.epoch_file_path = epoch_file_path
    def on_epoch_end(self, epoch, logs=None):
        try:
            with open(self.epoch_file_path, "w") as f:
                f.write(str(int(epoch)))
        except Exception:
            pass

class HistorySaver(Callback):
    def __init__(self, save_path: str):
        super().__init__()
        self.save_path = save_path
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        try:
            with open(self.save_path, "r") as f:
                hist = json.load(f)
        except Exception:
            hist = {}
        for k, v in logs.items():
            try:
                val = float(v)
            except Exception:
                continue
            hist.setdefault(k, []).append(val)
        try:
            with open(self.save_path, "w") as f:
                json.dump(hist, f)
        except Exception:
            pass

class PersistentEarlyStoppingStateVal(Callback):
    def __init__(self, state_val_callback, patience=10, restore_best_weights=True, state_file="earlystop_state_state.json"):
        super().__init__()
        self.state_val_callback = state_val_callback
        self.patience = patience
        self.restore_best_weights = restore_best_weights
        self.state_file = state_file
        self.wait = 0
        self.best_weights = None
        self.best = np.inf
        if os.path.exists(self.state_file):
            try:
                with open(self.state_file, "r") as f:
                    state = json.load(f)
                self.wait = int(state.get("wait", 0))
                self.best = float(state.get("best", np.inf))
            except Exception:
                self.wait = 0
                self.best = np.inf
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        current = logs.get("state_val_loss")
        if current is None and hasattr(self.state_val_callback, "on_epoch_end"):
            self.state_val_callback.on_epoch_end(epoch, logs)
            current = logs.get("state_val_loss")
        if current is None:
            return
        current = float(current)
        logs["state_val_loss"] = current
        if current < self.best:
            self.best = current
            self.wait = 0
            if self.restore_best_weights:
                try:
                    self.best_weights = self.model.get_weights()
                except Exception:
                    self.best_weights = None
        else:
            self.wait += 1
        try:
            with open(self.state_file, "w") as f:
                json.dump({"wait": int(self.wait), "best": float(self.best)}, f)
        except Exception:
            pass
        if self.wait >= self.patience:
            if self.restore_best_weights and self.best_weights is not None:
                try:
                    self.model.set_weights(self.best_weights)
                except Exception:
                    pass
            print(f"\nEarly stopping triggered at epoch {epoch + 1} (state_val_loss: {self.best:.6f})")
            self.model.stop_training = True

class StateValLossWeighted(Callback):
    def __init__(self, X_val, static_val, state_ids_val, y_val, state_sizes, state_weights, save_path, penalty_factor=1.0, clip_scaled_pred=False):
        super().__init__()
        self.X_val = X_val
        self.static_val = static_val
        self.state_ids_val = state_ids_val
        self.y_val = y_val
        self.state_sizes = state_sizes
        self.state_weights = state_weights
        self.save_path = save_path
        self.best = np.inf
        self.penalty_factor = penalty_factor
        self.clip_scaled_pred = clip_scaled_pred
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        try:
            y_pred = self.model.predict([self.X_val, self.static_val, self.state_ids_val], batch_size=1024, verbose=0)
            if isinstance(y_pred, (list, tuple)):
                y_pred_reg = y_pred[0]
            else:
                y_pred_reg = y_pred
        except Exception:
            try:
                y_pred = self.model.predict([self.X_val, self.static_val], batch_size=1024, verbose=0)
                if isinstance(y_pred, (list, tuple)):
                    y_pred_reg = y_pred[0]
                else:
                    y_pred_reg = y_pred
            except Exception:
                return

        if self.clip_scaled_pred:
            y_pred_reg = np.clip(y_pred_reg, 0.0, 1.0)

        start = 0
        weighted_losses = []
        total_weight = 0.0
        for i, n in enumerate(self.state_sizes):
            y_true_state = self.y_val[start:start+n]
            y_pred_state = y_pred_reg[start:start+n]
            start += n
            error = y_true_state - y_pred_state
            penalty = np.where(error > 0, self.penalty_factor, 1.0)
            asymmetric_loss_state = np.mean(np.square(error) * penalty)
            weight = float(self.state_weights[i])
            weighted_losses.append(asymmetric_loss_state * weight)
            total_weight += weight
        state_val_loss = float(np.sum(weighted_losses) / max(total_weight, 1e-8))
        logs["state_val_loss"] = state_val_loss
        if state_val_loss < self.best:
            self.best = state_val_loss
            try:
                self.model.save(self.save_path)
            except Exception:
                pass
            print(f"\nEpoch {epoch+1}: Best model saved with weighted state_val_loss: {state_val_loss:.6f}")


In [None]:
import os
import json
import joblib
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import LSTM, Dense, Input, Concatenate, Dropout, Embedding, Flatten
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.optimizers import Adam

BASE_DIR = "/content/drive/MyDrive/lstm_projeto"
DATA_PATH = os.path.join(BASE_DIR, "data", "final_training_data_estadual.parquet")
CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints")
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "model_state_checkpoint.keras")
CHECKPOINT_BEST_STATE = os.path.join(CHECKPOINT_DIR, "model_checkpoint_best_state.keras")
SCALER_DIR = os.path.join(BASE_DIR, "scalers")
STATE_MAP_PATH = os.path.join(BASE_DIR, "state_to_idx.json")
STATE_PEAK_PATH = os.path.join(BASE_DIR, "state_peak.json")   
TRAINING_HISTORY_PATH = os.path.join(BASE_DIR, "training_history_state.json")
EPOCH_TRACK_PATH_STATE = CHECKPOINT_PATH.replace(".keras", "_epoch.txt")

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(SCALER_DIR, exist_ok=True)

SEQUENCE_LENGTH = 12
HORIZON = 6
EPOCHS = 120
BATCH_SIZE = 32

NOTIF_LEARNING_RATE = 0.8
POP_ALPHA = 0.25
SAMPLE_WEIGHT_MINMAX = (0.1, 2.0)

df = pd.read_parquet(DATA_PATH)
df = df.sort_values(by=["estado_sigla","year","week"]).reset_index(drop=True)

df["casos_soma"] = df["casos_soma"].astype(float)
df["populacao_total"] = df["populacao_total"].astype(float)
df["notificacao"] = df["year"].isin([2021,2022]).astype(float)

df["casos_velocidade"] = df.groupby("estado_sigla")["casos_soma"].diff().fillna(0)
df["casos_aceleracao"] = df.groupby("estado_sigla")["casos_velocidade"].diff().fillna(0)
df["casos_mm_4_semanas"] = df.groupby("estado_sigla")["casos_soma"].transform(lambda x: x.rolling(4, min_periods=1).mean())
df["week_sin"] = np.sin(2*np.pi*df["week"]/52)
df["week_cos"] = np.cos(2*np.pi*df["week"]/52)
df["year_norm"] = (df["year"] - df["year"].min()) / max(1.0, (df["year"].max() - df["year"].min()))


state_peak = df.groupby("estado_sigla")["casos_soma"].max().to_dict()
state_peak = {k: (v if (v is not None and v > 0) else 1.0) for k,v in state_peak.items()}
with open(STATE_PEAK_PATH, "w") as f:
    json.dump(state_peak, f)

df["casos_norm"] = df.apply(lambda r: r["casos_soma"] / state_peak.get(r["estado_sigla"], 1.0), axis=1)
df["casos_norm_log"] = np.log1p(df["casos_norm"])

dynamic_features = [
    "casos_norm_log", 
    "casos_velocidade","casos_aceleracao","casos_mm_4_semanas",
    "T2M_mean","T2M_std","PRECTOTCORR_mean","PRECTOTCORR_std",
    "RH2M_mean","RH2M_std","ALLSKY_SFC_SW_DWN_mean","ALLSKY_SFC_SW_DWN_std",
    "week_sin","week_cos","year_norm","notificacao"
]
static_features = ["populacao_total"]

estados = np.sort(df["estado_sigla"].unique())

if os.path.exists(STATE_MAP_PATH):
    with open(STATE_MAP_PATH,"r") as f:
        state_to_idx = json.load(f)
else:
    state_to_idx = {str(s): i+1 for i,s in enumerate(estados)}
    with open(STATE_MAP_PATH,"w") as f:
        json.dump(state_to_idx,f)
num_states = len(state_to_idx) + 1
idx_to_state = {int(v): k for k,v in state_to_idx.items()}

scaler_dyn_path = os.path.join(SCALER_DIR, "scaler_dyn_global_state.pkl")
scaler_static_path = os.path.join(SCALER_DIR, "scaler_static_global_state.pkl")
scaler_target_path = os.path.join(SCALER_DIR, "scaler_target_global_state.pkl")

if not (os.path.exists(scaler_dyn_path) and os.path.exists(scaler_static_path) and os.path.exists(scaler_target_path)):
    all_dyn, all_static, all_tgt = [], [], []
    for st in estados:
        df_st = df[df["estado_sigla"]==st].sort_values(by=["year","week"])
        X = df_st[dynamic_features].values
        if X.shape[0] < SEQUENCE_LENGTH+HORIZON:
            continue
        all_dyn.append(X)
        S = df_st[static_features].iloc[0].values.reshape(1,-1)
        all_static.append(np.repeat(S, X.shape[0], axis=0))

        yvals = df_st["casos_norm_log"].values.reshape(-1,1)
        all_tgt.append(yvals)

    if len(all_dyn) == 0:
        raise RuntimeError("Sem sequências suficientes para ajustar scalers. Verifique dados.")
    all_dyn = np.concatenate(all_dyn, axis=0)
    all_static = np.concatenate(all_static, axis=0)
    all_tgt = np.concatenate(all_tgt, axis=0)

    scaler_dyn = MinMaxScaler().fit(all_dyn)
    scaler_static = MinMaxScaler().fit(all_static)
    scaler_target = MinMaxScaler().fit(all_tgt)  
    joblib.dump(scaler_dyn, scaler_dyn_path)
    joblib.dump(scaler_static, scaler_static_path)
    joblib.dump(scaler_target, scaler_target_path)
else:
    scaler_dyn = joblib.load(scaler_dyn_path)
    scaler_static = joblib.load(scaler_static_path)
    scaler_target = joblib.load(scaler_target_path)

X_train_list, y_train_list, static_train_list, ids_train_list, notif_train_parts = [], [], [], [], []
X_test_list, y_test_list, static_test_list, ids_test_list, notif_test_parts = [], [], [], [], []

for st in tqdm(estados, desc="preparing sequences per state"):
    df_st = df[df["estado_sigla"]==st].sort_values(by=["year","week"])
    arr = df_st[dynamic_features].values
    X_all, y_all = create_sequences_multi_step(arr, SEQUENCE_LENGTH, HORIZON)
    if X_all.shape[0] == 0:
        continue

    y_all_scaled = scaler_target.transform(y_all.reshape(-1,1)).reshape(y_all.shape)

    split = int(0.8 * X_all.shape[0])
    static_rep = df_st[static_features].iloc[0].values.reshape(1,-1)

    flags = df_st["notificacao"].values
    seq_flags_all = flags[SEQUENCE_LENGTH - 1 : SEQUENCE_LENGTH - 1 + X_all.shape[0]]
    notif_train = seq_flags_all[:split].astype(float) if split>0 else np.empty((0,))
    notif_test = seq_flags_all[split:].astype(float) if split < X_all.shape[0] else np.empty((0,))

    if split > 0:
        X_train_list.append(X_all[:split])
        y_train_list.append(y_all_scaled[:split])
        static_train_list.append(np.repeat(static_rep, split, axis=0))
        ids_train_list.append(np.repeat(state_to_idx[st], split).reshape(-1,1))
        notif_train_parts.append(notif_train)
    if split < X_all.shape[0]:
        ntest = X_all.shape[0] - split
        X_test_list.append(X_all[split:])
        y_test_list.append(y_all_scaled[split:])
        static_test_list.append(np.repeat(static_rep, ntest, axis=0))
        ids_test_list.append(np.repeat(state_to_idx[st], ntest).reshape(-1,1))
        notif_test_parts.append(notif_test)

if len(X_train_list) == 0:
    raise RuntimeError("Nenhuma sequência de treino encontrada.")

X_train = np.concatenate(X_train_list, axis=0)
y_train = np.concatenate(y_train_list, axis=0)
static_train = np.concatenate(static_train_list, axis=0)
state_ids_train = np.concatenate(ids_train_list, axis=0).reshape(-1,1)

X_test = np.concatenate(X_test_list, axis=0)
y_test = np.concatenate(y_test_list, axis=0)
static_test = np.concatenate(static_test_list, axis=0)
state_ids_test = np.concatenate(ids_test_list, axis=0).reshape(-1,1)

state_pop = {st: float(df[df["estado_sigla"]==st]["populacao_total"].iloc[0]) for st in estados}
median_pop = float(np.median(list(state_pop.values())))
state_peak = {st: float(df[df["estado_sigla"]==st]["casos_soma"].max() if not df[df["estado_sigla"]==st].empty else 1.0) for st in estados}

train_state_ids = state_ids_train.flatten()
per_sample_peak = np.array([state_peak[idx_to_state[int(s)]] for s in train_state_ids], dtype=float)
per_sample_pop = np.array([state_pop[idx_to_state[int(s)]] for s in train_state_ids], dtype=float)

peak_per100k = per_sample_peak / (per_sample_pop / 100000.0 + 1e-8)
base_w = np.sqrt(1.0 + peak_per100k)
pop_factor = (per_sample_pop / (median_pop + 1e-8)) ** POP_ALPHA
sample_weights = base_w * pop_factor

if len(notif_train_parts) > 0:
    notif_flags_train = np.concatenate(notif_train_parts, axis=0)
else:
    notif_flags_train = np.zeros(sample_weights.shape[0], dtype=float)

if notif_flags_train.shape[0] != sample_weights.shape[0]:
    print(f"[WARN] notif_flags_train ({notif_flags_train.shape[0]}) != sample_weights ({sample_weights.shape[0]}). Ajustando.")
    if notif_flags_train.shape[0] < sample_weights.shape[0]:
        pad = np.zeros(sample_weights.shape[0] - notif_flags_train.shape[0], dtype=float)
        notif_flags_train = np.concatenate([notif_flags_train, pad], axis=0)
    else:
        notif_flags_train = notif_flags_train[: sample_weights.shape[0]]

sample_weights = sample_weights * np.where(notif_flags_train > 0.5, NOTIF_LEARNING_RATE, 1.0)
sample_weights = np.clip(sample_weights, SAMPLE_WEIGHT_MINMAX[0], SAMPLE_WEIGHT_MINMAX[1]).astype(np.float32)

ns, sl, nf = X_train.shape
X_train_2d = X_train.reshape(-1, nf)
X_test_2d = X_test.reshape(-1, nf)
X_train_scaled = scaler_dyn.transform(X_train_2d).reshape(ns, sl, nf)
X_test_scaled = scaler_dyn.transform(X_test_2d).reshape(X_test.shape[0], sl, nf)
static_train_scaled = scaler_static.transform(static_train)
static_test_scaled = scaler_static.transform(static_test)

embed_dim = 8
input_dyn = Input(shape=(SEQUENCE_LENGTH, len(dynamic_features)), name='dynamic_input')
input_static = Input(shape=(len(static_features),), name='static_input')
input_state = Input(shape=(1,), dtype='int32', name='state_id_input')

x = LSTM(128, return_sequences=True)(input_dyn)
x = Dropout(0.1)(x)
x = LSTM(128)(x)
x = Dropout(0.1)(x)

emb = Embedding(input_dim=num_states, output_dim=embed_dim)(input_state)
embf = Flatten()(emb)
cat = Concatenate()([x, input_static, embf])

activation = 'linear'
out = Dense(HORIZON, activation=activation, name='reg_output')(cat)

model = Model([input_dyn, input_static, input_state], out)

optimizer = Adam(learning_rate=1e-4, clipnorm=1.0)

model.compile(optimizer=optimizer, loss=asymmetric_mse)

epoch_saver = EpochSaver(EPOCH_TRACK_PATH_STATE)
history_saver = HistorySaver(TRAINING_HISTORY_PATH)
checkpoint_last = ModelCheckpoint(filepath=CHECKPOINT_PATH, save_best_only=False, save_freq="epoch", verbose=1)
checkpoint_best = ModelCheckpoint(filepath=CHECKPOINT_BEST_STATE, save_best_only=True, monitor="val_loss", verbose=1)

state_sizes = [x.shape[0] for x in X_test_list]
state_weights = [np.max(y_raw) if y_raw.size > 0 else 1.0 for y_raw in y_test_list]  
state_val_cb = StateValLossWeighted(
    X_val=X_test_scaled,
    static_val=static_test_scaled,
    state_ids_val=state_ids_test,
    y_val=y_test,
    state_sizes=state_sizes,
    state_weights=state_weights,
    save_path=CHECKPOINT_BEST_STATE,
    penalty_factor=1.0
)

early_stopping_state = PersistentEarlyStoppingStateVal(
    state_val_callback=state_val_cb,
    patience=10,
    restore_best_weights=True,
    state_file=os.path.join(BASE_DIR, "earlystop_state_state.json")
)

callbacks = [state_val_cb, early_stopping_state, history_saver, epoch_saver, checkpoint_last]

history = model.fit(
    [X_train_scaled, static_train_scaled, state_ids_train],
    y_train,
    validation_data=([X_test_scaled, static_test_scaled, state_ids_test], y_test),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    sample_weight=sample_weights,
    verbose=1
)

joblib.dump(scaler_dyn, os.path.join(SCALER_DIR,"scaler_dyn_global_state.pkl"))
joblib.dump(scaler_static, os.path.join(SCALER_DIR,"scaler_static_global_state.pkl"))
joblib.dump(scaler_target, os.path.join(SCALER_DIR,"scaler_target_global_state.pkl"))
model.save(CHECKPOINT_PATH)

print("Treinamento estadual concluído.")

# Usando

## Municipal

In [None]:
import os
import json
import joblib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from warnings import filterwarnings
import tensorflow as tf
from tensorflow.keras.utils import register_keras_serializable
from tensorflow.keras.models import load_model

filterwarnings("ignore")
tf.get_logger().setLevel("ERROR")

@register_keras_serializable(package="Custom", name="asymmetric_mse")
def asymmetric_mse(y_true, y_pred):
    penalty_factor = 10.0
    error = y_true - y_pred
    denom = tf.maximum(tf.abs(y_true), 1.0)
    rel = tf.abs(error) / denom
    penalty = tf.where(error > 0, 1.0 + penalty_factor * rel, 1.0)
    loss = tf.square(error) * penalty
    return tf.reduce_mean(loss)

BASE_DIR = "/content/drive/MyDrive/lstm_projeto"
DATA_PATH = os.path.join(BASE_DIR, "data", "inference_data.parquet")
MODEL_PATH = os.path.join(BASE_DIR, "checkpoints", "model_checkpoint_best_city.keras")
SCALER_DIR = os.path.join(BASE_DIR, "scalers")
CITY_MAP_PATH = os.path.join(BASE_DIR, "city_to_idx.json")

HORIZON = 6
SEQUENCE_LENGTH = 12
BATCH_SIZE = 64

DYN_FEATS = [
    "numero_casos", "casos_velocidade", "casos_aceleracao", "casos_mm_4_semanas",
    "T2M", "T2M_MAX", "T2M_MIN", "PRECTOTCORR", "RH2M", "ALLSKY_SFC_SW_DWN",
    "week_sin", "week_cos", "year_norm", "notificacao"
]
STATIC_FEATS = ["latitude", "longitude"]

YEAR_MIN_TRAIN = 2014
YEAR_MAX_TRAIN = 2025

model = load_model(MODEL_PATH, compile=False)
scaler_dyn = joblib.load(os.path.join(SCALER_DIR, "scaler_dyn_global.pkl"))
scaler_static = joblib.load(os.path.join(SCALER_DIR, "scaler_static_global.pkl"))
scaler_target = joblib.load(os.path.join(SCALER_DIR, "scaler_target_global.pkl"))

if os.path.exists(CITY_MAP_PATH):
    with open(CITY_MAP_PATH, "r") as f:
        raw_map = json.load(f)
        city_to_idx = {int(k): int(v) for k, v in raw_map.items()}
else:
    city_to_idx = {}

df = pd.read_parquet(DATA_PATH)
df['codigo_ibge'] = df['codigo_ibge'].astype(int)
df['ano'] = df['ano'].astype(int)
df['semana'] = df['semana'].astype(int)
df = df.sort_values(by=["codigo_ibge", "ano", "semana"]).reset_index(drop=True)

try:
    df['date'] = pd.to_datetime(df['ano'].astype(str) + df['semana'].astype(str) + '0', format='%Y%W%w', errors='coerce')
except Exception:
    df['date'] = pd.NaT

def prever_para_municipio(ibge: int, start_year: int, start_week: int, show_plot=True, display_history_weeks=250):
    df_mun = df[df['codigo_ibge'] == int(ibge)].copy().reset_index(drop=True)
    if df_mun.empty:
        return None

    mun_name = df_mun.get('municipio', pd.Series([str(ibge)])).iloc[0]

    df_mun['notificacao'] = df_mun['ano'].isin([2021, 2022]).astype(float)
    df_mun["week_sin"] = np.sin(2 * np.pi * df_mun["semana"] / 52)
    df_mun["week_cos"] = np.cos(2 * np.pi * df_mun["semana"] / 52)
    df_mun["year_norm"] = (df_mun["ano"] - YEAR_MIN_TRAIN) / (YEAR_MAX_TRAIN - YEAR_MIN_TRAIN)
    df_mun['casos_velocidade'] = df_mun['numero_casos'].diff().fillna(0)
    df_mun['casos_aceleracao'] = df_mun['casos_velocidade'].diff().fillna(0)
    df_mun['casos_mm_4_semanas'] = df_mun['numero_casos'].rolling(4, min_periods=1).mean()

    idx_list = df_mun.index[(df_mun['ano'] == start_year) & (df_mun['semana'] == start_week)].tolist()
    if not idx_list:
        return None

    pred_point_idx = idx_list[0]
    last_known_idx = pred_point_idx - 1
    if last_known_idx < SEQUENCE_LENGTH - 1:
        return None

    start_idx = last_known_idx - SEQUENCE_LENGTH + 1
    input_sequence_df = df_mun.iloc[start_idx : last_known_idx + 1]

    static_raw = input_sequence_df[STATIC_FEATS].iloc[0].values.reshape(1, -1)
    dyn_hist_raw = input_sequence_df[DYN_FEATS].values

    input_dyn_scaled = scaler_dyn.transform(dyn_hist_raw).reshape(1, SEQUENCE_LENGTH, len(DYN_FEATS))
    input_static_scaled = scaler_static.transform(static_raw)

    city_idx = city_to_idx.get(int(ibge), 0)
    input_city = np.array([[city_idx]], dtype=np.int32)

    y_pred = model.predict([input_dyn_scaled, input_static_scaled, input_city], verbose=0)
    if isinstance(y_pred, (list, tuple)):
        y_pred_reg = y_pred[0]
    else:
        y_pred_reg = y_pred

    y_pred_flat = y_pred_reg.reshape(-1, 1)
    y_pred_inv_flat = scaler_target.inverse_transform(y_pred_flat)
    y_pred_inv = y_pred_inv_flat.reshape(y_pred_reg.shape)
    prediction_cases = np.maximum(y_pred_inv.flatten(), 0.0)

    last_known_case = df_mun.iloc[last_known_idx]['numero_casos']
    connected_prediction = np.insert(prediction_cases, 0, last_known_case)

    hist_to_display = df_mun.iloc[max(0, last_known_idx - display_history_weeks) : last_known_idx + 1]

    if 'date' in hist_to_display.columns and not hist_to_display['date'].isna().all():
        start_date_for_pred = hist_to_display['date'].iloc[-1]
        x_pred = pd.date_range(start=start_date_for_pred, periods=len(connected_prediction), freq='W-MON')
    else:
        x_pred = np.arange(len(hist_to_display)-1, len(hist_to_display)-1 + len(connected_prediction))

    if show_plot:
        plt.figure(figsize=(14,6))
        if 'date' in hist_to_display.columns and not hist_to_display['date'].isna().all():
            plt.plot(hist_to_display['date'], hist_to_display['numero_casos'], 'o-', label="Historical")
            future_real = df_mun.iloc[last_known_idx + 1 : last_known_idx + 1 + HORIZON]
            if not future_real.empty and 'date' in future_real.columns:
                plt.plot(future_real['date'], future_real['numero_casos'], 'o-', color='green', label="Future (real)")
            plt.plot(x_pred, connected_prediction, 'o--', color='red', label=f"Forecast (H={HORIZON} weeks)")
            plt.axvline(x=hist_to_display['date'].iloc[-1], color="black", linestyle=":", linewidth=2, label="Forecast point")
            plt.xlabel("Date (approx. epi week)")
        else:
            plt.plot(range(len(hist_to_display)), hist_to_display['numero_casos'], 'o-', label="Historical")
            plt.plot(x_pred, connected_prediction, 'o--', color='red', label=f"Forecast (H={HORIZON} weeks)")
            plt.axvline(x=len(hist_to_display)-1, color="black", linestyle=":", linewidth=2, label="Forecast point")
            plt.xlabel("Index")
        plt.title(f"Weekly Cases Forecast - {mun_name} ({ibge})")
        plt.ylabel("Weekly cases")
        plt.grid(True, linestyle='--', linewidth=0.5)
        plt.legend()
        plt.tight_layout()
        plt.show()

    return {
        "municipio": mun_name,
        "ibge": int(ibge),
        "last_known_index": int(last_known_idx),
        "prediction_scaled": y_pred_reg,
        "prediction_counts": prediction_cases
    }


In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import plotly.graph_objects as go

ibge_widget = widgets.IntText(description="IBGE:", value=3509502)
year_widget = widgets.IntText(description="Year:", value=2025)
week_widget = widgets.IntText(description="Week:", value=20)
button = widgets.Button(description="Generate Forecast")
output = widgets.Output()

def safe_get_column(df, *names):
    low = {c.lower(): c for c in df.columns}
    for n in names:
        if n and n.lower() in low:
            return low[n.lower()]
    return None

def show_interactive(df_display, ibge, year, week):
    df2 = df_display.copy()
    date_col = safe_get_column(df2, "date", "week", "period") or df2.columns[0]
    real_col = safe_get_column(df2, "real", "actual", "cases", "observed", "numero_casos", "casos_soma")
    pred_col = safe_get_column(df2, "pred", "forecast", "prediction", "predicted", "forecast")
    if date_col in df2.columns:
        df2[date_col] = df2[date_col].astype(str)
    fig = go.Figure()
    if real_col and real_col in df2.columns:
        fig.add_trace(go.Scatter(x=df2[date_col], y=df2[real_col], mode="lines+markers", name="Real", line=dict(color="green")))
    if pred_col and pred_col in df2.columns:
        fig.add_trace(go.Scatter(x=df2[date_col], y=df2[pred_col], mode="lines+markers", name="Forecast", line=dict(color="red", dash="dash")))
    fig.update_layout(title=f"Real vs Forecast — {ibge} ({year} W{week})", hovermode="x unified", margin=dict(l=40, r=20, t=40, b=30))
    fig.show()

def on_button_clicked(b):
    with output:
        clear_output(wait=True)
        try:
            ibge = int(ibge_widget.value)
            year = int(year_widget.value)
            week = int(week_widget.value)
        except Exception:
            return
        try:
            result = prever_para_municipio(ibge, year, week)
        except NameError:
            return
        except Exception:
            return
        if result is None:
            return
        last_idx = result["last_known_index"]
        hist = df[(df["codigo_ibge"] == int(ibge))].reset_index(drop=True)
        start = max(0, last_idx - 50)
        hist_slice = hist.iloc[start:last_idx+1].copy()
        pred_vals = list(result["prediction_counts"])
        if 'date' in hist_slice.columns and not hist_slice['date'].isna().all():
            dates = list(hist_slice['date']) + [pd.to_datetime(hist_slice['date'].iloc[-1]) + pd.to_timedelta(7*i, unit='d') for i in range(1, len(pred_vals)+1)]
            display_df = pd.DataFrame({"date": dates, "real": list(hist_slice['numero_casos']) + [None]*len(pred_vals), "forecast": [None]*len(hist_slice) + pred_vals})
        else:
            idxs = list(range(len(hist_slice))) + list(range(len(hist_slice), len(hist_slice) + len(pred_vals)))
            display_df = pd.DataFrame({"index": idxs, "real": list(hist_slice['numero_casos']) + [None]*len(pred_vals), "forecast": [None]*len(hist_slice) + pred_vals})
        show_interactive(display_df, ibge, year, week)

button.on_click(on_button_clicked)

display(ibge_widget, year_widget, week_widget, button, output)


In [None]:
import os
import json
import numpy as np
import pandas as pd
import joblib
from joblib import Parallel, delayed
from tqdm import tqdm
from tqdm_joblib import tqdm_joblib
import tensorflow as tf
from tensorflow.keras.models import load_model
from sklearn.metrics import mean_absolute_error, mean_squared_error


from typing import Tuple
from tensorflow.keras.utils import register_keras_serializable

@register_keras_serializable(package="Custom", name="asymmetric_mse")
def asymmetric_mse(y_true, y_pred):
    penalty_factor = 5.0
    error = y_true - y_pred
    denom = tf.maximum(tf.abs(y_true), 1.0)
    rel = tf.abs(error) / denom
    penalty = tf.where(error > 0, 1.0 + penalty_factor * rel, 1.0)
    loss = tf.square(error) * penalty
    return tf.reduce_mean(loss)

def create_sequences_multi_step(data: np.ndarray, seq_len: int, horizon: int) -> Tuple[np.ndarray, np.ndarray]:
    num_samples = data.shape[0] - seq_len - horizon + 1
    if num_samples <= 0:
        return np.array([]), np.array([])
    X = np.zeros((num_samples, seq_len, data.shape[1]), dtype=np.float32)
    y = np.zeros((num_samples, horizon), dtype=np.float32)
    for i in range(num_samples):
        X[i] = data[i : i + seq_len]
        y[i] = data[i + seq_len : i + seq_len + horizon, 0]
    return X, y

def load_city_mapping(path):
    with open(path, "r") as f:
        data = json.load(f)
    return {int(k): int(v) for k, v in data.items()}

def process_municipio(municipio_id):
    global df, SEQUENCE_LENGTH, HORIZON, dynamic_features, static_features, SCALER_DIR, city_to_idx

    import os
    import joblib
    municipio_id = int(municipio_id)
    df_mun = df[df["codigo_ibge"] == municipio_id].copy().sort_values(by=["ano","semana"])
    if len(df_mun) < SEQUENCE_LENGTH + HORIZON:
        return None

    dynamic_data = df_mun[dynamic_features].values
    static_data = df_mun[static_features].iloc[0].values.reshape(1, -1)
    X_mun_raw, y_mun_raw = create_sequences_multi_step(dynamic_data, SEQUENCE_LENGTH, HORIZON)
    if X_mun_raw.shape[0] == 0:
        return None

    split_idx = int(len(X_mun_raw) * 0.8)
    if split_idx == 0 or split_idx == len(X_mun_raw):
        return None

    X_train_raw = X_mun_raw[:split_idx]
    y_train_raw = y_mun_raw[:split_idx]
    X_test_raw  = X_mun_raw[split_idx:]
    y_test_raw  = y_mun_raw[split_idx:]

    static_train = np.repeat(static_data, X_train_raw.shape[0], axis=0)
    static_test  = np.repeat(static_data, X_test_raw.shape[0], axis=0)

    scaler_dyn = joblib.load(os.path.join(SCALER_DIR, "scaler_dyn_global.pkl"))
    scaler_static = joblib.load(os.path.join(SCALER_DIR, "scaler_static_global.pkl"))
    scaler_target = joblib.load(os.path.join(SCALER_DIR, "scaler_target_global.pkl"))

    X_train_scaled = scaler_dyn.transform(X_train_raw.reshape(-1, X_train_raw.shape[-1])).reshape(X_train_raw.shape)
    X_test_scaled  = scaler_dyn.transform(X_test_raw.reshape(-1, X_test_raw.shape[-1])).reshape(X_test_raw.shape)
    static_train_scaled = scaler_static.transform(static_train)
    static_test_scaled  = scaler_static.transform(static_test)

    y_train_scaled = scaler_target.transform(y_train_raw.reshape(-1, 1)).reshape(y_train_raw.shape)
    y_test_scaled  = scaler_target.transform(y_test_raw.reshape(-1, 1)).reshape(y_test_raw.shape)

    city_idx = city_to_idx.get(municipio_id, 0)
    city_ids_train = np.repeat(city_idx, X_train_scaled.shape[0]).astype(np.int32)
    city_ids_test  = np.repeat(city_idx, X_test_scaled.shape[0]).astype(np.int32)

    return (X_train_scaled, y_train_scaled, X_test_scaled, y_test_scaled,
            static_train_scaled, static_test_scaled,
            y_train_raw, y_test_raw,
            city_ids_train, city_ids_test)

def safe_mape(y_true, y_pred):
    """ Calcula o MAPE (Mean Absolute Percentage Error) de forma segura """
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    mask = y_true >= 1 
    if not np.any(mask):
        return np.nan
    mape = np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask]))
    return mape * 100.0

BASE_DIR = "/content/drive/MyDrive/lstm_projeto"
DATA_PATH = os.path.join(BASE_DIR, "data", "final_training_data.parquet")
CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints")
MODEL_TO_EVALUATE = os.path.join(CHECKPOINT_DIR, "model_checkpoint_best_city.keras")
SCALER_DIR = os.path.join(BASE_DIR, "scalers")
CITY_MAP_PATH = os.path.join(BASE_DIR, "city_to_idx.json")

SEQUENCE_LENGTH = 12
HORIZON = 6
BATCH_SIZE = 64

print("\n--- Etapa 1: Carregando e processando os dados ---")
df = pd.read_parquet(DATA_PATH)
df['codigo_ibge'] = df['codigo_ibge'].astype(int)
df = df.sort_values(by=["codigo_ibge", "ano", "semana"])

df["notificacao"] = df["ano"].isin([2021, 2022]).astype(float)
df['casos_velocidade'] = df.groupby('codigo_ibge')['numero_casos'].diff().fillna(0)
df['casos_aceleracao'] = df.groupby('codigo_ibge')['casos_velocidade'].diff().fillna(0)
df['casos_mm_4_semanas'] = df.groupby('codigo_ibge')['numero_casos'].transform(lambda x: x.rolling(4, min_periods=1).mean())

df["week_sin"] = np.sin(2 * np.pi * df["semana"] / 52)
df["week_cos"] = np.cos(2 * np.pi * df["semana"] / 52)
df["year_norm"] = (df["ano"] - df["ano"].min()) / (df["ano"].max() - df["ano"].min())

dynamic_features = [
    "numero_casos", "casos_velocidade", "casos_aceleracao", "casos_mm_4_semanas",
    "T2M", "T2M_MAX", "T2M_MIN", "PRECTOTCORR", "RH2M", "ALLSKY_SFC_SW_DWN",
    "week_sin", "week_cos", "year_norm", "notificacao"
]
static_features = ["latitude", "longitude"]

municipios = np.sort(df["codigo_ibge"].unique())

print("Carregando mapeamento cidade -> idx")
city_to_idx = load_city_mapping(CITY_MAP_PATH)

print("Iniciando processamento por município (recriando conjuntos de treino/teste)")
with tqdm_joblib(tqdm(desc="Processando municípios", total=len(municipios))):
    results = Parallel(n_jobs=6, backend='threading')(
        delayed(process_municipio)(int(mid)) for mid in municipios
    )

results = [r for r in results if r is not None]
if len(results) == 0:
    raise RuntimeError("Nenhum município retornou dados válidos.")

(X_train_list, y_train_list, X_test_list, y_test_list,
 static_train_list, static_test_list,
 y_train_raw_list, y_test_raw_list,
 city_ids_train_list, city_ids_test_list) = zip(*results)

print("\n--- Etapa 2: Agregando dados de teste/validação ---")
X_test  = np.concatenate(X_test_list, axis=0)
y_test  = np.concatenate(y_test_list, axis=0)
static_test  = np.concatenate(static_test_list, axis=0)
city_ids_test  = np.concatenate(city_ids_test_list, axis=0).reshape(-1,1)

y_test_raw_agg  = np.concatenate(y_test_raw_list, axis=0)

print(f"Total de sequências de Teste/Validação recriadas: {X_test.shape[0]}")

print(f"\n--- Etapa 3: Carregando modelo de {MODEL_TO_EVALUATE} ---")
model = load_model(MODEL_TO_EVALUATE, custom_objects={'asymmetric_mse': asymmetric_mse})
print("Modelo carregado.")

print("Carregando scaler do alvo (target)...")
scaler_target_path = os.path.join(SCALER_DIR, "scaler_target_global.pkl")
scaler_target = joblib.load(scaler_target_path)
print("Scaler carregado.")

print("\n--- Etapa 4: Executando predição no conjunto de teste ---")
y_pred_scaled = model.predict([X_test, static_test, city_ids_test], batch_size=BATCH_SIZE, verbose=1)

y_true_flat = y_test_raw_agg.flatten()
y_pred_scaled_flat = y_pred_scaled.flatten().reshape(-1, 1)

print("Des-escalando previsões...")
y_pred_real_flat = scaler_target.inverse_transform(y_pred_scaled_flat).flatten()

y_pred_real_flat = np.maximum(y_pred_real_flat, 0.0)

print("Calculando métricas...")

mae_geral = mean_absolute_error(y_true_flat, y_pred_real_flat)
rmse_geral = np.sqrt(mean_squared_error(y_true_flat, y_pred_real_flat))
mape_geral = safe_mape(y_true_flat, y_pred_real_flat)

print("\n--- MÉTRICAS DE VALIDAÇÃO (agregando todas as semanas) ---")
print(f"  MAE  (Erro Médio Absoluto): {mae_geral:.3f} casos")
print(f"  RMSE (Erro Quadrático Médio): {rmse_geral:.3f} casos")
print(f"  MAPE (Erro Percentual Médio): {mape_geral:.3f} %")

print(f"\n--- Métricas por Semana de Previsão (H=1 a {HORIZON}) ---")
print("  Semana |   MAE   |  RMSE  |  MAPE (%)")
print("  --------------------------------------")

y_true_matrix = y_test_raw_agg 
y_pred_real_matrix = scaler_target.inverse_transform(y_pred_scaled.reshape(-1, 1)).reshape(y_pred_scaled.shape)
y_pred_real_matrix = np.maximum(y_pred_real_matrix, 0.0)

for h in range(HORIZON):
    y_true_h = y_true_matrix[:, h]
    y_pred_h = y_pred_real_matrix[:, h]

    mae_h = mean_absolute_error(y_true_h, y_pred_h)
    rmse_h = np.sqrt(mean_squared_error(y_true_h, y_pred_h))
    mape_h = safe_mape(y_true_h, y_pred_h)

    print(f"   H+{h+1}   | {mae_h:7.3f} | {rmse_h:6.3f} | {mape_h:8.3f}")

print("\n--- Avaliação Concluída ---")

## Estadual

In [None]:
import matplotlib.pyplot as plt


STATE_PEAK_PATH = os.path.join(BASE_DIR, "state_peak.json")
if os.path.exists(STATE_PEAK_PATH):
    with open(STATE_PEAK_PATH, "r") as f:
        state_peak_map = json.load(f)
else:
    state_peak_map = {}

DYN_FEATS = [
    "casos_norm_log",   
    "casos_velocidade", "casos_aceleracao", "casos_mm_4_semanas",
    "T2M_mean","T2M_std","PRECTOTCORR_mean","PRECTOTCORR_std",
    "RH2M_mean","RH2M_std","ALLSKY_SFC_SW_DWN_mean","ALLSKY_SFC_SW_DWN_std",
    "week_sin","week_cos","year_norm","notificacao"
]

STATIC_FEATS = ["populacao_total"]

def prever_para_estado(state_sigla: str, start_year: int, start_week: int, show_plot=True, display_history_weeks=250):
    df_st = df[df['estado_sigla'] == state_sigla].copy().sort_values(['year','week']).reset_index(drop=True)
    if df_st.empty:
        return None

    df_st["notificacao"] = df_st["year"].isin([2021, 2022]).astype(float)
    df_st["week_sin"] = np.sin(2 * np.pi * df_st["week"] / 52)
    df_st["week_cos"] = np.cos(2 * np.pi * df_st["week"] / 52)
    df_st["year_norm"] = (df_st["year"] - df_st["year"].min()) / max(1.0, df_st["year"].max() - df_st["year"].min())
    df_st['casos_velocidade'] = df_st['casos_soma'].diff().fillna(0)
    df_st['casos_aceleracao'] = df_st['casos_velocidade'].diff().fillna(0)
    df_st['casos_mm_4_semanas'] = df_st['casos_soma'].rolling(4, min_periods=1).mean()

    idx_list = df_st.index[(df_st['year'] == start_year) & (df_st['week'] == start_week)].tolist()
    if not idx_list:
        return None
    pred_point_idx = idx_list[0]
    last_known_idx = pred_point_idx - 1
    if last_known_idx < SEQUENCE_LENGTH - 1:
        return None

    state_peak = float(state_peak_map.get(state_sigla, 1.0))
    if state_peak <= 0:
        state_peak = 1.0

    df_st["casos_norm"] = df_st["casos_soma"] / state_peak
    df_st["casos_norm_log"] = np.log1p(df_st["casos_norm"])

    start_idx = last_known_idx - SEQUENCE_LENGTH + 1
    input_seq = df_st.iloc[start_idx:last_known_idx+1].copy()

    input_seq['casos_mm_4_semanas'] = input_seq['casos_soma'].rolling(4, min_periods=1).mean()

    static_raw = input_seq[STATIC_FEATS].iloc[0].values.reshape(1, -1)
    dyn_raw = input_seq[DYN_FEATS].values

    input_dyn_scaled = scaler_dyn.transform(dyn_raw).reshape(1, SEQUENCE_LENGTH, len(DYN_FEATS))
    input_static_scaled = scaler_static.transform(static_raw)
    state_idx = state_to_idx.get(str(state_sigla), 0)
    input_state = np.array([[state_idx]], dtype=np.int32)

    y_pred = model.predict([input_dyn_scaled, input_static_scaled, input_state], verbose=0)
    y_pred_reg = y_pred[0] if isinstance(y_pred, (list, tuple)) else y_pred 
    y_pred_log_norm = scaler_target.inverse_transform(y_pred_reg.reshape(-1,1)).reshape(y_pred_reg.shape)

    y_pred_norm = np.expm1(y_pred_log_norm)  

    prediction_counts = (y_pred_norm.flatten() * state_peak)
    prediction_counts = np.maximum(prediction_counts, 0.0)

    last_known_case = df_st.iloc[last_known_idx]['casos_soma']
    connected_prediction = np.insert(prediction_counts, 0, last_known_case)

    hist_to_display = df_st.iloc[max(0, last_known_idx - display_history_weeks):last_known_idx+1]

    if 'date' in hist_to_display.columns and not hist_to_display['date'].isna().all():
        x_hist = hist_to_display['date']
        x_pred = pd.date_range(start=x_hist.iloc[-1], periods=HORIZON+1, freq='W-MON') 
        x_future = x_pred[1:HORIZON+1] 
    else:
        x_hist = np.arange(len(hist_to_display))
        x_pred = np.arange(len(hist_to_display)-1, len(hist_to_display)-1+HORIZON+1)
        x_future = x_pred[1:HORIZON+1]

    if show_plot:
        plt.figure(figsize=(14,6))
        plt.plot(x_hist, hist_to_display['casos_soma'], 'o-', label="Historical")
        future_real = df_st.iloc[last_known_idx+1 : last_known_idx+1+HORIZON]
        if not future_real.empty:
            plt.plot(x_future[:len(future_real)], future_real['casos_soma'], 'o-', color='green', label="Future (real)")
        plt.plot(x_pred, connected_prediction, 'o--', color='red', label=f"Forecast (H={HORIZON} weeks)")
        plt.axvline(x=x_hist.iloc[-1] if hasattr(x_hist, 'iloc') else x_hist[-1], color='black', linestyle=":", linewidth=2, label="Forecast point")
        plt.title(f"Weekly Cases Forecast - {state_sigla}")
        plt.xlabel("Date" if 'date' in hist_to_display.columns else "Index")
        plt.ylabel("Weekly cases")
        plt.grid(True, linestyle='--', linewidth=0.5)
        plt.legend()
        plt.tight_layout()
        plt.show()

    return {
        "state": state_sigla,
        "last_known_index": int(last_known_idx),
        "prediction_scaled": y_pred_reg,          
        "prediction_counts": prediction_counts   
    }


In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import plotly.graph_objects as go

state_widget = widgets.Text(description="State:", value="AC")
year_widget = widgets.IntText(description="Year:", value=2025)
week_widget = widgets.IntText(description="Week:", value=20)
button = widgets.Button(description="Generate Forecast")
output = widgets.Output()

def safe_get_column(df, *names):
    low = {c.lower(): c for c in df.columns}
    for n in names:
        if n and n.lower() in low:
            return low[n.lower()]
    return None

def show_interactive(df_display, state_sigla, year, week):
    df2 = df_display.copy()
    date_col = safe_get_column(df2, "date", "week", "period") or df2.columns[0]
    real_col = safe_get_column(df2, "real", "actual", "cases", "observed", "casos_soma")
    pred_col = safe_get_column(df2, "pred", "forecast", "prediction", "predicted", "forecast")
    if date_col in df2.columns:
        df2[date_col] = df2[date_col].astype(str)

    fig = go.Figure()

    if real_col and real_col in df2.columns:
        fig.add_trace(go.Scatter(
            x=df2[date_col],
            y=df2[real_col],
            mode="lines+markers",
            name="Real",
            line=dict(color="green"),
            marker=dict(color="green")
        ))

    if pred_col and pred_col in df2.columns:
        fig.add_trace(go.Scatter(
            x=df2[date_col],
            y=df2[pred_col],
            mode="lines+markers",
            name="Forecast",
            line=dict(color="red", dash="dash"),
            marker=dict(color="red")
        ))

    fig.update_layout(
        title=f"Real vs Forecast — {state_sigla} ({year} W{week})",
        hovermode="x unified",
        margin=dict(l=40, r=20, t=40, b=30),
        xaxis_title="Date / Index",
        yaxis_title="Cases"
    )
    fig.show()

def on_button_clicked(b):
    with output:
        clear_output(wait=True)
        state_sigla = str(state_widget.value).strip()
        year = int(year_widget.value)
        week = int(week_widget.value)
        result = prever_para_estado(state_sigla, year, week)
        if result is None:
            print("Ponto de previsão não encontrado ou sequência insuficiente.")
            return
        last_idx = result["last_known_index"]
        hist = df[(df["estado_sigla"] == state_sigla)].reset_index(drop=True)
        start = max(0, last_idx - 50)
        hist_slice = hist.iloc[start:last_idx+1].copy()
        pred_vals = result["prediction_counts"]
        if 'date' in hist_slice.columns and not hist_slice['date'].isna().all():
            dates = list(hist_slice['date']) + [pd.to_datetime(hist_slice['date'].iloc[-1]) + pd.to_timedelta(7*i, unit='d') for i in range(1, len(pred_vals)+1)]
            display_df = pd.DataFrame({
                "date": dates,
                "real": list(hist_slice['casos_soma']) + [None]*len(pred_vals),
                "forecast": [None]*len(hist_slice) + list(pred_vals)
            })
        else:
            idxs = list(range(len(hist_slice))) + list(range(len(hist_slice), len(hist_slice) + len(pred_vals)))
            display_df = pd.DataFrame({
                "index": idxs,
                "real": list(hist_slice['casos_soma']) + [None]*len(pred_vals),
                "forecast": [None]*len(hist_slice) + list(pred_vals)
            })
        show_interactive(display_df, state_sigla, year, week)

button.on_click(on_button_clicked)
display(state_widget, year_widget, week_widget, button, output)


In [None]:
import os
import json
import numpy as np
import pandas as pd
import joblib
from tqdm import tqdm
import tensorflow as tf
from tensorflow.keras.models import load_model
from sklearn.metrics import mean_absolute_error, mean_squared_error
from typing import Tuple
from tensorflow.keras.utils import register_keras_serializable


@register_keras_serializable(package="Custom", name="asymmetric_mse")
def asymmetric_mse(y_true, y_pred):
    penalty_factor = 1.0
    error = y_true - y_pred
    denom = tf.maximum(tf.abs(y_true), 1.0)
    rel = tf.abs(error) / denom
    penalty = tf.where(error > 0, 1.0 + penalty_factor * rel, 1.0)
    loss = tf.square(error) * penalty
    return tf.reduce_mean(loss)

def create_sequences_multi_step(data: np.ndarray, seq_len: int, horizon: int) -> Tuple[np.ndarray, np.ndarray]:
    n = data.shape[0] - seq_len - horizon + 1
    if n <= 0:
        return np.empty((0, seq_len, data.shape[1])), np.empty((0, horizon))
    X = np.zeros((n, seq_len, data.shape[1]), dtype=np.float32)
    y = np.zeros((n, horizon), dtype=np.float32)
    for i in range(n):
        X[i] = data[i : i + seq_len]
        y[i] = data[i + seq_len : i + seq_len + horizon, 0]
    return X, y

def safe_mape(y_true, y_pred):
    """ Calcula o MAPE (Mean Absolute Percentage Error) de forma segura """
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    mask = y_true >= 1 
    if not np.any(mask):
        return np.nan
    mape = np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask]))
    return mape * 100.0

BASE_DIR = "/content/drive/MyDrive/lstm_projeto"
DATA_PATH = os.path.join(BASE_DIR, "data", "final_training_data_estadual.parquet")
CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints")
MODEL_TO_EVALUATE = os.path.join(CHECKPOINT_DIR, "model_checkpoint_best_state.keras")
SCALER_DIR = os.path.join(BASE_DIR, "scalers")
STATE_MAP_PATH = os.path.join(BASE_DIR, "state_to_idx.json")
STATE_PEAK_PATH = os.path.join(BASE_DIR, "state_peak.json")

SCALER_DYN_PATH = os.path.join(SCALER_DIR, "scaler_dyn_global_state.pkl")
SCALER_STATIC_PATH = os.path.join(SCALER_DIR, "scaler_static_global_state.pkl")
SCALER_TARGET_PATH = os.path.join(SCALER_DIR, "scaler_target_global_state.pkl")

SEQUENCE_LENGTH = 12
HORIZON = 6
BATCH_SIZE = 32 

print("\n--- Etapa 1: Carregando e processando os dados estaduais ---")
df = pd.read_parquet(DATA_PATH)
df = df.sort_values(by=["estado_sigla","year","week"]).reset_index(drop=True)

df["casos_soma"] = df["casos_soma"].astype(float)
df["populacao_total"] = df["populacao_total"].astype(float)
df["notificacao"] = df["year"].isin([2021,2022]).astype(float)

df["casos_velocidade"] = df.groupby("estado_sigla")["casos_soma"].diff().fillna(0)
df["casos_aceleracao"] = df.groupby("estado_sigla")["casos_velocidade"].diff().fillna(0)
df["casos_mm_4_semanas"] = df.groupby("estado_sigla")["casos_soma"].transform(lambda x: x.rolling(4, min_periods=1).mean())
df["week_sin"] = np.sin(2*np.pi*df["week"]/52)
df["week_cos"] = np.cos(2*np.pi*df["week"]/52)
df["year_norm"] = (df["year"] - df["year"].min()) / max(1.0, (df["year"].max() - df["year"].min()))

state_peak_dict = df.groupby("estado_sigla")["casos_soma"].max().to_dict()
state_peak_dict = {k: (v if (v is not None and v > 0) else 1.0) for k,v in state_peak_dict.items()}

df["casos_norm"] = df.apply(lambda r: r["casos_soma"] / state_peak_dict.get(r["estado_sigla"], 1.0), axis=1)
df["casos_norm_log"] = np.log1p(df["casos_norm"])

dynamic_features = [
    "casos_norm_log", 
    "casos_velocidade","casos_aceleracao","casos_mm_4_semanas",
    "T2M_mean","T2M_std","PRECTOTCORR_mean","PRECTOTCORR_std",
    "RH2M_mean","RH2M_std","ALLSKY_SFC_SW_DWN_mean","ALLSKY_SFC_SW_DWN_std",
    "week_sin","week_cos","year_norm","notificacao"
]
static_features = ["populacao_total"]

estados = np.sort(df["estado_sigla"].unique())

with open(STATE_MAP_PATH,"r") as f:
    state_to_idx = json.load(f)
idx_to_state = {int(v): k for k,v in state_to_idx.items()}

print("Iniciando recriação do conjunto de validação estadual...")
X_test_list, y_test_scaled_list, static_test_list, ids_test_list = [], [], [], []
y_test_raw_list = [] 

scaler_dyn = joblib.load(SCALER_DYN_PATH)
scaler_static = joblib.load(SCALER_STATIC_PATH)
scaler_target = joblib.load(SCALER_TARGET_PATH)

for st in tqdm(estados, desc="Processando estados"):
    df_st = df[df["estado_sigla"]==st].sort_values(by=["year","week"])

    arr_log_norm = df_st[dynamic_features].values

    df_st_raw_target = df_st[["casos_soma"]].copy()
    arr_dyn_no_target = df_st[dynamic_features[1:]].values
    arr_for_raw_y = np.concatenate([df_st_raw_target.values, arr_dyn_no_target], axis=1)

    X_all, y_all_log_norm = create_sequences_multi_step(arr_log_norm, SEQUENCE_LENGTH, HORIZON)
    _, y_all_raw = create_sequences_multi_step(arr_for_raw_y, SEQUENCE_LENGTH, HORIZON) 

    if X_all.shape[0] == 0:
        continue

    y_all_scaled = scaler_target.transform(y_all_log_norm.reshape(-1,1)).reshape(y_all_log_norm.shape)

    split = int(0.8 * X_all.shape[0])
    static_rep = df_st[static_features].iloc[0].values.reshape(1,-1)

    if split < X_all.shape[0]:
        ntest = X_all.shape[0] - split
        X_test_list.append(X_all[split:])
        y_test_scaled_list.append(y_all_scaled[split:])
        static_test_list.append(np.repeat(static_rep, ntest, axis=0))
        ids_test_list.append(np.repeat(state_to_idx[st], ntest).reshape(-1,1))

        y_test_raw_list.append(y_all_raw[split:])

if len(X_test_list) == 0:
    raise RuntimeError("Nenhuma sequência de validação encontrada.")

print("\n--- Etapa 2: Agregando dados de teste/validação ---")
X_test_unscaled = np.concatenate(X_test_list, axis=0) 
y_test_scaled = np.concatenate(y_test_scaled_list, axis=0)
static_test_unscaled = np.concatenate(static_test_list, axis=0)
state_ids_test = np.concatenate(ids_test_list, axis=0).reshape(-1,1)

y_test_raw_agg = np.concatenate(y_test_raw_list, axis=0)

print(f"Total de sequências de Teste/Validação recriadas: {X_test_unscaled.shape[0]}")

print("Escalando inputs de teste...")
ns, sl, nf = X_test_unscaled.shape
X_test_scaled = scaler_dyn.transform(X_test_unscaled.reshape(-1, nf)).reshape(ns, sl, nf)
static_test_scaled = scaler_static.transform(static_test_unscaled)

print(f"\n--- Etapa 3: Carregando modelo de {MODEL_TO_EVALUATE} ---")
model = load_model(MODEL_TO_EVALUATE, custom_objects={'asymmetric_mse': asymmetric_mse})
print("Modelo carregado.")

print("\n--- Etapa 4: Executando predição no conjunto de teste estadual ---")
y_pred_scaled = model.predict([X_test_scaled, static_test_scaled, state_ids_test], batch_size=BATCH_SIZE, verbose=1)

print("Revertendo transformações (des-escalar, expm1, des-normalizar)...")

y_pred_log_norm_matrix = scaler_target.inverse_transform(
    y_pred_scaled.reshape(-1, 1)
).reshape(y_pred_scaled.shape)

y_pred_norm_matrix = np.expm1(y_pred_log_norm_matrix)

per_sample_peak_vector = np.array(
    [state_peak_dict[idx_to_state[int(idx)]] for idx in state_ids_test.flatten()]
)

y_pred_real_matrix = y_pred_norm_matrix * per_sample_peak_vector[:, np.newaxis]

y_pred_real_matrix = np.maximum(y_pred_real_matrix, 0.0)

print("Calculando métricas...")

y_true_flat = y_test_raw_agg.flatten()
y_pred_flat = y_pred_real_matrix.flatten()

mae_geral = mean_absolute_error(y_true_flat, y_pred_flat)
rmse_geral = np.sqrt(mean_squared_error(y_true_flat, y_pred_flat))
mape_geral = safe_mape(y_true_flat, y_pred_flat)

print("\n--- MÉTRICAS DE VALIDAÇÃO ESTADUAL (agregando todas as semanas) ---")
print(f"  MAE  (Erro Médio Absoluto): {mae_geral:.3f} casos")
print(f"  RMSE (Erro Quadrático Médio): {rmse_geral:.3f} casos")
print(f"  MAPE (Erro Percentual Médio): {mape_geral:.3f} %")

print(f"\n--- Métricas por Semana de Previsão (H=1 a {HORIZON}) ---")
print("  Semana |   MAE   |  RMSE  |  MAPE (%)")
print("  --------------------------------------")

y_true_matrix = y_test_raw_agg

for h in range(HORIZON):
    y_true_h = y_true_matrix[:, h]
    y_pred_h = y_pred_real_matrix[:, h]

    mae_h = mean_absolute_error(y_true_h, y_pred_h)
    rmse_h = np.sqrt(mean_squared_error(y_true_h, y_pred_h))
    mape_h = safe_mape(y_true_h, y_pred_h)

    print(f"   H+{h+1}   | {mae_h:7.3f} | {rmse_h:6.3f} | {mape_h:8.3f}")

print("\n--- Avaliação Estadual Concluída ---")