# Inicializando

In [64]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


# Treinando

In [None]:
import os
import json
import numpy as np
import joblib
from typing import Tuple

import tensorflow as tf
from tensorflow.keras.callbacks import Callback

def create_sequences_multi_step(data: np.ndarray, seq_len: int, horizon: int) -> Tuple[np.ndarray, np.ndarray]:
    """
    Gera janelas (X, y) empilhadas para previsão multi-step.
    X shape: (n_samples, seq_len, n_features)
    y shape: (n_samples, horizon)  -> alvo é a primeira coluna de `data` (numero_casos)
    """
    num_samples = data.shape[0] - seq_len - horizon + 1
    if num_samples <= 0:
        return np.array([]), np.array([])

    X = np.zeros((num_samples, seq_len, data.shape[1]), dtype=np.float32)
    y = np.zeros((num_samples, horizon), dtype=np.float32)

    for i in range(num_samples):
        X[i] = data[i : i + seq_len]
        y[i] = data[i + seq_len : i + seq_len + horizon, 0]
    return X, y


def asymmetric_mse(y_true, y_pred):
    """
    Loss assimétrica operando no espaço em que y_true/ y_pred estão (contagens reais escaladas).
    Penaliza mais subestimações relativas à magnitude do y_true, com normalização estável.
    """
    penalty_factor = 5.0  # penalidade mais forte para subestimações
    error = y_true - y_pred
    denom = tf.maximum(tf.abs(y_true), 1.0)  # evita divisão por números muito pequenos
    rel = tf.abs(error) / denom
    penalty = tf.where(error > 0, 1.0 + penalty_factor * rel, 1.0)  # se y_true > y_pred (subestimação) penaliza
    loss = tf.square(error) * penalty
    return tf.reduce_mean(loss)


# --- Callbacks utilitários --- #

class EpochSaver(Callback):
    """Salva o número da última epoch completada (arquivo texto)."""
    def __init__(self, epoch_file_path: str):
        super().__init__()
        self.epoch_file_path = epoch_file_path

    def on_epoch_end(self, epoch, logs=None):
        try:
            # salva epoch 0-based (conforme leitura posterior +1 para continuar)
            with open(self.epoch_file_path, "w") as f:
                f.write(str(int(epoch)))
        except Exception:
            pass


class HistorySaver(Callback):
    """
    Salva o histórico de treinamento incrementalmente em JSON após cada epoch.
    Ao retomar, o arquivo existente é mantido e novas épocas são anexadas.
    """
    def __init__(self, save_path: str):
        super().__init__()
        self.save_path = save_path

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        # carrega histórico existente (se houver)
        try:
            with open(self.save_path, "r") as f:
                hist = json.load(f)
        except Exception:
            hist = {}

        # append valores (garantir conversão para float simples)
        for k, v in logs.items():
            # evitar valores não-seriais (numpy types)
            try:
                val = float(v)
            except Exception:
                continue
            hist.setdefault(k, []).append(val)

        # escreve de volta
        try:
            with open(self.save_path, "w") as f:
                json.dump(hist, f)
        except Exception:
            pass


class PersistentEarlyStoppingCityVal(Callback):
    """
    EarlyStopping persistente que usa `city_val_loss` presente em logs.
    Persiste estado (wait, best) em arquivo JSON para suportar reinício.
    """
    def __init__(self, city_val_callback, patience=10, restore_best_weights=True, state_file="earlystop_city_state.json"):
        super().__init__()
        self.city_val_callback = city_val_callback
        self.patience = patience
        self.restore_best_weights = restore_best_weights
        self.state_file = state_file
        self.wait = 0
        self.best_weights = None
        self.best = np.inf

        if os.path.exists(self.state_file):
            try:
                with open(self.state_file, "r") as f:
                    state = json.load(f)
                self.wait = int(state.get("wait", 0))
                self.best = float(state.get("best", np.inf))
            except Exception:
                self.wait = 0
                self.best = np.inf

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        # tenta obter city_val_loss nos logs (outro callback deve ter preenchido)
        current = logs.get("city_val_loss")
        if current is None and hasattr(self.city_val_callback, "on_epoch_end"):
            self.city_val_callback.on_epoch_end(epoch, logs)
            current = logs.get("city_val_loss")
        if current is None:
            return

        current = float(current)
        logs["city_val_loss"] = current

        if current < self.best:
            self.best = current
            self.wait = 0
            if self.restore_best_weights:
                try:
                    self.best_weights = self.model.get_weights()
                except Exception:
                    self.best_weights = None
        else:
            self.wait += 1

        # persiste estado
        try:
            with open(self.state_file, "w") as f:
                json.dump({"wait": int(self.wait), "best": float(self.best)}, f)
        except Exception:
            pass

        if self.wait >= self.patience:
            if self.restore_best_weights and self.best_weights is not None:
                try:
                    self.model.set_weights(self.best_weights)
                except Exception:
                    pass
            print(f"\nEarly stopping triggered at epoch {epoch + 1} (city_val_loss: {self.best:.6f})")
            self.model.stop_training = True


