# ============================================================
# Clean TimeGAN-VAE SCSI Pipeline for Supply Chain Risk Analysis
# Version: 2.0 - Production Ready
# ============================================================

# %% [markdown]
# # TimeGAN-VAE Supply Chain Stress Index (SCSI) Pipeline
#
# This notebook generates a Supply Chain Stress Index using:
# - **TimeGAN**: For market-driven features (prices, freight rates)
# - **VAE**: For regulated features (fuel prices)
# - **PCA**: For dimensionality reduction into SCSI

In [1]:
# ============================================================
# 1. INSTALLATION & IMPORTS
# ============================================================

# Install required packages (for Google Colab)
!pip -q install pytorch-lightning==2.2.5 joblib tqdm openpyxl

import os
import sys
import json
import time
import random
import zipfile
import warnings
import shutil
import datetime as dt
from functools import reduce

warnings.filterwarnings("ignore")

# Data manipulation
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, Image, HTML

# Machine Learning
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
import pytorch_lightning as pl

import tensorflow as tf
from tensorflow.keras import layers

# Preprocessing
from sklearn.impute import KNNImputer
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import mean_squared_error, mean_absolute_error
from joblib import dump, load

print("✅ All packages imported successfully")


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/802.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m802.3/802.3 kB[0m [31m37.9 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/983.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.0/983.0 kB[0m [31m29.5 MB/s[0m eta [36m0:00:00[0m
[?25h✅ All packages imported successfully


In [2]:
import os, json, math, pathlib
import datetime as dt
import numpy as np
import pandas as pd

def to_jsonable(x):
    """
    Recursively convert objects to JSON-serializable types.
    Covers numpy, pandas, datetime, pathlib, mixed containers.
    """
    if x is None:
        return None

    # pandas NA first (works for scalars/Series cells)
    try:
        if pd.isna(x):
            return None
    except (TypeError, ValueError):
        pass

    # numpy scalars
    if hasattr(x, "item"):
        try:
            return x.item()
        except Exception:
            pass

    # explicit numpy types
    if isinstance(x, np.bool_):
        return bool(x)
    if isinstance(x, (np.integer,)):
        return int(x)
    if isinstance(x, (np.floating,)):
        val = float(x)
        return None if (math.isnan(val) or math.isinf(val)) else val

    # numpy arrays
    if isinstance(x, np.ndarray):
        return [to_jsonable(v) for v in x.tolist()]

    # pandas structures
    if isinstance(x, pd.Series):
        return [to_jsonable(v) for v in x.tolist()]
    if isinstance(x, pd.DataFrame):
        return [to_jsonable(rec) for rec in x.to_dict(orient="records")]
    if isinstance(x, (pd.Index, pd.MultiIndex)):
        return [to_jsonable(v) for v in x.tolist()]
    if isinstance(x, (pd.Timestamp, dt.datetime, dt.date)):
        return x.isoformat()
    if isinstance(x, pd.Timedelta):
        return str(x)

    # paths
    if isinstance(x, pathlib.Path):
        return str(x)

    # mappings/sequences/sets
    if isinstance(x, dict):
        return {str(k): to_jsonable(v) for k, v in x.items()}
    if isinstance(x, (list, tuple)):
        return [to_jsonable(v) for v in x]
    if isinstance(x, set):
        # convert elements first, then sort by string rep to avoid type compare errors
        conv = [to_jsonable(v) for v in x]
        return sorted(conv, key=lambda z: str(z))

    # basic py types
    if isinstance(x, (str, int, float, bool)):
        if isinstance(x, float) and (math.isnan(x) or math.isinf(x)):
            return None
        return x

    # fallback
    try:
        return str(x)
    except Exception:
        return f"<non-serializable: {type(x).__name__}>"

def save_json_safely(payload, json_path):
    """
    Sanitize and save JSON. On failure, writes a minimal error JSON.
    """
    try:
        os.makedirs(os.path.dirname(json_path) or ".", exist_ok=True)
        sanitized = to_jsonable(payload)
        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(sanitized, f, ensure_ascii=False, indent=2)
        print(f"✅ JSON saved successfully to {json_path}")
        return True
    except Exception as e:
        print(f"❌ JSON serialization failed: {e}")
        try:
            minimal = {
                "created_at": dt.datetime.now().isoformat(),
                "error": "Full payload could not be serialized",
                "error_details": str(e),
            }
            with open(json_path, "w", encoding="utf-8") as f:
                json.dump(minimal, f, ensure_ascii=False, indent=2)
            print(f"⚠️ Saved minimal JSON to {json_path}")
        except Exception as e2:
            print(f"❌ Even minimal JSON save failed: {e2}")
        return False


In [3]:
# ============================================================
# 2. CONFIGURATION & PARAMETERS
# ============================================================

class Config:
    """Centralized configuration for the pipeline"""

    # Paths
    BASE = "/content"
    OUTPUT_DIR = "/content/scsi_output"

    # Date range
    START_DATE = "2019-01-01"
    END_DATE = "2024-12-01"

    # Model parameters (optimized for small dataset)
    SEED = 42
    SEQ_LEN = 24  # Reduced from 60 for 72 data points

    # TimeGAN
    TIMEGAN_HIDDEN = 16  # Reduced from 32
    TIMEGAN_PRETRAIN_EPOCHS = 100
    TIMEGAN_EPOCHS = 500  # Reduced from 1000
    TIMEGAN_BATCH = 16
    TIMEGAN_LR = 3e-4
    TIMEGAN_RECON_WEIGHT = 10.0
    TIMEGAN_G_STEPS = 2
    TIMEGAN_D_STEPS = 1

    # VAE
    VAE_LATENT = 8  # Reduced from 12
    VAE_EPOCHS = 100
    VAE_BATCH = 16
    VAE_PATIENCE = 15
    VAE_BETA_START = 0.0
    VAE_BETA_END = 0.05
    VAE_BETA_WARMUP = 20
    VAE_LR = 1e-3

    # Feature columns
    REGULATED = ['ron95', 'ron97', 'diesel', 'diesel_eastmsia']
    MARKET = ['AsiaPacific', 'AsiaPacific_rescaled', 'AirFreightRate_Weekly', 'AirFreightRate_Annual']

    # Feature selection
    MAX_GPR_FEATURES = 10  # Limit GPR features to avoid overfitting
    MAX_TOTAL_FEATURES = 40  # Maximum total features to use

    # Test mode for quick development
    TEST_MODE = False  # Set True for quick testing with reduced epochs

    def __init__(self, test_mode=False):
        self.TEST_MODE = test_mode
        if test_mode:
            self.TIMEGAN_PRETRAIN_EPOCHS = 10
            self.TIMEGAN_EPOCHS = 50
            self.VAE_EPOCHS = 20
            print("⚠️ TEST MODE: Using reduced epochs for quick testing")

        # Create output directory
        os.makedirs(self.OUTPUT_DIR, exist_ok=True)

        # Define output files
        self.FILES = {
            'raw_data': f"{self.BASE}/raw_merged_unfiltered.csv",
            'final_output': f"{self.OUTPUT_DIR}/final_data_with_scsi.csv",
            'scsi_plot': f"{self.OUTPUT_DIR}/scsi_timeseries.png",
            'timegan_out': f"{self.OUTPUT_DIR}/timegan_synthetic.csv",
            'vae_out': f"{self.OUTPUT_DIR}/vae_synthetic.csv",
            'scsi_features': f"{self.OUTPUT_DIR}/scsi_features.csv",
            'validation_report': f"{self.OUTPUT_DIR}/validation_report.json"
        }

        # Model artifacts
        self.ARTIFACTS = {
            'ret_scaler': f"{self.OUTPUT_DIR}/timegan_scaler.joblib",
            'vae_scaler': f"{self.OUTPUT_DIR}/vae_scaler.joblib",
            'std_scaler': f"{self.OUTPUT_DIR}/pca_scaler.joblib",
            'pca_model': f"{self.OUTPUT_DIR}/pca_model.joblib",
            'config': f"{self.OUTPUT_DIR}/config.json"
        }

# Initialize configuration
cfg = Config(test_mode=False)  # Set to True for testing


In [4]:
# ============================================================
# 3. UTILITY FUNCTIONS
# ============================================================

def setup_seeds(seed=42):
    """Set all random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    tf.random.set_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    pl.seed_everything(seed, workers=True)
    print(f"🎲 Random seeds set to {seed}")

def setup_gpu():
    """Configure GPU settings for TensorFlow"""
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            try:
                tf.config.experimental.set_memory_growth(gpu, True)
            except:
                pass
        print(f"🖥️ GPU configured: {len(gpus)} device(s) available")
    else:
        print("💻 No GPU found, using CPU")

def print_header(title, level=1):
    """Print formatted section headers"""
    if level == 1:
        print("\n" + "="*70)
        print(f"  {title}")
        print("="*70)
    elif level == 2:
        print(f"\n📍 {title}")
        print("-"*50)
    else:
        print(f"\n  ▶ {title}")

def format_time(seconds):
    """Format seconds into readable time"""
    if seconds < 60:
        return f"{seconds:.1f}s"
    elif seconds < 3600:
        return f"{seconds/60:.1f}min"
    else:
        return f"{seconds/3600:.1f}h"


In [5]:
# ============================================================
# 4. DATA LOADING & PREPROCESSING (FIXED)
# ============================================================

class DataManager:
    """Handles all data loading and preprocessing operations"""

    def __init__(self, config):
        self.cfg = config

    def upload_files(self):
        """Handle file uploads in Google Colab"""
        from google.colab import files

        print_header("FILE UPLOAD", 2)

        required_files = {
            'VAE_files.zip': 'ZIP containing asia/fax/diesel CSVs',
            'gscpi_data.xlsx': 'GSCPI data',
            'data_gpr_export.xlsx': 'GPR data',
            'monthly_trade_dummy.csv': 'Trade data',
            'monthly_air_cargo_from_spline.csv': 'Air cargo data'
        }

        for filename, description in required_files.items():
            filepath = os.path.join(self.cfg.BASE, filename)
            if not os.path.exists(filepath):
                print(f"\n📁 Please upload: {description}")
                uploaded = files.upload()
                if uploaded:
                    # Move to correct location
                    for fname in uploaded.keys():
                        shutil.move(fname, filepath)
                    print(f"✅ Saved: {filename}")
            else:
                print(f"✅ Found: {filename}")

    @staticmethod
    def safe_monthly_impute(
        df: pd.DataFrame,
        expected_cols: list | None = None,
        span_start: str | None = None,   # e.g. "2019-01-01"
        span_end: str | None = None,     # e.g. "2024-12-01"
        date_col: str | None = None      # if df has a 'Date' column instead of DateTimeIndex
    ) -> pd.DataFrame:
        """
        Shape-safe monthly imputation:
        - Enforces monthly DatetimeIndex window [span_start, span_end]
        - Removes duplicate column names
        - Locks expected column set (adds missing as NaN, drops extras)
        - Converts to numeric (errors -> NaN)
        - Column-wise time interpolation + ffill/bfill
        """
        dbg = lambda *a: print("🔧", *a)
        df = df.copy()

        # ---- ensure DateTimeIndex
        if date_col and (date_col in df.columns):
            df[date_col] = pd.to_datetime(df[date_col], errors="coerce")
            df = df.dropna(subset=[date_col]).set_index(date_col)

        if not isinstance(df.index, pd.DatetimeIndex):
            raise ValueError("safe_monthly_impute requires a DatetimeIndex or a date_col to set one.")

        df = df.sort_index()

        # ---- dedup columns
        if df.columns.duplicated().any():
            dups = df.columns[df.columns.duplicated()].tolist()
            dbg(f"Duplicate columns found (keeping first): {dups}")
            df = df.loc[:, ~df.columns.duplicated()]

        # ---- lock expected columns
        if expected_cols is None:
            expected_cols = df.columns.tolist()

        # add missing, drop extras
        missing = [c for c in expected_cols if c not in df.columns]
        if missing:
            dbg(f"Adding missing expected columns with NaN: {missing}")
            for c in missing:
                df[c] = np.nan

        extras = [c for c in df.columns if c not in expected_cols]
        if extras:
            dbg(f"Dropping unexpected extra columns: {extras}")
            df = df.drop(columns=extras)

        # numeric coercion
        for c in expected_cols:
            df[c] = pd.to_numeric(df[c], errors="coerce")

        # ---- enforce continuous monthly index
        start = (pd.to_datetime(span_start).to_period("M").to_timestamp()
                 if span_start else df.index.min().to_period("M").to_timestamp())
        end   = (pd.to_datetime(span_end).to_period("M").to_timestamp()
                 if span_end else df.index.max().to_period("M").to_timestamp())
        monthly_idx = pd.date_range(start=start, end=end, freq="MS")
        df = df.reindex(monthly_idx)

        # ---- column-wise interpolation (preserves (n_rows, n_cols))
        try:
            imputed = (df[expected_cols]
                       .interpolate(method="time", limit_direction="both")
                       .ffill().bfill())
        except Exception as e:
            dbg(f"time interpolation failed ({e}); falling back to linear")
            imputed = (df[expected_cols]
                       .interpolate(method="linear", limit_direction="both")
                       .ffill().bfill())

        # extra safety
        if list(imputed.columns) != expected_cols or imputed.shape[1] != len(expected_cols):
            raise RuntimeError("Imputation changed columns/shape unexpectedly.")

        df[expected_cols] = imputed
        return df

    def process_vae_files(self):
        """Extract and process VAE zip file"""
        zip_path = f"{self.cfg.BASE}/VAE_files.zip"
        if not os.path.exists(zip_path):
            raise FileNotFoundError("VAE_files.zip not found")

        print("  ▶ Processing VAE files...")

        # Extract ZIP
        extract_dir = f"{self.cfg.BASE}/VAE_temp"
        with zipfile.ZipFile(zip_path, 'r') as z:
            z.extractall(extract_dir)

        # Find CSV files
        csv_files = {}
        for root, dirs, files in os.walk(extract_dir):
            for file in files:
                if 'asia_pacific' in file.lower():
                    csv_files['asia'] = os.path.join(root, file)
                elif 'fax_annual' in file.lower():
                    csv_files['fax_annual'] = os.path.join(root, file)
                elif 'fax_weekly' in file.lower():
                    csv_files['fax_weekly'] = os.path.join(root, file)
                elif 'diesel' in file.lower():
                    csv_files['diesel'] = os.path.join(root, file)

        # Load and merge
        dfs = []
        if 'asia' in csv_files:
            dfs.append(pd.read_csv(csv_files['asia'], parse_dates=['Date']))
        if 'fax_annual' in csv_files:
            df = pd.read_csv(csv_files['fax_annual'], parse_dates=['Date'])
            df = df.rename(columns={'AirFreightRate': 'AirFreightRate_Annual'})
            dfs.append(df)
        if 'fax_weekly' in csv_files:
            df = pd.read_csv(csv_files['fax_weekly'], parse_dates=['Date'])
            df = df.rename(columns={'AirFreightRate': 'AirFreightRate_Weekly'})
            dfs.append(df)
        if 'diesel' in csv_files:
            df = pd.read_csv(csv_files['diesel'], parse_dates=['date'])
            df = df.rename(columns={'date': 'Date'})
            if 'series_type' in df.columns:
                df = df[df['series_type'].str.lower() == 'level']
            dfs.append(df)

        # Merge all
        raw_merged = reduce(lambda l, r: pd.merge(l, r, on='Date', how='outer'), dfs)
        raw_merged = raw_merged.sort_values('Date')

        # Remove duplicate columns if any
        raw_merged = raw_merged.loc[:, ~raw_merged.columns.duplicated()]

        raw_merged.to_csv(self.cfg.FILES['raw_data'], index=False)

        # Cleanup
        shutil.rmtree(extract_dir)

        print(f"    ✔ Created raw data: {len(raw_merged)} rows")
        return raw_merged

    def to_monthly(self, df, date_col='Date'):
        """Convert data to monthly frequency"""
        df = df.copy()
        if date_col in df.columns:
            df[date_col] = pd.to_datetime(df[date_col], errors='coerce')
            df = df.dropna(subset=[date_col]).set_index(date_col)
        monthly = df.resample('MS').mean(numeric_only=True)
        # Remove duplicate columns after resampling
        monthly = monthly.loc[:, ~monthly.columns.duplicated()]
        return monthly

    def load_external_data(self, df_monthly):
        """Load and merge external datasets"""
        print("  ▶ Loading external datasets...")

        # GSCPI
        if os.path.exists(f"{self.cfg.BASE}/gscpi_data.xlsx"):
            gscpi = pd.read_excel(f"{self.cfg.BASE}/gscpi_data.xlsx")
            date_col = [c for c in gscpi.columns if 'date' in c.lower()]
            if date_col:
                gscpi = gscpi.rename(columns={date_col[0]: 'Date'})
            elif len(gscpi.columns) > 0:
                gscpi = gscpi.rename(columns={gscpi.columns[0]: 'Date'})
            gscpi_monthly = self.to_monthly(gscpi)
            if len(gscpi_monthly.columns) == 1:
                gscpi_monthly.columns = ['GSCPI_NYFED']
            # Ensure no duplicate columns before joining
            for col in gscpi_monthly.columns:
                if col in df_monthly.columns:
                    df_monthly = df_monthly.drop(columns=[col])
            df_monthly = df_monthly.join(gscpi_monthly, how='outer')
            print(f"    ✔ GSCPI: {gscpi_monthly.shape}")

        # GPR
        if os.path.exists(f"{self.cfg.BASE}/data_gpr_export.xlsx"):
            gpr = pd.read_excel(f"{self.cfg.BASE}/data_gpr_export.xlsx")
            date_col = [c for c in gpr.columns if 'date' in c.lower()]
            if date_col:
                gpr = gpr.rename(columns={date_col[0]: 'Date'})
            elif len(gpr.columns) > 0:
                gpr = gpr.rename(columns={gpr.columns[0]: 'Date'})
            gpr_monthly = self.to_monthly(gpr)
            # Limit GPR features
            if len(gpr_monthly.columns) > self.cfg.MAX_GPR_FEATURES:
                # Select top features by variance
                variances = gpr_monthly.var().sort_values(ascending=False)
                top_features = variances.head(self.cfg.MAX_GPR_FEATURES).index
                gpr_monthly = gpr_monthly[top_features]
            # Ensure unique column names
            for col in gpr_monthly.columns:
                if col in df_monthly.columns:
                    df_monthly = df_monthly.drop(columns=[col])
            df_monthly = df_monthly.join(gpr_monthly, how='outer')
            print(f"    ✔ GPR: {gpr_monthly.shape}")

        # Trade
        if os.path.exists(f"{self.cfg.BASE}/monthly_trade_dummy.csv"):
            trade = pd.read_csv(f"{self.cfg.BASE}/monthly_trade_dummy.csv", parse_dates=['Date'])
            trade_monthly = self.to_monthly(trade).add_suffix('_trade')
            # Ensure unique column names
            for col in trade_monthly.columns:
                if col in df_monthly.columns:
                    df_monthly = df_monthly.drop(columns=[col])
            df_monthly = df_monthly.join(trade_monthly, how='outer')
            print(f"    ✔ Trade: {trade_monthly.shape}")

        # Air cargo
        if os.path.exists(f"{self.cfg.BASE}/monthly_air_cargo_from_spline.csv"):
            air = pd.read_csv(f"{self.cfg.BASE}/monthly_air_cargo_from_spline.csv", parse_dates=['Date'])
            air_monthly = self.to_monthly(air).add_suffix('_air')
            # Ensure unique column names
            for col in air_monthly.columns:
                if col in df_monthly.columns:
                    df_monthly = df_monthly.drop(columns=[col])
            df_monthly = df_monthly.join(air_monthly, how='outer')
            print(f"    ✔ Air cargo: {air_monthly.shape}")

        # Final cleanup of duplicate columns
        df_monthly = df_monthly.loc[:, ~df_monthly.columns.duplicated()]

        return df_monthly

    def impute_missing(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Robust imputation:
        1) Shape/column-safe monthly interpolation (safe_monthly_impute)
        2) KNNImputer for any leftovers (numeric only), assigned in-bulk
        3) Median fallback
        """
        print("  ▶ Imputing missing values...")

        # Lock expected feature set before any transforms
        expected_cols = df.columns.tolist()

        # 1) time-based interpolation that preserves shape
        df = self.safe_monthly_impute(
            df,
            expected_cols=expected_cols,
            span_start=self.cfg.START_DATE,
            span_end=self.cfg.END_DATE
        )

        initial_missing = int(df.isna().sum().sum())

        # 2) KNN for remaining NaNs (if any)
        if df.isna().any().any():
            # Numeric columns only, preserve order & uniqueness
            num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
            num_cols = pd.Index(num_cols).unique().tolist()

            if num_cols:
                X = df[num_cols].to_numpy(dtype=float)
                n_rows, n_cols = X.shape
                print(f"    ℹ️ KNN stage on {n_cols} numeric cols, {n_rows} rows")

                n_neighbors = max(2, min(5, len(df)//3))
                imputer = KNNImputer(n_neighbors=n_neighbors)
                X_imp = imputer.fit_transform(X)

                # Guard against any unexpected width drift
                if X_imp.shape[1] != len(num_cols):
                    print(f"    ⚠️ Column count mismatch after KNN: "
                          f"num_cols={len(num_cols)}, X_imp={X_imp.shape[1]}. Trimming to min.")
                    keep = min(len(num_cols), X_imp.shape[1])
                    num_cols = num_cols[:keep]
                    X_imp = X_imp[:, :keep]

                # Bulk assign back
                df.loc[:, num_cols] = X_imp

        # 3) Median fallback
        df = df.fillna(df.median(numeric_only=True))

        final_missing = int(df.isna().sum().sum())
        print(f"    ✔ Imputed {initial_missing - final_missing} values (remaining: {final_missing})")
        return df

    def add_temporal_features(self, df):
        """Add temporal and policy features"""
        df['sin_month'] = np.sin(2 * np.pi * df.index.month / 12)
        df['cos_month'] = np.cos(2 * np.pi * df.index.month / 12)
        df['policy_dummy'] = 0.0
        df.loc['2020-04-01':'2020-12-01', 'policy_dummy'] = 1.0
        return df

    def select_features(self, df):
        """Select most relevant features to avoid overfitting"""
        print("  ▶ Selecting relevant features...")

        # Ensure no duplicate columns before selection
        df = df.loc[:, ~df.columns.duplicated()]

        essential = self.cfg.MARKET + self.cfg.REGULATED
        essential += ['sin_month', 'cos_month', 'policy_dummy', 'GSCPI_NYFED']

        # Add available trade and cargo features
        essential += [c for c in df.columns if '_trade' in c or '_air' in c]

        # Add top GPR features
        gpr_cols = [c for c in df.columns if 'GPR' in c.upper()]
        if len(gpr_cols) > self.cfg.MAX_GPR_FEATURES:
            variances = df[gpr_cols].var().sort_values(ascending=False)
            essential += variances.head(self.cfg.MAX_GPR_FEATURES).index.tolist()
        else:
            essential += gpr_cols

        # Limit total features and ensure uniqueness
        essential = list(set([c for c in essential if c in df.columns]))
        if len(essential) > self.cfg.MAX_TOTAL_FEATURES:
            # Prioritize by variance
            variances = df[essential].var().sort_values(ascending=False)
            essential = variances.head(self.cfg.MAX_TOTAL_FEATURES).index.tolist()

        df_selected = df[essential]
        print(f"    ✔ Selected {len(essential)} features (from {len(df.columns)})")
        return df_selected

    def load_and_prepare_data(self):
        """Main data loading and preparation pipeline"""
        print_header("DATA LOADING & PREPARATION", 1)

        # Process VAE files or load existing
        if os.path.exists(self.cfg.FILES['raw_data']):
            raw_data = pd.read_csv(self.cfg.FILES['raw_data'], parse_dates=['Date'])
        else:
            raw_data = self.process_vae_files()

        # Convert to monthly
        df_monthly = self.to_monthly(raw_data)

        # Load external data
        df_monthly = self.load_external_data(df_monthly)

        # Filter date range
        df_monthly = df_monthly.loc[self.cfg.START_DATE:self.cfg.END_DATE]

        # Impute missing values (shape-safe)
        df_monthly = self.impute_missing(df_monthly)

        # Add temporal features
        df_monthly = self.add_temporal_features(df_monthly)

        # Select features
        df_final = self.select_features(df_monthly)

        print(f"\n✅ Data prepared: {df_final.shape[0]} rows × {df_final.shape[1]} features")
        return df_final


In [6]:
# ============================================================
# 5. TIMEGAN IMPLEMENTATION
# ============================================================

# Configure TensorFlow
tf.keras.backend.clear_session()
GRU_CONFIG = dict(
    return_sequences=True,
    activation='tanh',
    recurrent_activation='sigmoid',
    reset_after=False,  # Disable CuDNN
    implementation=2,
    recurrent_dropout=0.1
)

class TimeGANModel:
    """TimeGAN implementation for market features"""

    def __init__(self, config):
        self.cfg = config

    def build_models(self, hidden_dim, data_dim):
        """Build TimeGAN component models"""

        # Embedder
        embedder = tf.keras.Sequential([
            layers.GRU(hidden_dim, **GRU_CONFIG),
            layers.LayerNormalization(),
            layers.GRU(hidden_dim, **GRU_CONFIG),
            layers.Dense(hidden_dim, activation='sigmoid')
        ])

        # Recovery
        recovery = tf.keras.Sequential([
            layers.GRU(hidden_dim, **GRU_CONFIG),
            layers.LayerNormalization(),
            layers.GRU(hidden_dim, **GRU_CONFIG),
            layers.Dense(data_dim)
        ])

        # Generator
        generator = tf.keras.Sequential([
            layers.GRU(hidden_dim, **GRU_CONFIG),
            layers.LayerNormalization(),
            layers.GRU(hidden_dim, **GRU_CONFIG),
            layers.Dense(hidden_dim)
        ])

        # Supervisor
        supervisor = tf.keras.Sequential([
            layers.GRU(hidden_dim, **GRU_CONFIG),
            layers.LayerNormalization(),
            layers.GRU(hidden_dim, **GRU_CONFIG),
            layers.Dense(hidden_dim)
        ])

        # Discriminator
        discriminator = tf.keras.Sequential([
            layers.GRU(hidden_dim, **GRU_CONFIG),
            layers.LayerNormalization(),
            layers.GRU(hidden_dim, **GRU_CONFIG),
            layers.Dense(1)
        ])

        return embedder, recovery, generator, supervisor, discriminator

    def make_sequences(self, data, seq_len):
        """Create sequences for training"""
        if len(data) <= seq_len:
            return np.empty((0, seq_len, data.shape[1]), dtype=np.float32)

        sequences = []
        for i in range(len(data) - seq_len):
            sequences.append(data[i:i+seq_len])

        return np.array(sequences, dtype=np.float32)

    def train_and_generate(self, df, market_cols, cond_cols):
        """Train TimeGAN and generate synthetic data"""
        print_header("TimeGAN Training", 2)

        # Get available columns
        avail_market = [c for c in market_cols if c in df.columns]
        avail_cond = [c for c in cond_cols if c in df.columns]

        if not avail_market:
            print("  ⚠️ No market columns available")
            return pd.DataFrame(index=df.index)

        print(f"  Market features: {len(avail_market)}")
        print(f"  Conditional features: {len(avail_cond)}")

        # Prepare data
        df_values = df[avail_market].copy()

        # Convert to log-returns for stability
        df_ret = df_values.apply(lambda s: 100 * np.log(s.clip(lower=1e-6)).diff()).dropna()

        if df_ret.empty:
            return pd.DataFrame(index=df.index)

        # Get conditional data
        df_cond = df.loc[df_ret.index, avail_cond] if avail_cond else pd.DataFrame(index=df_ret.index)

        # Scale data
        ret_scaler = MinMaxScaler()
        X_ret = ret_scaler.fit_transform(df_ret)
        dump(ret_scaler, self.cfg.ARTIFACTS['ret_scaler'])

        # Combine features
        if not df_cond.empty:
            X_all = np.concatenate([X_ret, df_cond.values.astype(np.float32)], axis=1)
        else:
            X_all = X_ret

        X_all = X_all.astype(np.float32)

        # Create sequences
        T_eff = min(self.cfg.SEQ_LEN, max(8, len(X_all) - 1))
        seqs = self.make_sequences(X_all, T_eff)

        if seqs.shape[0] == 0:
            return pd.DataFrame(index=df.index)

        print(f"  Sequences: {seqs.shape}")

        # Build models
        H = self.cfg.TIMEGAN_HIDDEN
        emb, rec, gen, sup, dis = self.build_models(H, X_all.shape[1])

        # Optimizers
        e_opt = tf.keras.optimizers.Adam(self.cfg.TIMEGAN_LR, clipnorm=1.0)
        g_opt = tf.keras.optimizers.Adam(self.cfg.TIMEGAN_LR, clipnorm=1.0)
        d_opt = tf.keras.optimizers.Adam(self.cfg.TIMEGAN_LR, clipnorm=1.0)

        # Training functions
        @tf.function
        def step_embed(X):
            with tf.GradientTape() as tape:
                X_rec = rec(emb(X))
                loss = tf.reduce_mean(tf.square(X - X_rec))
            vars_ = emb.trainable_variables + rec.trainable_variables
            grads = tape.gradient(loss, vars_)
            e_opt.apply_gradients(zip(grads, vars_))
            return loss

        @tf.function
        def step_gen(Z, X):
            with tf.GradientTape() as tape:
                H_ = gen(Z)
                H_hat = sup(H_)
                logits_fake = dis(H_hat)
                adv = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                    tf.ones_like(logits_fake) * 0.9, logits_fake))
                recon = tf.reduce_mean(tf.square(X - rec(H_hat)))
                total = adv + self.cfg.TIMEGAN_RECON_WEIGHT * recon
            vars_ = gen.trainable_variables + sup.trainable_variables
            grads = tape.gradient(total, vars_)
            g_opt.apply_gradients(zip(grads, vars_))
            return total

        @tf.function
        def step_disc(X, Z):
            with tf.GradientTape() as tape:
                logits_real = dis(emb(X))
                H_ = gen(Z)
                H_hat = sup(H_)
                logits_fake = dis(H_hat)
                loss_real = tf.nn.sigmoid_cross_entropy_with_logits(
                    tf.ones_like(logits_real) * 0.9, logits_real)
                loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(
                    tf.zeros_like(logits_fake) + 0.1, logits_fake)
                loss = tf.reduce_mean(loss_real) + tf.reduce_mean(loss_fake)
            grads = tape.gradient(loss, dis.trainable_variables)
            d_opt.apply_gradients(zip(grads, dis.trainable_variables))
            return loss

        def batchify(arr, bs):
            n = len(arr)
            idx = np.arange(n)
            np.random.shuffle(idx)
            for i in range(0, n, bs):
                yield tf.convert_to_tensor(arr[idx[i:i+bs]])

        # Pretraining
        print("  Training embedder...")
        for epoch in range(self.cfg.TIMEGAN_PRETRAIN_EPOCHS):
            for Xb in batchify(seqs, self.cfg.TIMEGAN_BATCH):
                step_embed(Xb)
            if epoch % 50 == 0:
                print(f"    Epoch {epoch}/{self.cfg.TIMEGAN_PRETRAIN_EPOCHS}")

        # Adversarial training
        print("  Adversarial training...")
        for epoch in range(self.cfg.TIMEGAN_EPOCHS):
            for Xb in batchify(seqs, self.cfg.TIMEGAN_BATCH):
                for _ in range(self.cfg.TIMEGAN_G_STEPS):
                    Zb = tf.random.normal([len(Xb), T_eff, H], seed=self.cfg.SEED)
                    step_gen(Zb, Xb)
                for _ in range(self.cfg.TIMEGAN_D_STEPS):
                    Zb = tf.random.normal([len(Xb), T_eff, H], seed=self.cfg.SEED)
                    step_disc(Xb, Zb)
            if epoch % 100 == 0:
                print(f"    Epoch {epoch}/{self.cfg.TIMEGAN_EPOCHS}")

        # Generate synthetic data
        print("  Generating synthetic data...")
        Z = tf.random.normal([1, T_eff, H], seed=self.cfg.SEED)
        X_hat = rec(sup(gen(Z))).numpy()[0]

        # Convert back to levels
        R_hat = ret_scaler.inverse_transform(X_hat[:, :len(avail_market)])

        idx = df_ret.index[-T_eff:]
        start = df.loc[idx[0], avail_market].values
        levels = np.zeros_like(R_hat)
        levels[0] = start * np.exp(R_hat[0] / 100.0)
        for t in range(1, len(R_hat)):
            levels[t] = levels[t-1] * np.exp(R_hat[t] / 100.0)

        out = pd.DataFrame(levels, index=idx, columns=avail_market).clip(lower=0)
        out.to_csv(self.cfg.FILES['timegan_out'])

        # Interpolate to full range
        return out.reindex(df.index).interpolate(method='linear', limit_direction='both')


In [7]:
# ============================================================
# 6. VAE IMPLEMENTATION
# ============================================================

class DatasetFromTensor(Dataset):
    def __init__(self, data):
        self.data = data
    def __getitem__(self, i):
        return self.data[i]
    def __len__(self):
        return self.data.size(0)

class TimeSeriesVAE(pl.LightningModule):
    """VAE for regulated features"""

    def __init__(self, input_dim, config):
        super().__init__()
        self.save_hyperparameters()
        self.cfg = config
        self.beta = config.VAE_BETA_START

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.1)
        )

        self.fc_mu = nn.Linear(32, config.VAE_LATENT)
        self.fc_logv = nn.Linear(32, config.VAE_LATENT)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(config.VAE_LATENT, 32),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, input_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        mu, logv = self.fc_mu(h), self.fc_logv(h)
        std = torch.exp(0.5 * logv)
        z = mu + torch.randn_like(std) * std
        return self.decoder(z), mu, logv

    def vae_loss(self, x, x_hat, mu, logv):
        recon = nn.functional.mse_loss(x_hat, x, reduction='mean')
        kl = -0.5 * torch.mean(1 + logv - mu.pow(2) - logv.exp())
        return recon + self.beta * kl

    def on_train_epoch_start(self):
        # Beta annealing
        e = self.current_epoch
        if e < self.cfg.VAE_BETA_WARMUP:
            self.beta = self.cfg.VAE_BETA_START + \
                       (self.cfg.VAE_BETA_END - self.cfg.VAE_BETA_START) * \
                       (e / max(1, self.cfg.VAE_BETA_WARMUP))
        else:
            self.beta = self.cfg.VAE_BETA_END

    def training_step(self, batch, batch_idx):
        x = batch
        x_hat, mu, logv = self(x)
        loss = self.vae_loss(x, x_hat, mu, logv)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch
        x_hat, mu, logv = self(x)
        loss = self.vae_loss(x, x_hat, mu, logv)
        self.log('val_loss', loss, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.cfg.VAE_LR)

class VAEModel:
    """VAE wrapper for training and generation"""

    def __init__(self, config):
        self.cfg = config

    def train_and_generate(self, df, reg_cols, aux_cols):
        """Train VAE and generate synthetic data"""
        print_header("VAE Training", 2)

        # Get available columns
        avail_reg = [c for c in reg_cols if c in df.columns]
        avail_aux = [c for c in aux_cols if c in df.columns]

        if not avail_reg:
            print("  ⚠️ No regulated columns available")
            return pd.DataFrame(index=df.index)

        print(f"  Regulated features: {len(avail_reg)}")
        print(f"  Auxiliary features: {len(avail_aux)}")

        # Prepare data with lags
        base = df[avail_reg + avail_aux].copy()
        max_lag = 2

        lagged = pd.concat([
            base.shift(L).add_suffix(f"_lag{L}")
            for L in range(max_lag + 1)
        ], axis=1).dropna()

        if len(lagged) < 10:
            print("  ⚠️ Insufficient data for VAE")
            return pd.DataFrame(index=df.index)

        # Scale data
        scaler = MinMaxScaler()
        X = scaler.fit_transform(lagged)
        dump(scaler, self.cfg.ARTIFACTS['vae_scaler'])

        # Create dataset
        ds = DatasetFromTensor(torch.tensor(X, dtype=torch.float32))
        n_train = max(1, int(0.8 * len(ds)))
        train_ds, val_ds = random_split(ds, [n_train, len(ds) - n_train])

        print(f"  Training samples: {n_train}, Validation: {len(ds) - n_train}")

        # Initialize VAE
        vae = TimeSeriesVAE(X.shape[1], self.cfg)

        # Setup trainer
        trainer = pl.Trainer(
            max_epochs=self.cfg.VAE_EPOCHS,
            accelerator='gpu' if torch.cuda.is_available() else 'cpu',
            devices=1 if torch.cuda.is_available() else None,
            gradient_clip_val=1.0,
            deterministic=True,
            callbacks=[
                pl.callbacks.EarlyStopping(monitor='val_loss', patience=self.cfg.VAE_PATIENCE)
            ],
            logger=False,
            enable_progress_bar=True,
            enable_checkpointing=False
        )

        # Train
        print("  Training VAE...")
        trainer.fit(
            vae,
            DataLoader(train_ds, batch_size=self.cfg.VAE_BATCH, shuffle=True),
            DataLoader(val_ds, batch_size=self.cfg.VAE_BATCH, shuffle=False)
        )

        # Generate synthetic data
        print("  Generating synthetic data...")
        vae.eval()
        with torch.no_grad():
            X_hat, _, _ = vae(torch.tensor(X, dtype=torch.float32, device=vae.device))

        X_inv = scaler.inverse_transform(X_hat.cpu().numpy())

        # Extract current values (lag0)
        reg0_cols = [f"{c}_lag0" for c in avail_reg]
        out = pd.DataFrame(X_inv, index=lagged.index, columns=lagged.columns)[reg0_cols]
        out.columns = avail_reg
        out.to_csv(self.cfg.FILES['vae_out'])

        # Interpolate to full range
        return out.reindex(df.index).interpolate(method='linear', limit_direction='both')


In [8]:
# ============================================================
# 7. SCSI COMPUTATION
# ============================================================

class SCSIComputer:
    """Compute Supply Chain Stress Index from synthetic features"""

    def __init__(self, config):
        self.cfg = config

    def compute_scsi(self, timegan_out, vae_out, df_full):
        """Blend outputs and compute SCSI using PCA"""
        print_header("SCSI Computation", 2)

        # Create date index
        idx = pd.date_range(self.cfg.START_DATE, self.cfg.END_DATE, freq='MS')

        # Prepare synthetic data
        tg = timegan_out.reindex(idx) if not timegan_out.empty else pd.DataFrame(index=idx)
        vae = vae_out.reindex(idx) if not vae_out.empty else pd.DataFrame(index=idx)

        # Initialize features dataframe
        synth = pd.DataFrame(index=idx)

        # Add TimeGAN features
        for col in tg.columns:
            if not tg[col].isna().all():
                synth[f"tg_{col}"] = tg[col]

        # Add VAE features
        for col in vae.columns:
            if not vae[col].isna().all():
                synth[f"vae_{col}"] = vae[col]

        # Add external features
        external_cols = [c for c in df_full.columns if any(x in c for x in
                        ['GSCPI', 'GPR', '_trade', '_air', 'sin_month', 'cos_month', 'policy_dummy'])]

        for col in external_cols:
            if col in df_full.columns:
                synth[col] = df_full.reindex(idx)[col]

        # Clean data
        synth = synth.dropna(axis=1, how='all')
        synth = synth.fillna(method='ffill').fillna(method='bfill')
        synth = synth.dropna()

        if synth.empty:
            print("  ⚠️ No valid data for SCSI computation")
            return synth

        print(f"  Features for SCSI: {synth.shape}")

        # Standardize and compute PCA
        scaler = StandardScaler()
        X = scaler.fit_transform(synth)
        dump(scaler, self.cfg.ARTIFACTS['std_scaler'])

        # PCA for dimensionality reduction
        pca = PCA(n_components=1, random_state=self.cfg.SEED)
        scsi_values = pca.fit_transform(X).flatten()
        dump(pca, self.cfg.ARTIFACTS['pca_model'])

        # Ensure positive orientation
        if scsi_values.mean() < 0:
            scsi_values *= -1

        synth['SCSI'] = scsi_values

        print(f"  Explained variance: {pca.explained_variance_ratio_[0]:.3f}")

        # Save features
        synth.to_csv(self.cfg.FILES['scsi_features'])

        return synth


In [9]:
# ============================================================
# 8. VISUALIZATION & REPORTING
# ============================================================
from datetime import datetime as dt_class


class Visualizer:
    """Create visualizations and reports"""

    def __init__(self, config):
        self.cfg = config

    def plot_scsi(self, scsi_df):
        """Create comprehensive SCSI visualization"""
        import os
        import numpy as np
        import pandas as pd
        import matplotlib.pyplot as plt

        if 'SCSI' not in scsi_df.columns:
            print("  ⚠️ SCSI column not found")
            return

        # Ensure proper index type & order
        scsi_df = scsi_df.copy()
        scsi_df.index = pd.to_datetime(scsi_df.index)
        scsi_df = scsi_df.sort_index()

        # Make sure output folder exists
        os.makedirs(os.path.dirname(self.cfg.FILES['scsi_plot']), exist_ok=True)

        # Create figure with subplots
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))

        # Main SCSI plot
        ax1 = axes[0, 0]
        ax1.plot(scsi_df.index, scsi_df['SCSI'], linewidth=2, color='steelblue')
        ax1.fill_between(scsi_df.index, scsi_df['SCSI'], alpha=0.3, color='steelblue')
        ax1.set_title('Supply Chain Stress Index (SCSI)', fontsize=14, fontweight='bold')
        ax1.set_xlabel('Date')
        ax1.set_ylabel('SCSI Value')
        ax1.grid(True, alpha=0.3)
        ax1.axhline(y=0, color='red', linestyle='--', alpha=0.5)

        # COVID period shading (overlap-safe)
        covid_start = pd.Timestamp('2020-03-01')
        covid_end   = pd.Timestamp('2022-04-01')
        idx0, idx1  = scsi_df.index[0], scsi_df.index[-1]
        shade_start = max(covid_start, idx0)
        shade_end   = min(covid_end, idx1)
        if shade_start < shade_end:
            ax1.axvspan(shade_start, shade_end, alpha=0.2, color='red', label='COVID Period')
            ax1.legend()

        # Distribution
        ax2 = axes[0, 1]
        ax2.hist(scsi_df['SCSI'], bins=30, edgecolor='black', alpha=0.7, color='steelblue')
        ax2.set_title('SCSI Distribution', fontsize=14, fontweight='bold')
        ax2.set_xlabel('SCSI Value')
        ax2.set_ylabel('Frequency')
        ax2.axvline(x=scsi_df['SCSI'].mean(), color='red', linestyle='--',
                    label=f'Mean: {scsi_df['SCSI'].mean():.3f}')
        ax2.legend()

        # Rolling statistics (avoid NaNs at start)
        ax3 = axes[1, 0]
        rolling_mean = scsi_df['SCSI'].rolling(window=6, min_periods=1).mean()
        rolling_std  = scsi_df['SCSI'].rolling(window=6, min_periods=1).std()
        ax3.plot(scsi_df.index, rolling_mean, label='6-Month Mean', linewidth=2)
        ax3.fill_between(scsi_df.index,
                         rolling_mean - rolling_std,
                         rolling_mean + rolling_std,
                         alpha=0.3, label='±1 Std Dev')
        ax3.set_title('SCSI Trend Analysis', fontsize=14, fontweight='bold')
        ax3.set_xlabel('Date')
        ax3.set_ylabel('SCSI Value')
        ax3.legend()
        ax3.grid(True, alpha=0.3)

        # Stress levels
        ax4 = axes[1, 1]
        stress_levels = pd.cut(scsi_df['SCSI'],
                               bins=[-np.inf, -1, 0, 1, np.inf],
                               labels=['Low', 'Normal', 'Elevated', 'High'])
        stress_counts = stress_levels.value_counts()
        colors = ['green', 'blue', 'orange', 'red']
        ax4.pie(stress_counts, labels=stress_counts.index, autopct='%1.1f%%',
                colors=colors, startangle=90)
        ax4.set_title('Stress Level Distribution', fontsize=14, fontweight='bold')

        plt.tight_layout()
        plt.savefig(self.cfg.FILES['scsi_plot'], dpi=150, bbox_inches='tight')
        plt.close(fig)

        print(f"  ✅ Visualization saved to {self.cfg.FILES['scsi_plot']}")

    def generate_report(self, final_df, runtime_info):
        """Generate validation report with proper JSON serialization"""
        import os, json, math
        import numpy as np
        import pandas as pd

        report = {
            'timestamp': dt_class.now().isoformat(),
            'configuration': {
                'start_date': self.cfg.START_DATE,
                'end_date': self.cfg.END_DATE,
                'test_mode': bool(self.cfg.TEST_MODE),
                'features_used': int(len(final_df.columns))
            },
            'data_quality': {
                'rows': int(len(final_df)),
                'columns': int(len(final_df.columns)),
                'missing_values': int(final_df.isna().sum().sum()),
                'completeness': float(
                    1 - final_df.isna().sum().sum() / max(1, (len(final_df) * len(final_df.columns)))
                )
            },
            'scsi_statistics': {},
            'runtime': {},
            'validation': {}
        }

        # Convert runtime_info to safe types
        for key, value in (runtime_info or {}).items():
            if isinstance(value, (np.number, int, float)):
                val = float(value)
                report['runtime'][str(key)] = None if (math.isnan(val) or math.isinf(val)) else val
            else:
                report['runtime'][str(key)] = str(value)

        # SCSI statistics
        if 'SCSI' in final_df.columns:
            scsi = final_df['SCSI']
            report['scsi_statistics'] = {
                'mean': float(scsi.mean()),
                'std': float(scsi.std()),
                'min': float(scsi.min()),
                'max': float(scsi.max()),
                'median': float(scsi.median()),
                'skewness': float(scsi.skew()),
                'kurtosis': float(scsi.kurt())
            }

            # Validation checks
            std_check = scsi.std() > 0.01
            range_check = (scsi.min() >= -10) and (scsi.max() <= 10)
            outlier_check = len(scsi[(scsi < -5) | (scsi > 5)]) < len(scsi) * 0.1

            report['validation']['scsi_variability'] = bool(getattr(std_check, 'item', lambda: std_check)())
            report['validation']['scsi_range_valid'] = bool(getattr(range_check, 'item', lambda: range_check)())
            report['validation']['no_extreme_outliers'] = bool(getattr(outlier_check, 'item', lambda: outlier_check)())

        # More validation checks
        missing_check = final_df.isna().sum().sum() == 0
        expected_len = len(pd.date_range(self.cfg.START_DATE, self.cfg.END_DATE, freq='MS'))
        continuity_check = len(final_df) == expected_len

        report['validation']['no_missing_values'] = bool(getattr(missing_check, 'item', lambda: missing_check)())
        report['validation']['date_continuity'] = bool(getattr(continuity_check, 'item', lambda: continuity_check)())

        # Save report using safe JSON serialization
        def simple_to_jsonable(obj):
            if isinstance(obj, dict):
                return {str(k): simple_to_jsonable(v) for k, v in obj.items()}
            if isinstance(obj, (list, tuple)):
                return [simple_to_jsonable(item) for item in obj]
            if hasattr(obj, 'item'):               # numpy scalars
                return obj.item()
            if isinstance(obj, (np.bool_, bool)):
                return bool(obj)
            if isinstance(obj, (np.integer, int)):
                return int(obj)
            if isinstance(obj, (np.floating, float)):
                val = float(obj)
                return None if (math.isnan(val) or math.isinf(val)) else val
            return obj

        sanitized_report = simple_to_jsonable(report)
        os.makedirs(os.path.dirname(self.cfg.FILES['validation_report']), exist_ok=True)
        with open(self.cfg.FILES['validation_report'], 'w', encoding='utf-8') as f:
            json.dump(sanitized_report, f, indent=2, ensure_ascii=False)
        print(f"✅ Validation report saved to {self.cfg.FILES['validation_report']}")

        return report  # Return the original report for use in Pipeline


In [None]:
# Reporting
# def generate_report(self, final_df, runtime_info):
    """Generate validation report with proper JSON serialization"""
    from datetime import datetime as dt_class

    report = {
        'timestamp': dt_class.now().isoformat(),
        'configuration': {
            'start_date': self.cfg.START_DATE,
            'end_date': self.cfg.END_DATE,
            'test_mode': bool(self.cfg.TEST_MODE),  # Ensure Python bool
            'features_used': int(len(final_df.columns))
        },
        'data_quality': {
            'rows': int(len(final_df)),
            'columns': int(len(final_df.columns)),
            'missing_values': int(final_df.isna().sum().sum()),
            'completeness': float(1 - final_df.isna().sum().sum() / (len(final_df) * len(final_df.columns)))
        },
        'scsi_statistics': {},
        'runtime': {},
        'validation': {}
    }

    # Convert runtime_info to safe types
    # for key, value in runtime_info.items():
        if isinstance(value, (int, float, np.number)):
            report['runtime'][str(key)] = float(value)
        else:
            report['runtime'][str(key)] = str(value)

    # SCSI statistics
    # if 'SCSI' in final_df.columns:
        scsi = final_df['SCSI']
        report['scsi_statistics'] = {
            'mean': float(scsi.mean()),
            'std': float(scsi.std()),
            'min': float(scsi.min()),
            'max': float(scsi.max()),
            'median': float(scsi.median()),
            'skewness': float(scsi.skew()),
            'kurtosis': float(scsi.kurt())
        }

        # Validation checks - FIXED: Proper boolean conversion
        std_check = scsi.std() > 0.01
        range_check = (scsi.min() >= -10) and (scsi.max() <= 10)
        outlier_check = len(scsi[(scsi < -5) | (scsi > 5)]) < len(scsi) * 0.1

        # Convert to Python bool properly
        report['validation']['scsi_variability'] = bool(std_check.item() if hasattr(std_check, 'item') else std_check)
        report['validation']['scsi_range_valid'] = bool(range_check.item() if hasattr(range_check, 'item') else range_check)
        report['validation']['no_extreme_outliers'] = bool(outlier_check.item() if hasattr(outlier_check, 'item') else outlier_check)

    # More validation checks
    # missing_check = final_df.isna().sum().sum() == 0
    # continuity_check = len(final_df) == len(pd.date_range(self.cfg.START_DATE, self.cfg.END_DATE, freq='MS'))

    # report['validation']['no_missing_values'] = bool(missing_check.item() if hasattr(missing_check, 'item') else missing_check)
    # report['validation']['date_continuity'] = bool(continuity_check.item() if hasattr(continuity_check, 'item') else continuity_check)

    # Save report using safe JSON serialization
    # try:
        # Use the Pipeline's to_jsonable method if available, or create a simple converter
        # def simple_to_jsonable(obj):
            if isinstance(obj, dict):
                return {str(k): simple_to_jsonable(v) for k, v in obj.items()}
            elif isinstance(obj, (list, tuple)):
                return [simple_to_jsonable(item) for item in obj]
            elif hasattr(obj, 'item'):  # numpy scalars
                return obj.item()
            elif isinstance(obj, (np.bool_, bool)):
                return bool(obj)
            elif isinstance(obj, (np.integer, int)):
                return int(obj)
            elif isinstance(obj, (np.floating, float)):
                return float(obj)
            else:
                return obj

        # sanitized_report = simple_to_jsonable(report)

        # with open(self.cfg.FILES['validation_report'], 'w') as f:
            # json.dump(sanitized_report, f, indent=2)
        # print(f"✅ Validation report saved to {self.cfg.FILES['validation_report']}")

    # except Exception as e:
        # print(f"❌ Failed to save validation report: {e}")
        # Try saving a minimal report
        # minimal_report = {
            'timestamp': dt_class.now().isoformat(),
            'error': f'Full report could not be saved: {str(e)}',
            'data_shape': [int(final_df.shape[0]), int(final_df.shape[1])]
        # }
        # with open(self.cfg.FILES['validation_report'], 'w') as f:
            json.dump(minimal_report, f, indent=2)


In [12]:
# ============================================================
# 9. MAIN PIPELINE
# ============================================================

import os, sys, time, json
import numpy as np
import pandas as pd

from datetime import datetime as dt_class

class Pipeline:
    """Main pipeline orchestrator"""

    def __init__(self, config=None, test_mode=False):
        self.cfg = config or Config(test_mode=test_mode)
        self.data_manager = DataManager(self.cfg)
        self.timegan = TimeGANModel(self.cfg)
        self.vae = VAEModel(self.cfg)
        self.scsi_computer = SCSIComputer(self.cfg)
        self.visualizer = Visualizer(self.cfg)

    def run(self):
        """Execute the complete pipeline"""

        print("\n" + "🚀" * 35)
        print("     TIMEGAN-VAE SCSI PIPELINE")
        print("🚀" * 35)
        print(f"Started: {dt_class.now().strftime('%Y-%m-%d %H:%M:%S')}")

        pipeline_start = time.time()
        runtime_info = {}

        try:
            # Setup
            setup_seeds(self.cfg.SEED)
            setup_gpu()

            # Step 1: Data Loading
            print_header("STEP 1/5: DATA LOADING", 1)
            t0 = time.time()

            # Upload files if in Colab
            if 'google.colab' in sys.modules:
                self.data_manager.upload_files()

            df_full = self.data_manager.load_and_prepare_data()
            runtime_info['data_loading'] = time.time() - t0

            # Step 2: TimeGAN
            print_header("STEP 2/5: TIMEGAN TRAINING", 1)
            t0 = time.time()

            market_features = self.cfg.MARKET + [c for c in df_full.columns
                                                 if 'freight' in c.lower() or 'cargo' in c.lower()]
            market_features = sorted(set([c for c in market_features if c in df_full.columns]))

            aux_features = ['sin_month', 'cos_month', 'policy_dummy', 'GSCPI_NYFED']
            aux_features = [c for c in aux_features if c in df_full.columns]

            timegan_out = self.timegan.train_and_generate(df_full, market_features, aux_features)
            runtime_info['timegan'] = time.time() - t0

            # Step 3: VAE
            print_header("STEP 3/5: VAE TRAINING", 1)
            t0 = time.time()

            aux_vae = ['sin_month', 'cos_month', 'policy_dummy'] + \
                     [c for c in df_full.columns if 'trade' in c.lower()]
            aux_vae = [c for c in aux_vae if c in df_full.columns]

            vae_out = self.vae.train_and_generate(df_full, self.cfg.REGULATED, aux_vae)
            runtime_info['vae'] = time.time() - t0

            # Step 4: SCSI Computation
            print_header("STEP 4/5: SCSI COMPUTATION", 1)
            t0 = time.time()

            scsi_df = self.scsi_computer.compute_scsi(timegan_out, vae_out, df_full)
            runtime_info['scsi'] = time.time() - t0

            # Step 5: Final Output
            print_header("STEP 5/5: FINAL OUTPUT", 1)
            t0 = time.time()

            # Create final dataset
            final_df = df_full.copy()
            if 'SCSI' in scsi_df.columns:
                final_df['SCSI'] = scsi_df['SCSI']

            # Ensure no NaNs
            final_df = final_df.fillna(method='ffill').fillna(method='bfill')
            if final_df.isna().any().any():
                numeric_cols = final_df.select_dtypes(include=[np.number]).columns
                final_df[numeric_cols] = final_df[numeric_cols].interpolate(method='linear')

            # Save final output
            os.makedirs(os.path.dirname(self.cfg.FILES['final_output']) or ".", exist_ok=True)  # NEW
            final_df.to_csv(self.cfg.FILES['final_output'])

            # Visualization
            self.visualizer.plot_scsi(scsi_df)

            # Generate report (this call already saves the validation report internally)
            runtime_info['output'] = time.time() - t0
            runtime_info['total']  = time.time() - pipeline_start

            report = self.visualizer.generate_report(final_df, runtime_info)

            # ---- Save a separate run summary (do NOT overwrite the validation report) ----
            # Reuse the validation report directory to avoid missing Config.OUT
            val_report_path = self.cfg.FILES.get('validation_report', './scsi_output/validation_report.json')
            out_dir = os.path.dirname(val_report_path) or "./scsi_output"
            os.makedirs(out_dir, exist_ok=True)

            summary_path = os.path.join(self.cfg.OUT, "run_summary.json")
            payload = {
            "created_at": dt_class.now().isoformat(),
            "shapes": {
                "final_df": [int(final_df.shape[0]), int(final_df.shape[1])],
            },
            "files": {**self.cfg.FILES},  # paths as strings
            "runtime": {k: float(v) if isinstance(v, (int, float, np.number)) else str(v)
                        for k, v in runtime_info.items()},
            "report": report,  # already JSON-safe via our helper
                    }

            # Save summary safely
            save_json_safely(payload, summary_path)
            print(f"📝 Run summary saved to {summary_path}")

            # Display results
            self._display_results(final_df, report, runtime_info)
            return final_df


        except Exception as e:
            print(f"\n❌ Pipeline failed: {str(e)}")
            raise

    def _display_results(self, final_df, report, runtime_info):
        """Display pipeline results"""

        print("\n" + "="*70)
        print("  ✅ PIPELINE COMPLETED SUCCESSFULLY!")
        print("="*70)

        # Runtime summary
        print("\n⏱️  RUNTIME SUMMARY:")
        for step, duration in runtime_info.items():
            print(f"  {step.replace('_', ' ').title():20} {format_time(duration)}")

        # Data summary
        print("\n📊 DATA SUMMARY:")
        print(f"  Final shape:         {final_df.shape[0]} rows × {final_df.shape[1]} columns")
        print(f"  Date range:          {final_df.index[0].strftime('%Y-%m')} to {final_df.index[-1].strftime('%Y-%m')}")
        print(f"  Missing values:      {final_df.isna().sum().sum()}")

        # SCSI summary
        if 'scsi_statistics' in report:
            print("\n📈 SCSI STATISTICS:")
            stats = report['scsi_statistics']
            print(f"  Range:               [{stats['min']:.3f}, {stats['max']:.3f}]")
            print(f"  Mean ± Std:          {stats['mean']:.3f} ± {stats['std']:.3f}")
            print(f"  Median:              {stats['median']:.3f}")

        # Validation summary
        if 'validation' in report:
            print("\n✓ VALIDATION:")
            all_passed = True
            for check, passed in report['validation'].items():
                status = "✅" if passed else "❌"
                print(f"  {status} {check.replace('_', ' ').title()}")
                if not passed:
                    all_passed = False

            if all_passed:
                print("\n🎉 All validation checks passed!")
            else:
                print("\n⚠️ Some validation checks failed. Review the report for details.")

        # File outputs
        print("\n💾 OUTPUT FILES:")
        for desc, path in self.cfg.FILES.items():
            if os.path.exists(path):
                size = os.path.getsize(path) / 1024  # KB
                print(f"  {desc:20} {path} ({size:.1f} KB)")

        # Preview
        print("\n📋 DATA PREVIEW:")
        if 'SCSI' in final_df.columns:
            preview_cols = ['SCSI'] + [c for c in self.cfg.MARKET if c in final_df.columns][:3]
            print(final_df[preview_cols].head())
        else:
            print(final_df.head())

        print("\n" + "🎉"*35)
        print("  Your Supply Chain Stress Index is ready!")
        print("🎉"*35)


In [13]:
# ============================================================
# 10. EXECUTION
# ============================================================

def main(test_mode=False):
    """Main execution function"""

    # Initialize and run pipeline
    pipeline = Pipeline(test_mode=test_mode)
    final_result = pipeline.run()

    return final_result

# Execute pipeline
if __name__ == "__main__":
    # Set test_mode=True for quick testing with reduced epochs
    final_result = main(test_mode=False)

    print("\n✅ Results stored in 'final_result' variable")
    print("📊 Use final_result.to_csv('your_file.csv') to save for MR-TFT")


INFO:lightning_fabric.utilities.seed:Seed set to 42



🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀
     TIMEGAN-VAE SCSI PIPELINE
🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀
Started: 2025-08-20 21:39:38
🎲 Random seeds set to 42
🖥️ GPU configured: 1 device(s) available

  STEP 1/5: DATA LOADING

📍 FILE UPLOAD
--------------------------------------------------
✅ Found: VAE_files.zip
✅ Found: gscpi_data.xlsx
✅ Found: data_gpr_export.xlsx
✅ Found: monthly_trade_dummy.csv
✅ Found: monthly_air_cargo_from_spline.csv

  DATA LOADING & PREPARATION
  ▶ Loading external datasets...
    ✔ GSCPI: (325, 1)
    ✔ GPR: (1501, 10)
    ✔ Trade: (49, 4)
    ✔ Air cargo: (60, 1)
  ▶ Imputing missing values...
    ℹ️ KNN stage on 24 numeric cols, 72 rows
    ⚠️ Column count mismatch after KNN: num_cols=24, X_imp=23. Trimming to min.
    ✔ Imputed 72 values (remaining: 0)
  ▶ Selecting relevant features...
    ✔ Selected 25 features (from 27)

✅ Data prepared: 72 rows × 25 features

  STEP 2/5: TIMEGAN TRAINING

📍 TimeGAN Training
---------------------------------------------

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 4.3 K 
1 | fc_mu   | Linear     | 264   
2 | fc_logv | Linear     | 264   
3 | decoder | Sequential | 4.5 K 
---------------------------------------
9.3 K     Trainable params
0         Non-trainable params
9.3 K     Total params
0.037     Total estimated model params size (MB)



  STEP 3/5: VAE TRAINING

📍 VAE Training
--------------------------------------------------
  Regulated features: 4
  Auxiliary features: 7
  Training samples: 56, Validation: 14
  Training VAE...


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


  Generating synthetic data...

  STEP 4/5: SCSI COMPUTATION

📍 SCSI Computation
--------------------------------------------------
  Features for SCSI: (72, 26)
  Explained variance: 0.514

  STEP 5/5: FINAL OUTPUT
  ✅ Visualization saved to /content/scsi_output/scsi_timeseries.png
✅ Validation report saved to /content/scsi_output/validation_report.json

❌ Pipeline failed: 'Config' object has no attribute 'OUT'


AttributeError: 'Config' object has no attribute 'OUT'