In [None]:
# =====================================================================
# Approach 3 (variant): INFERENCE on a held-out NON-BUBBLE CSV
# - Uses unified package. Reports p(bubble) on the LAST 24 months.
# =====================================================================

import torch, torch.nn as nn, numpy as np, pandas as pd, warnings, os, time, math
warnings.filterwarnings("ignore")

# Colab detection
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ---- Encoders & model (same as training)
class EncoderBiLSTM(nn.Module):
    def __init__(self, in_dim, emb=128):
        super().__init__()
        self.lstm = nn.LSTM(in_dim, emb, num_layers=2, bidirectional=True, batch_first=True)
        self.fc   = nn.Linear(emb*2, emb)
    def forward(self, x):
        _, (h, _) = self.lstm(x); h = torch.cat([h[-2], h[-1]], dim=1)
        return nn.functional.normalize(self.fc(h), dim=1)

class SinusoidalPE(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2)*(-math.log(10000.0)/d_model))
        pe[:,0::2] = torch.sin(pos*div); pe[:,1::2] = torch.cos(pos*div)
        self.register_buffer("pe", pe.unsqueeze(0))
    def forward(self, x): return x + self.pe[:, :x.size(1)]

class EncoderTransformer(nn.Module):
    def __init__(self, in_dim, emb=128, nhead=4, num_layers=2, dropout=0.1, pool="last"):
        super().__init__()
        self.input = nn.Linear(in_dim, emb)
        layer = nn.TransformerEncoderLayer(d_model=emb, nhead=nhead, batch_first=True,
                                           dropout=dropout, norm_first=True)
        self.enc = nn.TransformerEncoder(layer, num_layers=num_layers)
        self.pos = SinusoidalPE(emb); self.pool = pool
        self.cls = nn.Parameter(torch.zeros(1,1,emb)) if pool == "cls" else None
    def forward(self, x):
        h = self.input(x)
        if self.cls is not None:
            cls = self.cls.expand(x.size(0), -1, -1); h = torch.cat([cls, h], dim=1)
        h = self.enc(self.pos(h))
        if   self.pool == "cls":  z = h[:, 0, :]
        elif self.pool == "mean": z = h.mean(dim=1)
        else:                     z = h[:, -1, :]
        return nn.functional.normalize(z, dim=1)

class BubbleVsNonBubble(nn.Module):
    def __init__(self, in_dim, emb=128, kind="bilstm", pool="last", nhead=4, num_layers=2, dropout=0.1):
        super().__init__()
        self.encoder = EncoderBiLSTM(in_dim, emb) if kind=="bilstm" else \
                       EncoderTransformer(in_dim, emb, nhead, num_layers, dropout, pool)
        self.classifier = nn.Sequential(
            nn.Linear(emb, 64), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(64, 32), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(32, 1), nn.Sigmoid()
        )
    def forward(self, x):
        z = self.encoder(x); p_bub = self.classifier(z)
        return z, p_bub.squeeze()
    @torch.no_grad()
    def get_bubble_probability(self, x):
        _, p = self.forward(x); return p

def risk_label(p_bub):
    if p_bub >= 0.8: return "üî¥ Very High Bubble Risk"
    if p_bub >= 0.6: return "üü† High Bubble Risk"
    if p_bub >= 0.4: return "üü° Moderate Bubble Risk"
    if p_bub >= 0.2: return "üü¢ Low Bubble Risk"
    return "üîµ Very Low Bubble Risk"

def analyze_last_24(csv_path, model, scalers, info):
    df = pd.read_csv(csv_path, parse_dates=["Date"])
    if "PPIACO" in df.columns and "PPI" not in df.columns:
        df.rename(columns={"PPIACO":"PPI"}, inplace=True)
    need_cols, macro_cols, dow_cols = scalers["need_cols"], scalers["macro_cols"], scalers["dow_cols"]
    df_clean = df.dropna(subset=need_cols).reset_index(drop=True)
    window = info.get("window", 24)
    if len(df_clean) < window:
        print(f"‚ùó Warning: only {len(df_clean)} rows (need {window}). Using all rows.")
        window = len(df_clean)
        if window == 0: return None

    # Use LAST 24 months to match the paper‚Äôs evaluation protocol
    if len(df_clean) > window:
        df_clean = df_clean.iloc[-window:].reset_index(drop=True)

    Xm = scalers["sc_macro"].transform(df_clean[macro_cols]).astype("float32")
    Xd = scalers["sc_dow"].transform(df_clean[dow_cols]).astype("float32")
    seq = np.hstack([Xm, Xd])
    seq = torch.tensor(seq).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        p_bub = float(model.get_bubble_probability(seq).cpu().item())
    return p_bub

# ---- Load package
if IN_COLAB:
    from google.colab import files
    print("üì¶ Upload 'approach3_joint_model_package.pth'")
    up = files.upload(); package_path = next(iter(up))
else:
    package_path = "approach3_joint_model_package.pth"

package = torch.load(package_path, map_location=DEVICE, weights_only=False)
cfg, scalers = package["model_config"], package["scalers"]
assert cfg.get("predicts") == "bubble", "Loaded package must have predicts='bubble'."
print(f"‚úÖ Loaded | encoder={cfg['encoder_kind']} | emb={cfg['emb']} | pool={cfg['pool']}")

# Rebuild model
model = BubbleVsNonBubble(in_dim=cfg.get("in_dim",6),
                          emb=cfg["emb"],
                          kind=cfg["encoder_kind"],
                          pool=cfg.get("pool","last"),
                          nhead=cfg.get("transformer",{}).get("nhead",4),
                          num_layers=cfg.get("transformer",{}).get("num_layers",2),
                          dropout=cfg.get("transformer",{}).get("dropout",0.1)).to(DEVICE)
model.load_state_dict(package["model_state_dict"]); model.eval()

# ---- Upload ONE held-out NON-BUBBLE CSV and evaluate
if IN_COLAB:
    print("\nüìÇ Upload the HELD-OUT NON-BUBBLE CSV to evaluate:")
    up2 = files.upload(); csv_path = next(iter(up2))
else:
    csv_path = "YOUR_HELD_OUT_NONBUBBLE.csv"

t0 = time.time()
p_bub = analyze_last_24(csv_path, model, scalers, cfg)
dt = (time.time() - t0)*1000

print("\n" + "="*70)
print(f"üîé File: {os.path.basename(csv_path)}")
if p_bub is None:
    print("‚ùå Not enough clean rows to analyze.")
else:
    print(f"üìà p(bubble) on LAST 24 months: {p_bub:.4f}")
    print(f"üí¨ Interpretation: {risk_label(p_bub)}  (should be LOW for non-bubble)")
    print(f"‚è±Ô∏è Processing Time: {dt:.2f} ms")
print("="*70)

üì¶ Upload 'approach3_joint_model_package.pth'


KeyboardInterrupt: 