class CityValLossWeighted(Callback):
    """
    Callback que calcula uma loss ponderada por cidade na validação.
    Recebe arrays de validação e informações de agrupamento por cidade.
    """
    def __init__(self, X_val, static_val, city_ids_val, y_val, city_sizes, city_weights, save_path, penalty_factor=10.0):
        super().__init__()
        self.X_val = X_val
        self.static_val = static_val
        self.city_ids_val = city_ids_val
        self.y_val = y_val
        self.city_sizes = city_sizes
        self.city_weights = city_weights
        self.save_path = save_path
        self.best = np.inf
        self.penalty_factor = penalty_factor

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        # chama predict incluindo city ids
        try:
            y_pred = self.model.predict([self.X_val, self.static_val, self.city_ids_val], batch_size=1024, verbose=0)
            # se por acaso retornou lista, pega o primeiro (reg_output)
            if isinstance(y_pred, (list, tuple)):
                y_pred_reg = y_pred[0]
            else:
                y_pred_reg = y_pred
        except Exception:
            try:
                y_pred = self.model.predict([self.X_val, self.static_val], batch_size=1024, verbose=0)
                if isinstance(y_pred, (list, tuple)):
                    y_pred_reg = y_pred[0]
                else:
                    y_pred_reg = y_pred
            except Exception:
                return

        start = 0
        weighted_losses = []
        total_weight = 0.0

        for i, n in enumerate(self.city_sizes):
            y_true_city = self.y_val[start:start+n]
            y_pred_city = y_pred_reg[start:start+n]
            start += n

            error = y_true_city - y_pred_city
            penalty = np.where(error > 0, self.penalty_factor, 1.0)
            asymmetric_loss_city = np.mean(np.square(error) * penalty)

            weight = float(self.city_weights[i])
            weighted_losses.append(asymmetric_loss_city * weight)
            total_weight += weight

        city_val_loss = float(np.sum(weighted_losses) / max(total_weight, 1e-8))
        logs["city_val_loss"] = city_val_loss

        if city_val_loss < self.best:
            self.best = city_val_loss
            try:
                self.model.save(self.save_path)
            except Exception:
                pass
            print(f"\nEpoch {epoch+1}: Melhor modelo salvo com city_val_loss ponderado: {city_val_loss:.6f}")


# --- Mapeamento city_to_idx utilitário --- #

def build_and_save_city_mapping(codigo_ibge_array, save_path):
    """
    Gera um mapeamento codigo_ibge -> index sequencial (reserve 0 = unknown)
    Salva em JSON no caminho save_path.
    """
    unique_cities = np.sort(np.unique(codigo_ibge_array).astype(int))
    city_to_idx = {int(cid): int(idx+1) for idx, cid in enumerate(unique_cities)}
    with open(save_path, "w") as f:
        json.dump(city_to_idx, f)
    return city_to_idx


def load_city_mapping(path):
    with open(path, "r") as f:
        data = json.load(f)
    return {int(k): int(v) for k, v in data.items()}


