In [None]:
# =====================================================================
# INFERENCE SCRIPT FOR BUBBLE DETECTOR (FINAL REVISION)
# =====================================================================
import torch, torch.nn as nn, numpy as np, pandas as pd
import warnings, sys, os, time

# --- Environment Detection ---
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if not IN_COLAB:
    import argparse

warnings.filterwarnings("ignore")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- 1. Model Class Definitions (Must match training script) ---
class Encoder(nn.Module):
    def __init__(self, in_dim, emb=128):
        super().__init__()
        self.lstm = nn.LSTM(in_dim, emb, 2, bidirectional=True, batch_first=True)
        self.fc   = nn.Linear(emb*2, emb)
    def forward(self, x):
        _, (h, _) = self.lstm(x)
        h = torch.cat([h[-2], h[-1]], 1)
        return nn.functional.normalize(self.fc(h), dim=1)

class BubbleDetector(nn.Module):
    def __init__(self, in_dim, emb=128):
        super().__init__()
        self.encoder = Encoder(in_dim, emb)
        self.classifier = nn.Sequential(
            nn.Linear(emb, 64), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(64, 32), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(32, 1), nn.Sigmoid()
        )
    def forward(self, x):
        z = self.encoder(x)
        prob = self.classifier(z)
        return z, prob.squeeze()
    def get_probability(self, x):
        with torch.no_grad():
            _, prob = self.forward(x)
        return prob

# --- 2. Utility and Analysis Functions ---
def calculate_all_probabilities(csv_path, model, scalers, info):
    """Calculates bubble probability for all sliding windows in a CSV."""
    df = pd.read_csv(csv_path, parse_dates=["Date"])
    if "PPIACO" in df.columns and "PPI" not in df.columns:
        df.rename(columns={"PPIACO": "PPI"}, inplace=True)

    df_clean = df.dropna(subset=scalers['need_cols']).reset_index(drop=True)

    # --- MODIFIED SECTION: Handle insufficient data gracefully ---
    window = info['window']
    if len(df_clean) < window:
        print(f"❗ Warning: Input data has only {len(df_clean)} rows but model requires {window}. Analyzing with all available data.")
        window = len(df_clean) # Dynamically use available data as the window
        if window == 0:
            return [] # Return empty if no data after cleaning
    # --- END MODIFIED SECTION ---

    Xm = scalers['sc_macro'].transform(df_clean[scalers['macro_cols']]).astype("float32")
    Xd = scalers['sc_dow'].transform(df_clean[scalers['dow_cols']]).astype("float32")

    probabilities = []
    # This loop will now work even with a smaller, adjusted window
    for t in range(len(df_clean) - window + 1):
        seq_data = np.hstack([Xm[t:t+window], Xd[t:t+window]])
        seq = torch.tensor(seq_data).unsqueeze(0).to(DEVICE)
        prob = model.get_probability(seq).cpu().item()
        probabilities.append(prob)

    return probabilities

def get_risk_interpretation(probability):
    """Provides a text interpretation of a bubble probability score."""
    if probability >= 0.8: return "🔴 Very High Risk"
    elif probability >= 0.6: return "🟠 High Risk"
    elif probability >= 0.4: return "🟡 Moderate Risk"
    elif probability >= 0.2: return "🟢 Low Risk"
    else: return "🔵 Very Low Risk"

# --- 3. Main Execution Block ---
def main():
    print("📈 Bubble Detector - Probability Analysis")
    print("="*70)

    # --- Get file paths based on environment ---
    if IN_COLAB:
        from google.colab import files
        print("📂 [Colab] Upload 'bubble_model_package.pth':")
        uploaded_package = files.upload()
        package_path = next(iter(uploaded_package))
    else:
        parser = argparse.ArgumentParser(description="Bubble Probability Analysis")
        parser.add_argument("--model_package", type=str, default="bubble_model_package.pth", help="Path to the model package file.")
        parser.add_argument("--input_csv", type=str, required=True, help="Path to the input CSV file for analysis.")
        args = parser.parse_args()
        package_path = args.model_package
        csv_path_arg = args.input_csv
        if not os.path.exists(package_path) or not os.path.exists(csv_path_arg):
             print(f"❌ Error: One or more file paths are invalid.")
             sys.exit(1)

    # --- Load Model and Scalers from package ---
    package = torch.load(package_path, map_location=DEVICE, weights_only=False)
    model_config = package['model_config']
    scalers = package['scalers']

    model = BubbleDetector(in_dim=model_config['in_dim'], emb=model_config['emb']).to(DEVICE)
    model.load_state_dict(package['model_state_dict'])
    model.eval()
    print("✅ Model and scalers loaded successfully!")

    # --- Get CSV for analysis ---
    if IN_COLAB:
        print("\n📂 [Colab] Upload the CSV file to analyze:")
        uploaded_csv = files.upload()
        csv_path = next(iter(uploaded_csv))
    else:
        csv_path = csv_path_arg

    # --- Run analysis and display results ---
    t0 = time.time()
    probabilities = calculate_all_probabilities(csv_path, model, scalers, model_config)
    t1 = time.time()
    processing_time_ms = (t1 - t0) * 1000

    print("\n" + "="*70)
    print(f"🔎 Analysis Results for: '{os.path.basename(csv_path)}'")
    print("="*70)

    if not probabilities:
        print("❌ Could not perform analysis due to insufficient data.")
    else:
        mean_prob = np.mean(probabilities)
        latest_prob = probabilities[-1]

        trend_direction = "N/A"
        if len(probabilities) > 1: # Trend requires at least 2 points
            trend_slope = np.polyfit(range(len(probabilities)), probabilities, 1)[0]
            if trend_slope > 0.01: trend_direction = "Increasing ↗️"
            elif trend_slope < -0.01: trend_direction = "Decreasing ↘️"
            else: trend_direction = "Stable ↔️"

        print(f"📈 Mean Bubble Probability:   {mean_prob:.4f}")
        print(f"📉 Latest Bubble Probability: {latest_prob:.4f}")
        print(f"💬 Interpretation:            {get_risk_interpretation(latest_prob)}")
        print(f"📊 Probability Trend:         {trend_direction}")
        print(f"⏱️ Processing Time:           {processing_time_ms:.2f} ms")

    print("="*70)

if __name__ == "__main__":
    main()

📈 Bubble Detector - Probability Analysis
📂 [Colab] Upload 'bubble_model_package.pth':


Saving bubble_model_package.pth to bubble_model_package (3).pth
✅ Model and scalers loaded successfully!

📂 [Colab] Upload the CSV file to analyze:


Saving Merged_Black_Monday_Last_24_months.csv to Merged_Black_Monday_Last_24_months.csv

🔎 Analysis Results for: 'Merged_Black_Monday_Last_24_months.csv'
📈 Mean Bubble Probability:   0.8716
📉 Latest Bubble Probability: 0.8716
💬 Interpretation:            🔴 Very High Risk
📊 Probability Trend:         N/A
⏱️ Processing Time:           13.66 ms