def process_municipio(municipio_id):
    """
    Espera que existam no escopo global:
      - df (pandas DataFrame)
      - SEQUENCE_LENGTH, HORIZON
      - dynamic_features, static_features
      - SCALER_DIR
      - city_to_idx (mapa codigo_ibge -> index sequencial), se não existir usa 0
    Retorna:
      X_train_scaled, y_train_scaled, X_test_scaled, y_test_scaled,
      static_train_scaled, static_test_scaled, y_train_raw, y_test_raw, city_ids_train, city_ids_test
    Nota: y's são SEM log1p (contagem original). A escala será feita pelos scalers globais.
    """
    import os
    import joblib
    municipio_id = int(municipio_id)
    df_mun = df[df["codigo_ibge"] == municipio_id].copy().sort_values(by=["ano","semana"])
    if len(df_mun) < SEQUENCE_LENGTH + HORIZON:
        return None

    dynamic_data = df_mun[dynamic_features].values
    static_data = df_mun[static_features].iloc[0].values.reshape(1, -1)
    X_mun_raw, y_mun_raw = create_sequences_multi_step(dynamic_data, SEQUENCE_LENGTH, HORIZON)
    if X_mun_raw.shape[0] == 0:
        return None

    split_idx = int(len(X_mun_raw) * 0.8)
    if split_idx == 0 or split_idx == len(X_mun_raw):
        return None

    X_train_raw = X_mun_raw[:split_idx]
    y_train_raw = y_mun_raw[:split_idx]
    X_test_raw  = X_mun_raw[split_idx:]
    y_test_raw  = y_mun_raw[split_idx:]

    static_train = np.repeat(static_data, X_train_raw.shape[0], axis=0)
    static_test  = np.repeat(static_data, X_test_raw.shape[0], axis=0)

    # carrega scalers globais
    scaler_dyn = joblib.load(os.path.join(SCALER_DIR, "scaler_dyn_global.pkl"))
    scaler_static = joblib.load(os.path.join(SCALER_DIR, "scaler_static_global.pkl"))
    scaler_target = joblib.load(os.path.join(SCALER_DIR, "scaler_target_global.pkl"))

    # escala
    X_train_scaled = scaler_dyn.transform(X_train_raw.reshape(-1, X_train_raw.shape[-1])).reshape(X_train_raw.shape)
    X_test_scaled  = scaler_dyn.transform(X_test_raw.reshape(-1, X_test_raw.shape[-1])).reshape(X_test_raw.shape)
    static_train_scaled = scaler_static.transform(static_train)
    static_test_scaled  = scaler_static.transform(static_test)

    # y's: SEM log1p (contagem original)
    y_train_scaled = scaler_target.transform(y_train_raw.reshape(-1, 1)).reshape(y_train_raw.shape)
    y_test_scaled  = scaler_target.transform(y_test_raw.reshape(-1, 1)).reshape(y_test_raw.shape)

    # city ids
    city_idx = city_to_idx.get(municipio_id, 0)
    city_ids_train = np.repeat(city_idx, X_train_scaled.shape[0]).astype(np.int32)
    city_ids_test  = np.repeat(city_idx, X_test_scaled.shape[0]).astype(np.int32)

    flags_weeks = df_mun["notificacao"].values
    seq_flags_all = flags_weeks[SEQUENCE_LENGTH - 1 : SEQUENCE_LENGTH - 1 + X_mun_raw.shape[0]]  # alinhado com cada sequência
    # agora split
    split_idx = int(len(X_mun_raw) * 0.8)
    notif_train = seq_flags_all[:split_idx].astype(float)
    notif_test  = seq_flags_all[split_idx:].astype(float)
    # e no return inclua notif_train, notif_test
    return (X_train_scaled, y_train_scaled, X_test_scaled, y_test_scaled,
            static_train_scaled, static_test_scaled,
            y_train_raw, y_test_raw,
            city_ids_train, city_ids_test,
            notif_train, notif_test)

## Estadual

In [None]:
# ---------- BLOCO 2 (REVISADO): treinamento estadual alinhado ao municipal ----------
import os
import json
import joblib
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import LSTM, Dense, Input, Concatenate, Dropout, Embedding, Flatten
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.optimizers import Adam

# --- CONFIG / PATHS ---
BASE_DIR = "/content/drive/MyDrive/lstm_projeto"
DATA_PATH = os.path.join(BASE_DIR, "data", "final_training_data_estadual.parquet")
CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints")
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "model_state_checkpoint.keras")
CHECKPOINT_BEST_STATE = os.path.join(CHECKPOINT_DIR, "model_checkpoint_best_state.keras")
SCALER_DIR = os.path.join(BASE_DIR, "scalers")
STATE_MAP_PATH = os.path.join(BASE_DIR, "state_to_idx.json")
TRAINING_HISTORY_PATH = os.path.join(BASE_DIR, "training_history_state.json")
EPOCH_TRACK_PATH_STATE = CHECKPOINT_PATH.replace(".keras", "_epoch.txt")

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(SCALER_DIR, exist_ok=True)

SEQUENCE_LENGTH = 12
HORIZON = 6
EPOCHS = 120
BATCH_SIZE = 32

# policy knobs (make consistent with municipal)
NOTIF_LEARNING_RATE = 0.8   # amostras com notificacao (2021/2022) aprendem com 80% do peso
POP_ALPHA = 0.1
SAMPLE_WEIGHT_MINMAX = (0.1, 5.0)

# ---------- LOAD DATA & FEATURES ----------
df = pd.read_parquet(DATA_PATH)
df = df.sort_values(by=["estado_sigla","year","week"]).reset_index(drop=True)

df["casos_soma"] = df["casos_soma"].astype(float)
df["populacao_total"] = df["populacao_total"].astype(float)
df["notificacao"] = df["year"].isin([2021,2022]).astype(float)

# derived features (compute once)
df["casos_velocidade"] = df.groupby("estado_sigla")["casos_soma"].diff().fillna(0)
df["casos_aceleracao"] = df.groupby("estado_sigla")["casos_velocidade"].diff().fillna(0)
df["casos_mm_4_semanas"] = df.groupby("estado_sigla")["casos_soma"].transform(lambda x: x.rolling(4, min_periods=1).mean())
df["week_sin"] = np.sin(2*np.pi*df["week"]/52)
df["week_cos"] = np.cos(2*np.pi*df["week"]/52)
df["year_norm"] = (df["year"] - df["year"].min()) / max(1.0, (df["year"].max() - df["year"].min()))

dynamic_features = [
    "casos_soma","casos_velocidade","casos_aceleracao","casos_mm_4_semanas",
    "T2M_mean","T2M_std","PRECTOTCORR_mean","PRECTOTCORR_std",
    "RH2M_mean","RH2M_std","ALLSKY_SFC_SW_DWN_mean","ALLSKY_SFC_SW_DWN_std",
    "week_sin","week_cos","year_norm","notificacao"
 ]
static_features = ["populacao_total"]

estados = np.sort(df["estado_sigla"].unique())

# state mapping
if os.path.exists(STATE_MAP_PATH):
    with open(STATE_MAP_PATH,"r") as f:
        state_to_idx = json.load(f)
else:
    state_to_idx = {str(s): i+1 for i,s in enumerate(estados)}
    with open(STATE_MAP_PATH,"w") as f:
        json.dump(state_to_idx,f)
num_states = len(state_to_idx) + 1
idx_to_state = {int(v): k for k,v in state_to_idx.items()}

# ---------- SCALERS (global dyn/static/target) ----------
scaler_dyn_path = os.path.join(SCALER_DIR, "scaler_dyn_global_state.pkl")
scaler_static_path = os.path.join(SCALER_DIR, "scaler_static_global_state.pkl")
scaler_target_path = os.path.join(SCALER_DIR, "scaler_target_global_state.pkl")

if not (os.path.exists(scaler_dyn_path) and os.path.exists(scaler_static_path) and os.path.exists(scaler_target_path)):
    all_dyn, all_static, all_tgt = [], [], []
    for st in estados:
        df_st = df[df["estado_sigla"]==st].sort_values(by=["year","week"])
        X = df_st[dynamic_features].values
        if X.shape[0] < SEQUENCE_LENGTH+HORIZON:
            continue
        all_dyn.append(X)
        S = df_st[static_features].iloc[0].values.reshape(1,-1)
        all_static.append(np.repeat(S, X.shape[0], axis=0))
        yvals = df_st["casos_soma"].values.reshape(-1,1)
        all_tgt.append(yvals)
    if len(all_dyn) == 0:
        raise RuntimeError("Sem sequências suficientes para ajustar scalers. Verifique dados.")
    all_dyn = np.concatenate(all_dyn, axis=0)
    all_static = np.concatenate(all_static, axis=0)
    all_tgt = np.concatenate(all_tgt, axis=0)
    scaler_dyn = MinMaxScaler().fit(all_dyn)
    scaler_static = MinMaxScaler().fit(all_static)
    scaler_target = MinMaxScaler().fit(all_tgt)
    joblib.dump(scaler_dyn, scaler_dyn_path)
    joblib.dump(scaler_static, scaler_static_path)
    joblib.dump(scaler_target, scaler_target_path)
else:
    scaler_dyn = joblib.load(scaler_dyn_path)
    scaler_static = joblib.load(scaler_static_path)
    scaler_target = joblib.load(scaler_target_path)

# ---------- Prepare sequences per state (same pattern as municipal) ----------
X_train_list, y_train_list, static_train_list, ids_train_list, notif_train_parts = [], [], [], [], []
X_test_list, y_test_list, static_test_list, ids_test_list, notif_test_parts = [], [], [], [], []

for st in tqdm(estados, desc="preparing sequences per state"):
    df_st = df[df["estado_sigla"]==st].sort_values(by=["year","week"])
    arr = df_st[dynamic_features].values
    X_all, y_all = create_sequences_multi_step(arr, SEQUENCE_LENGTH, HORIZON)
    if X_all.shape[0] == 0:
        continue
    # scale target with global scaler (consistent)
    y_all_scaled = scaler_target.transform(y_all.reshape(-1,1)).reshape(y_all.shape)

    split = int(0.8 * X_all.shape[0])
    static_rep = df_st[static_features].iloc[0].values.reshape(1,-1)

    # create notification flags aligned with sequences (flag at last input timestep)
    flags = df_st["notificacao"].values
    seq_flags_all = flags[SEQUENCE_LENGTH - 1 : SEQUENCE_LENGTH - 1 + X_all.shape[0]]
    notif_train = seq_flags_all[:split].astype(float) if split>0 else np.empty((0,))
    notif_test = seq_flags_all[split:].astype(float) if split < X_all.shape[0] else np.empty((0,))

    # train
    if split > 0:
        X_train_list.append(X_all[:split])
        y_train_list.append(y_all_scaled[:split])
        static_train_list.append(np.repeat(static_rep, split, axis=0))
        ids_train_list.append(np.repeat(state_to_idx[st], split).reshape(-1,1))
        notif_train_parts.append(notif_train)
    # test
    if split < X_all.shape[0]:
        ntest = X_all.shape[0] - split
        X_test_list.append(X_all[split:])
        y_test_list.append(y_all_scaled[split:])
        static_test_list.append(np.repeat(static_rep, ntest, axis=0))
        ids_test_list.append(np.repeat(state_to_idx[st], ntest).reshape(-1,1))
        notif_test_parts.append(notif_test)

if len(X_train_list) == 0:
    raise RuntimeError("Nenhuma sequência de treino encontrada.")

X_train = np.concatenate(X_train_list, axis=0)
y_train = np.concatenate(y_train_list, axis=0)
static_train = np.concatenate(static_train_list, axis=0)
state_ids_train = np.concatenate(ids_train_list, axis=0).reshape(-1,1)

X_test = np.concatenate(X_test_list, axis=0)
y_test = np.concatenate(y_test_list, axis=0)
static_test = np.concatenate(static_test_list, axis=0)
state_ids_test = np.concatenate(ids_test_list, axis=0).reshape(-1,1)

# ---------- sample weights (keeps population/peak logic but apply notif multiplier like municipal) ----------
state_pop = {st: float(df[df["estado_sigla"]==st]["populacao_total"].iloc[0]) for st in estados}
median_pop = float(np.median(list(state_pop.values())))
state_peak = {st: float(df[df["estado_sigla"]==st]["casos_soma"].max() if not df[df["estado_sigla"]==st].empty else 1.0) for st in estados}

train_state_ids = state_ids_train.flatten()
per_sample_peak = np.array([state_peak[idx_to_state[int(s)]] for s in train_state_ids], dtype=float)
per_sample_pop = np.array([state_pop[idx_to_state[int(s)]] for s in train_state_ids], dtype=float)

peak_per100k = per_sample_peak / (per_sample_pop / 100000.0 + 1e-8)
base_w = np.sqrt(1.0 + peak_per100k)
pop_factor = (per_sample_pop / (median_pop + 1e-8)) ** POP_ALPHA
sample_weights = base_w * pop_factor

# notif flags (concatenate in the same order used to build X_train_list)
if len(notif_train_parts) > 0:
    notif_flags_train = np.concatenate(notif_train_parts, axis=0)
else:
    notif_flags_train = np.zeros(sample_weights.shape[0], dtype=float)

# align shapes & warn if mismatch (pad/crop like municipal)
if notif_flags_train.shape[0] != sample_weights.shape[0]:
    print(f"[WARN] notif_flags_train ({notif_flags_train.shape[0]}) != sample_weights ({sample_weights.shape[0]}). Ajustando.")
    if notif_flags_train.shape[0] < sample_weights.shape[0]:
        pad = np.zeros(sample_weights.shape[0] - notif_flags_train.shape[0], dtype=float)
        notif_flags_train = np.concatenate([notif_flags_train, pad], axis=0)
    else:
        notif_flags_train = notif_flags_train[: sample_weights.shape[0]]

# apply reduction for notification years
sample_weights = sample_weights * np.where(notif_flags_train > 0.5, NOTIF_LEARNING_RATE, 1.0)
sample_weights = np.clip(sample_weights, SAMPLE_WEIGHT_MINMAX[0], SAMPLE_WEIGHT_MINMAX[1]).astype(np.float32)

# ---------- scale dynamics/static arrays ----------
ns, sl, nf = X_train.shape
X_train_2d = X_train.reshape(-1, nf)
X_test_2d = X_test.reshape(-1, nf)
X_train_scaled = scaler_dyn.transform(X_train_2d).reshape(ns, sl, nf)
X_test_scaled = scaler_dyn.transform(X_test_2d).reshape(X_test.shape[0], sl, nf)
static_train_scaled = scaler_static.transform(static_train)
static_test_scaled = scaler_static.transform(static_test)

# ---------- build model (same architecture as municipal) ----------
embed_dim = 8
input_dyn = Input(shape=(SEQUENCE_LENGTH, len(dynamic_features)), name='dynamic_input')
input_static = Input(shape=(len(static_features),), name='static_input')
input_state = Input(shape=(1,), dtype='int32', name='state_id_input')

x = LSTM(128, return_sequences=True)(input_dyn)
x = Dropout(0.1)(x)
x = LSTM(128)(x)
x = Dropout(0.1)(x)

emb = Embedding(input_dim=num_states, output_dim=embed_dim)(input_state)
embf = Flatten()(emb)
cat = Concatenate()([x, input_static, embf])

activation = 'linear'
out = Dense(HORIZON, activation=activation, name='reg_output')(cat)

model = Model([input_dyn, input_static, input_state], out)

optimizer = Adam(learning_rate=1e-4, clipnorm=1.0)

# use asymmetric_mse from bloco1 (consistent behavior with municipal)
model.compile(optimizer=optimizer, loss=asymmetric_mse)

# ---------- callbacks (make it act like municipal: weighted state validation -> persistent earlystop -> history/epoch saver) ----------
epoch_saver = EpochSaver(EPOCH_TRACK_PATH_STATE)
history_saver = HistorySaver(TRAINING_HISTORY_PATH)
checkpoint_last = ModelCheckpoint(filepath=CHECKPOINT_PATH, save_best_only=False, save_freq="epoch", verbose=1)
checkpoint_best = ModelCheckpoint(filepath=CHECKPOINT_BEST_STATE, save_best_only=True, monitor="val_loss", verbose=1)

# build weighted state-val callback using existing StateValLossWeighted
state_sizes = [x.shape[0] for x in X_test_list]
state_weights = [np.max(y_raw) if y_raw.size > 0 else 1.0 for y_raw in y_test_list]  # weight by test peaks (or adapt)
state_val_cb = StateValLossWeighted(
    X_val=X_test_scaled,
    static_val=static_test_scaled,
    state_ids_val=state_ids_test,
    y_val=y_test,
    state_sizes=state_sizes,
    state_weights=state_weights,
    save_path=CHECKPOINT_BEST_STATE,
    penalty_factor=2.0
)

early_stopping_state = PersistentEarlyStoppingStateVal(
    state_val_callback=state_val_cb,
    patience=10,
    restore_best_weights=True,
    state_file=os.path.join(BASE_DIR, "earlystop_state_state.json")
)

# order: state_val_cb first so it computes logs['state_val_loss'], then early stopping, then history/epoch, then checkpoint
callbacks = [state_val_cb, early_stopping_state, history_saver, epoch_saver, checkpoint_last]

# ---------- train ----------
history = model.fit(
    [X_train_scaled, static_train_scaled, state_ids_train],
    y_train,
    validation_data=([X_test_scaled, static_test_scaled, state_ids_test], y_test),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    sample_weight=sample_weights,
    verbose=1
)

# save scalers & model final
joblib.dump(scaler_dyn, os.path.join(SCALER_DIR,"scaler_dyn_global_state.pkl"))
joblib.dump(scaler_static, os.path.join(SCALER_DIR,"scaler_static_global_state.pkl"))
joblib.dump(scaler_target, os.path.join(SCALER_DIR,"scaler_target_global_state.pkl"))
model.save(CHECKPOINT_PATH)

print("Treinamento estadual concluído.")

FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/lstm_projeto/data/final_training_data_estadual.parquet'

## Municipal

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import plotly.graph_objects as go

ibge_widget = widgets.IntText(description="IBGE:", value=3509502)
year_widget = widgets.IntText(description="Year:", value=2025)
week_widget = widgets.IntText(description="Week:", value=20)
button = widgets.Button(description="Generate Forecast")
output = widgets.Output()

def safe_get_column(df, *names):
    low = {c.lower(): c for c in df.columns}
    for n in names:
        if n and n.lower() in low:
            return low[n.lower()]
    return None

def show_interactive(df_display, ibge, year, week):
    df2 = df_display.copy()
    date_col = safe_get_column(df2, "date", "week", "period") or df2.columns[0]
    real_col = safe_get_column(df2, "real", "actual", "cases", "observed", "numero_casos", "casos_soma")
    pred_col = safe_get_column(df2, "pred", "forecast", "prediction", "predicted", "forecast")
    if date_col in df2.columns:
        df2[date_col] = df2[date_col].astype(str)
    fig = go.Figure()
    if real_col and real_col in df2.columns:
        fig.add_trace(go.Scatter(x=df2[date_col], y=df2[real_col], mode="lines+markers", name="Real", line=dict(color="green")))
    if pred_col and pred_col in df2.columns:
        fig.add_trace(go.Scatter(x=df2[date_col], y=df2[pred_col], mode="lines+markers", name="Forecast", line=dict(color="red", dash="dash")))
    fig.update_layout(title=f"Real vs Forecast — {ibge} ({year} W{week})", hovermode="x unified", margin=dict(l=40, r=20, t=40, b=30))
    fig.show()

def on_button_clicked(b):
    with output:
        clear_output(wait=True)
        try:
            ibge = int(ibge_widget.value)
            year = int(year_widget.value)
            week = int(week_widget.value)
        except Exception:
            return
        try:
            result = prever_para_municipio(ibge, year, week)
        except NameError:
            return
        except Exception:
            return
        if result is None:
            return
        last_idx = result["last_known_index"]
        hist = df[(df["codigo_ibge"] == int(ibge))].reset_index(drop=True)
        start = max(0, last_idx - 50)
        hist_slice = hist.iloc[start:last_idx+1].copy()
        pred_vals = list(result["prediction_counts"])
        if 'date' in hist_slice.columns and not hist_slice['date'].isna().all():
            dates = list(hist_slice['date']) + [pd.to_datetime(hist_slice['date'].iloc[-1]) + pd.to_timedelta(7*i, unit='d') for i in range(1, len(pred_vals)+1)]
            display_df = pd.DataFrame({"date": dates, "real": list(hist_slice['numero_casos']) + [None]*len(pred_vals), "forecast": [None]*len(hist_slice) + pred_vals})
        else:
            idxs = list(range(len(hist_slice))) + list(range(len(hist_slice), len(hist_slice) + len(pred_vals)))
            display_df = pd.DataFrame({"index": idxs, "real": list(hist_slice['numero_casos']) + [None]*len(pred_vals), "forecast": [None]*len(hist_slice) + pred_vals})
        show_interactive(display_df, ibge, year, week)

button.on_click(on_button_clicked)

display(ibge_widget, year_widget, week_widget, button, output)


IntText(value=3509502, description='IBGE:')

IntText(value=2025, description='Year:')

IntText(value=20, description='Week:')

Button(description='Generate Forecast', style=ButtonStyle())

Output()

## Estadual

In [80]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import plotly.graph_objects as go

state_widget = widgets.Text(description="State:", value="AC")
year_widget = widgets.IntText(description="Year:", value=2025)
week_widget = widgets.IntText(description="Week:", value=20)
button = widgets.Button(description="Generate Forecast")
output = widgets.Output()

def safe_get_column(df, *names):
    low = {c.lower(): c for c in df.columns}
    for n in names:
        if n and n.lower() in low:
            return low[n.lower()]
    return None

def show_interactive(df_display, state_sigla, year, week):
    df2 = df_display.copy()
    date_col = safe_get_column(df2, "date", "week", "period") or df2.columns[0]
    real_col = safe_get_column(df2, "real", "actual", "cases", "observed", "casos_soma")
    pred_col = safe_get_column(df2, "pred", "forecast", "prediction", "predicted", "forecast")
    if date_col in df2.columns:
        df2[date_col] = df2[date_col].astype(str)

    fig = go.Figure()

    # Real (verdes)
    if real_col and real_col in df2.columns:
        fig.add_trace(go.Scatter(
            x=df2[date_col],
            y=df2[real_col],
            mode="lines+markers",
            name="Real",
            line=dict(color="green"),
            marker=dict(color="green")
        ))

    # Forecast (vermelho)
    if pred_col and pred_col in df2.columns:
        fig.add_trace(go.Scatter(
            x=df2[date_col],
            y=df2[pred_col],
            mode="lines+markers",
            name="Forecast",
            line=dict(color="red", dash="dash"),
            marker=dict(color="red")
        ))

    fig.update_layout(
        title=f"Real vs Forecast — {state_sigla} ({year} W{week})",
        hovermode="x unified",
        margin=dict(l=40, r=20, t=40, b=30),
        xaxis_title="Date / Index",
        yaxis_title="Cases"
    )
    fig.show()

def on_button_clicked(b):
    with output:
        clear_output(wait=True)
        state_sigla = str(state_widget.value).strip()
        year = int(year_widget.value)
        week = int(week_widget.value)
        result = prever_para_estado(state_sigla, year, week)
        if result is None:
            print("Ponto de previsão não encontrado ou sequência insuficiente.")
            return
        last_idx = result["last_known_index"]
        hist = df[(df["estado_sigla"] == state_sigla)].reset_index(drop=True)
        start = max(0, last_idx - 50)
        hist_slice = hist.iloc[start:last_idx+1].copy()
        pred_vals = result["prediction_counts"]
        if 'date' in hist_slice.columns and not hist_slice['date'].isna().all():
            dates = list(hist_slice['date']) + [pd.to_datetime(hist_slice['date'].iloc[-1]) + pd.to_timedelta(7*i, unit='d') for i in range(1, len(pred_vals)+1)]
            display_df = pd.DataFrame({
                "date": dates,
                "real": list(hist_slice['casos_soma']) + [None]*len(pred_vals),
                "forecast": [None]*len(hist_slice) + list(pred_vals)
            })
        else:
            idxs = list(range(len(hist_slice))) + list(range(len(hist_slice), len(hist_slice) + len(pred_vals)))
            display_df = pd.DataFrame({
                "index": idxs,
                "real": list(hist_slice['casos_soma']) + [None]*len(pred_vals),
                "forecast": [None]*len(hist_slice) + list(pred_vals)
            })
        show_interactive(display_df, state_sigla, year, week)

button.on_click(on_button_clicked)
display(state_widget, year_widget, week_widget, button, output)


Text(value='AC', description='State:')

IntText(value=2025, description='Year:')

IntText(value=20, description='Week:')

Button(description='Generate Forecast', style=ButtonStyle())

Output()