# ExpertFusion: Mixture of Experts Model

This notebook implements a Mixture of Experts (MoE) model for financial market analysis.

# 1. Setup & Keys

In [1]:
import os
import re
import datetime
import time
import requests
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Union
from bs4 import BeautifulSoup
from sklearn.metrics import mean_squared_error, mean_absolute_error
import openai
import sqlite3
import yfinance as yf
import wrds
import pickle
from functools import lru_cache
import json

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Global variables
DB_FILE = "sp500_daily_news.db"  # Update if needed
_LAST_QUERY_TIME = 0
_MIN_QUERY_INTERVAL = 1.0  # seconds

# Initialize OpenAI client
openai_api_key = os.getenv('OPENAI_API_KEY')
if not openai_api_key:
    raise ValueError("Please set the OPENAI_API_KEY environment variable")

client = openai.OpenAI(api_key=openai_api_key)

# 2. GPT Factor & Explanation

Functions for generating factor scores and explanations using GPT-4o-mini.

In [None]:
def call_gpt_factor_and_expl(system_msg: str, user_msg: str) -> Tuple[float, str]:
    try:
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_msg},
                {"role": "user", "content": user_msg}
            ],
            temperature=0.7,
            max_tokens=150
        )
        text = response.choices[0].message.content
        # Debug: print raw GPT response
        print(f"[DEBUG] Raw GPT response:\n{text}\n")
        factor_match = re.search(r"Factor:\s*(-?\d*\.?\d+)", text)
        expl_match = re.search(r"Explanation:\s*(.+)", text, re.DOTALL)
        factor = float(factor_match.group(1)) if factor_match else 0.0
        explanation = expl_match.group(1).strip() if expl_match else ""
        if not explanation:
            print("[WARN] No explanation parsed from GPT response.")
        return factor, explanation
    except Exception as e:
        print(f"[ERROR] GPT call failed: {e}")
        return 0.0, ""

# 3. Database Access & WRDS Data

In [None]:
def get_wrds_connection():
    try:
        wrds_username = os.getenv('WRDS_USERNAME')
        if wrds_username:
            print(f"[INFO] Using WRDS username: {wrds_username}")
            db = wrds.Connection(wrds_username=wrds_username)
        else:
            db = wrds.Connection()
        print("[INFO] Connected to WRDS.")
        return db
    except Exception as e:
        print(f"[ERROR] Failed to connect to WRDS: {e}")
        return None

def get_tickers_from_db(skip_wrds=False) -> List[str]:
    try:
        SP500_LIST_FILE = "sp500_list_wiki.json"
        if os.path.exists(SP500_LIST_FILE):
            with open(SP500_LIST_FILE, 'r') as f:
                sp500_data = json.load(f)
            tickers = sp500_data['tickers']
            fetch_time = sp500_data['fetch_time']
            print(f"[INFO] Loaded {len(tickers)} tickers from {SP500_LIST_FILE} (fetched at {fetch_time})")
            if skip_wrds:
                print("[INFO] Skipping WRDS connection, using tickers directly from file")
                return tickers
        else:
            raise FileNotFoundError(f"{SP500_LIST_FILE} not found. Please run fetch_sp500_list.py with VPN first.")
        db = get_wrds_connection()
        if not db:
            return tickers
        print("[INFO] Getting ticker mappings from WRDS...")
        today = pd.Timestamp.today()
        today_str = today.strftime('%Y-%m-%d')
        ticker_query = f"""
            SELECT DISTINCT n.permno, n.ticker, n.namedt, n.nameendt
            FROM crsp.msenames n
            WHERE n.ticker IN ({','.join(map(lambda x: f"'{x}'", tickers))})
              AND n.namedt <= '{today_str}'
              AND (n.nameendt >= '{today_str}' OR n.nameendt IS NULL)
            ORDER BY n.namedt DESC
        """
        ticker_df = db.raw_sql(ticker_query, date_cols=['namedt', 'nameendt'])
        print(f"[INFO] Retrieved {len(ticker_df)} ticker mappings from msenames")
        if ticker_df.empty:
            raise ValueError("No ticker mappings found in WRDS.")
        print("[INFO] Sample of ticker mappings:")
        print(ticker_df.head())
        ticker_df = ticker_df.sort_values('namedt', ascending=False).drop_duplicates('permno')
        mapped_tickers = ticker_df['ticker'].unique().tolist()
        print(f"[INFO] Retrieved {len(mapped_tickers)} unique tickers from WRDS")
        return mapped_tickers
    except Exception as e:
        print(f"[ERROR] Failed to get tickers from WRDS: {e}")
        raise

def rate_limited_query(query: str, date_cols=None, params=None) -> pd.DataFrame:
    global _LAST_QUERY_TIME
    current_time = time.time()
    time_since_last = current_time - _LAST_QUERY_TIME
    if time_since_last < _MIN_QUERY_INTERVAL:
        time.sleep(_MIN_QUERY_INTERVAL - time_since_last)
    db = get_wrds_connection()
    if not db:
        raise ConnectionError("No WRDS connection available.")
    try:
        result = db.raw_sql(query, params=params, date_cols=date_cols)
        _LAST_QUERY_TIME = time.time()
        return result
    except Exception as e:
        print(f"[ERROR] WRDS query failed: {e}")
        raise

@lru_cache(maxsize=100)
def get_fundamentals_wrds(ticker: str, date: datetime.date) -> Dict:
    try:
        wrds_data = load_cached_data("wrds", date=date.strftime('%Y-%m-%d'))
        if wrds_data and ticker in wrds_data:
            return wrds_data[ticker].get(date.isoformat(), {})
        return {}
    except Exception as e:
        print(f"[ERROR] Failed to get fundamental data: {e}")
        raise

def cache_wrds_data(date: str) -> Dict:
    print("[INFO] Caching WRDS data...")
    data = {}
    try:
        db = get_wrds_connection()
        if not db:
            raise ConnectionError("No WRDS connection available for caching.")
        tickers = get_tickers_from_db()
        for ticker in tickers:
            try:
                date_obj = datetime.datetime.strptime(date, "%Y-%m-%d").date()
                fundamentals = get_fundamentals_wrds(ticker, date_obj)
                if fundamentals:
                    data[ticker] = {date: fundamentals}
            except Exception as e:
                print(f"[ERROR] Failed to get WRDS data for {ticker}: {e}")
                continue
        if data:
            print(f"[INFO] Saving WRDS data for {len(data)} tickers")
            with open(f'wrds_cache_{date}.pkl', 'wb') as f:
                pickle.dump(data, f)
        else:
            print("[WARN] No WRDS data to cache")
        return data
    except Exception as e:
        print(f"[ERROR] Failed to cache WRDS data: {e}")
        raise

def cache_yfinance_data(tickers: List[str], start_date: str, end_date: str) -> Dict:
    print("[INFO] Caching yfinance data...")
    data = {}
    for ticker in tickers:
        try:
            df = get_market_data(ticker, start_date, end_date)
            if not df.empty:
                data[ticker] = df
        except Exception as e:
            print(f"[ERROR] yfinance caching failed for {ticker}: {e}")
            continue
    if data:
        print(f"[INFO] Saving yfinance data for {len(data)} tickers")
        with open(f'yfinance_cache_{start_date}_{end_date}.pkl', 'wb') as f:
            pickle.dump(data, f)
    else:
        print("[WARN] No yfinance data to cache")
    return data

def load_cached_data(data_type: str, date: str = None, start_date: str = None, end_date: str = None):
    if data_type == "wrds":
        cache_file = f'wrds_cache_{date}.pkl'
    elif data_type == "yfinance":
        cache_file = f'yfinance_cache_{start_date}_{end_date}.pkl'
    elif data_type == "macro_uncertainty":
        cache_file = f'macro_uncertainty_cache_{start_date}_{end_date}.pkl'
    else:
        raise ValueError(f"Unknown data type: {data_type}")
    if not os.path.exists(cache_file):
        raise FileNotFoundError(f"Cache file not found: {cache_file}")
    try:
        with open(cache_file, 'rb') as f:
            data = pickle.load(f)
        print(f"[INFO] Loaded {data_type} data from cache: {cache_file}")
        if data_type == "macro_uncertainty":
            if not isinstance(data.index, (pd.DatetimeIndex, pd.PeriodIndex)):
                if isinstance(data.index, pd.RangeIndex):
                    print(f"[DEBUG] Reconstructing macro_uncertainty index using start_date {start_date} and end_date {end_date}")
                    new_index = pd.date_range(start=start_date, end=end_date, freq='D').date
                    data.index = new_index
                else:
                    raise TypeError(f"Unexpected macro_df index type: {type(data.index)}. Full index: {data.index}")
        return data
    except Exception as e:
        print(f"[ERROR] Failed to load {data_type} cache: {e}")
        raise

def get_news_from_db(ticker: str, date: datetime.date) -> List[str]:
    try:
        conn = sqlite3.connect(DB_FILE)
        c = conn.cursor()
        c.execute("SELECT headlines FROM news_storage WHERE ticker = ? AND date = ?", (ticker, date.isoformat()))
        result = c.fetchone()
        conn.close()
        if result:
            return result[0].split(" || ")
        return []
    except Exception as e:
        print(f"[ERROR] Failed to get news from DB: {e}")
        raise

def get_date_range_from_db() -> Tuple[datetime.date, datetime.date]:
    try:
        conn = sqlite3.connect(DB_FILE)
        c = conn.cursor()
        c.execute("SELECT MIN(date), MAX(date) FROM news_storage")
        start_date_str, end_date_str = c.fetchone()
        conn.close()
        if start_date_str and end_date_str:
            return (datetime.date.fromisoformat(start_date_str),
                    datetime.date.fromisoformat(end_date_str))
        end_date = datetime.date.today()
        start_date = end_date - datetime.timedelta(days=30)
        return start_date, end_date
    except Exception as e:
        print(f"[ERROR] Failed to get date range from DB: {e}")
        raise

def get_fundamentals_av(ticker: str) -> Dict:
    try:
        db = get_wrds_connection()
        if not db:
            raise ConnectionError("No WRDS connection available.")
        print(f"[INFO] Querying fundamental data for {ticker} from WRDS...")
        query = """
            SELECT c.gvkey, c.datadate, 
                   c.prccm/c.ceqq AS pe_ratio,
                   c.ibq/c.ceqq AS roe
            FROM comp.fundq AS c
            JOIN crsp.ccmxpf_linktable AS l ON c.gvkey = l.gvkey
            JOIN crsp.dsenames AS d ON l.lpermno = d.permno
            WHERE d.ticker = %s AND c.ceqq > 0
            ORDER BY c.datadate DESC
            LIMIT 1
        """
        fund_data = rate_limited_query(query, date_cols=['datadate'], params=(ticker,))
        if fund_data.empty:
            raise ValueError(f"No fundamental data found for {ticker}")
        result = {"pe_ratio": str(fund_data.iloc[0]['pe_ratio']),
                  "roe": str(fund_data.iloc[0]['roe'])}
        return result
    except Exception as e:
        print(f"[ERROR] Failed to get fundamental data from WRDS: {e}")
        raise

# 4. FOMC Data

In [None]:
class FOMCDateBasedFetcher:
    BASE_URL = "https://www.federalreserve.gov"
    CAL_URL  = f"{BASE_URL}/monetarypolicy/fomccalendars.htm"
    def __init__(self):
        self.all_meetings = self.parse_calendars_for_all_years()
        self.all_meetings.sort(key=lambda x: x["end_date"])
        for m in self.all_meetings:
            m["statement_text"] = None
    def fetch_fomc_calendars_page(self) -> str:
        try:
            resp = requests.get(self.CAL_URL)
            if not resp.ok:
                raise ValueError(f"FOMC calendars fetch failed with HTTP {resp.status_code}")
            return resp.text
        except Exception as e:
            print("[ERROR] fetch_fomc_calendars_page =>", e)
            raise
    def fetch_fomc_statement_text(self, link_url: str) -> str:
        try:
            resp = requests.get(link_url)
            if not resp.ok:
                raise ValueError(f"Statement fetch failed with HTTP {resp.status_code}")
            soup = BeautifulSoup(resp.text, "html.parser")
            main_div = soup.find("div", class_="col-xs-12 col-sm-8 col-md-8")
            if not main_div:
                raise ValueError("FOMC statement div not found")
            return main_div.get_text(separator="\n", strip=True)
        except Exception as e:
            print("[ERROR] fetch_fomc_statement_text =>", e)
            raise
    def parse_meeting_divs_for_year(self, html: str, year: int) -> list:
        soup = BeautifulSoup(html, "html.parser")
        meeting_divs = soup.find_all("div", class_=re.compile(r"fomc-meeting"))
        results = []
        for div in meeting_divs:
            raw_text = div.get_text(" ", strip=True)
            m = re.search(
                r"(January|February|March|April|May|June|July|August|September|October|November|December)\s+(\d{1,2})(?:-(\d{1,2}))?",
                raw_text
            )
            if not m:
                continue
            month_str = m.group(1)
            day1 = int(m.group(2))
            day2 = m.group(3)
            month_map = {'January': 1, 'February': 2, 'March': 3, 'April': 4,
                         'May': 5, 'June': 6, 'July': 7, 'August': 8,
                         'September': 9, 'October': 10, 'November': 11, 'December': 12}
            month_int = month_map.get(month_str, 1)
            start_dt = datetime.date(year, month_int, day1)
            end_dt   = datetime.date(year, month_int, int(day2)) if day2 else start_dt
            link_tag = div.find("a", href=re.compile(r"/newsevents/pressreleases/monetary.*\.htm"))
            link_href = ""
            if link_tag and link_tag.has_attr("href"):
                link_href = link_tag["href"]
                if link_href.startswith("/"):
                    link_href = self.BASE_URL + link_href
            snippet = raw_text[:120]
            results.append({"start_date": start_dt, "end_date": end_dt, "link": link_href, "snippet": snippet})
        return results
    def parse_calendars_for_all_years(self) -> list:
        out = []
        full_html = self.fetch_fomc_calendars_page()
        if not full_html:
            raise ValueError("No HTML received from FOMC calendars page.")
        soup = BeautifulSoup(full_html, "html.parser")
        panels = soup.find_all("div", class_="panel panel-default")
        for panel in panels:
            heading = panel.find("div", class_="panel-heading")
            if not heading:
                continue
            heading_txt = heading.get_text(" ", strip=True)
            m = re.search(r"(\d{4}) FOMC Meetings", heading_txt)
            if not m:
                continue
            year = int(m.group(1))
            panel_html = str(panel)
            subset = self.parse_meeting_divs_for_year(panel_html, year)
            out.extend(subset)
        return out
    def get_most_recent_fomc_for(self, date_val: datetime.date) -> str:
        valid = [m for m in self.all_meetings if m["end_date"] <= date_val]
        if not valid:
            raise ValueError(f"No FOMC meeting found before {date_val}")
        valid_desc = sorted(valid, key=lambda x: x["end_date"], reverse=True)
        for meeting in valid_desc:
            link = meeting["link"]
            if not link:
                continue
            if meeting["statement_text"]:
                return meeting["statement_text"]
            text = self.fetch_fomc_statement_text(link)
            if text:
                meeting["statement_text"] = text
                return text
        raise ValueError(f"No FOMC statement could be retrieved for date {date_val}")

# 5. Market Data

In [None]:
def flatten_columns(df: pd.DataFrame) -> pd.DataFrame:
    if isinstance(df.columns, pd.MultiIndex):
        df.columns = ["_".join(map(str, c)) if isinstance(c, tuple) else str(c) for c in df.columns]
    return df

def rename_yf_columns(df: pd.DataFrame, ticker: str) -> pd.DataFrame:
    rename_dict = {}
    suffix = f"_{ticker}"
    for c in df.columns:
        if c.endswith(suffix):
            newc = c.replace(suffix, "")
            rename_dict[c] = newc
    if rename_dict:
        df.rename(columns=rename_dict, inplace=True)
    return df

def get_market_data(ticker: str, start_date: str, end_date: str) -> pd.DataFrame:
    try:
        yf_ticker = ticker
        if any(x in ticker for x in ['.', '-', '/']):
            yf_ticker = ticker.replace('-', '.').replace('/', '.')
        print(f"[INFO] Getting market data for {ticker} from {start_date} to {end_date}")
        df = yf.download(yf_ticker, start=start_date, end=end_date, progress=False)
        if df.empty:
            print(f"[WARN] No data returned for {ticker}")
            return pd.DataFrame()
        if isinstance(df.columns, pd.MultiIndex):
            df.columns = [col[0] + ' ' + col[1] if col[1] else col[0] for col in df.columns]
        col_map = {'Adj Close': 'close', 'Adj_Close': 'close', 'adj_close': 'close',
                   'Volume': 'volume', 'volume': 'volume'}
        df = df.rename(columns=col_map)
        df = df.reset_index()
        if 'index' in df.columns:
            df = df.rename(columns={'index': 'date'})
        elif 'Date' in df.columns:
            df = df.rename(columns={'Date': 'date'})
        return df
    except Exception as e:
        print(f"[ERROR] Failed to get market data for {ticker}: {e}")
        raise

# 6. Domain Experts

In [None]:
class MacroExpert:
    """Expert for macro-economic factors."""
    def __init__(self, macro_df: pd.DataFrame, fomcFetcher: FOMCDateBasedFetcher):
        self.macro_df = macro_df
        self.fomcFetcher = fomcFetcher
    def produce_factor(self, dt_val: datetime.date, ticker: str) -> float:
        try:
            if self.macro_df is None or self.macro_df.empty:
                raise ValueError("Macro data is missing or empty.")
            first_index = self.macro_df.index[0]
            print(f"[DEBUG] Macro data index type: {type(first_index)}, value: {first_index}")
            if isinstance(first_index, pd.Timestamp):
                hist_data = self.macro_df[self.macro_df.index.map(lambda x: x.date()) <= dt_val]
            elif isinstance(first_index, datetime.date):
                hist_data = self.macro_df[self.macro_df.index <= dt_val]
            else:
                raise TypeError(f"Unexpected macro_df index type: {type(first_index)}. Full index: {self.macro_df.index}")
            if len(hist_data) < 2:
                raise ValueError(f"Not enough macro data history before {dt_val}.")
            current = float(hist_data.iloc[-1]['value'])
            previous = float(hist_data.iloc[-2]['value'])
            fomc_text = self.fomcFetcher.get_most_recent_fomc_for(dt_val)
            fomc_snippet = fomc_text[:200] if fomc_text else "No FOMC statement"
            sys = (
                "You are a macro analysis expert. Consider how GDP and monetary policy specifically impact this company:\n"
                "1. Analyze the company's industry, business model, and macro sensitivities, considering both short and long-term impacts.\n"
                "2. Evaluate how current GDP and FOMC stance affect this specific business.\n"
                "3. Consider factors like discretionary vs essential products, consumer financing, interest rates, "
                "supply chain, global trade, currency exposure, and international revenue.\n"
                "4. Based on these factors, determine whether the overall macro environment is bullish or bearish for this company.\n"
                "5. Output a factor between -1 (very negative) and +1 (very positive), reflecting the bullish or bearish nature.\n"
                "Format => Factor: <float>\n"
                "Explanation: <text>"
            )
            usr = f"Company: {ticker}\nGDP={current:.2f} vs {previous:.2f}\nFOMC: {fomc_snippet}"
            fac, _ = call_gpt_factor_and_expl(sys, usr)
            return fac
        except Exception as e:
            print(f"[ERROR] MacroExpert failed: {e}")
            raise

class FundamentalExpert:
    """Expert for fundamental analysis."""
    def produce_factor(self, fundamentals: Dict) -> float:
        try:
            if not fundamentals:
                raise ValueError("No fundamental data provided.")
            pe = fundamentals.get('pe', '')
            roe = fundamentals.get('roe', '')
            sys = (
                "You are an expert in fundamental analysis. Evaluate the company's financial health based on the following:\n"
                "1. P/E ratio: Is the company undervalued (bullish) or overvalued (bearish)? Consider industry averages and market conditions.\n"
                "2. ROE: Is the company generating strong returns on equity? Does it indicate a healthy business model?\n"
                "3. Based on these financial metrics, determine if the stock is bullish or bearish.\n"
                "Output a factor between -1 (very negative) and +1 (very positive), with an explanation based on the metrics.\n"
                "Format => Factor: <float>\n"
                "Explanation: <text>"
            )
            usr = f"P/E={pe}, ROE={roe}"
            fac, _ = call_gpt_factor_and_expl(sys, usr)
            return fac
        except Exception as e:
            print(f"[ERROR] FundamentalExpert failed: {e}")
            raise

class NewsExpert:
    def produce_factor(self, ticker: str, date: datetime.date) -> float:
        try:
            headlines = get_news_from_db(ticker, date)
            if not headlines:
                print(f"[WARN] No news found for {ticker} on {date}. Returning default factor of 0.0.")
                return 0.0
            short = "\n".join(headlines[:3])
            sys = (
                "You are an expert in news sentiment analysis. Consider the news headlines for the company and evaluate:\n"
                "1. Are the headlines generally positive, neutral, or negative in terms of market sentiment?\n"
                "2. How does the news affect investor sentiment towards the company? Is the news likely to drive the stock price up or down?\n"
                "3. Based on your analysis, determine whether the overall sentiment is bullish or bearish for the stock.\n"
                "4. Output a factor between -1 (very negative) and +1 (very positive) reflecting the sentiment.\n"
                "Format => Factor: <float>\n"
                "Explanation: <text>"
            )
            usr = f"{short}"
            fac, _ = call_gpt_factor_and_expl(sys, usr)
            return fac
        except Exception as e:
            print(f"[ERROR] NewsExpert failed: {e}")
            raise

class TechnicalExpert:
    def produce_factor(self, df: pd.DataFrame, global_idx: int) -> float:
        try:
            # Filter for the current ticker and reset index.
            row = df.iloc[global_idx]
            ticker = row['ticker']
            group = df[df['ticker'] == ticker].reset_index(drop=True)
            positions = group.index[group['date'] == row['date']]
            if len(positions) == 0:
                raise ValueError(f"Could not find row for ticker {ticker} on {row['date']} in its group.")
            pos = positions[0]
            if pos < 4:
                print(f"[WARN] Not enough bars for technical analysis for ticker {ticker} on {row['date']}. Returning default factor of 0.0.")
                return 0.0
            pct_changes = group['close'].pct_change().values
            w5 = pct_changes[pos-4:pos+1]
            s = ", ".join(f"{x:.4f}" for x in w5 if not np.isnan(x))
            sys = (
                "You are an expert in technical analysis. Evaluate the following technical indicators for the stock:\n"
                "1. How do recent price and volume patterns reflect the overall trend in the market?\n"
                "2. Is the stock showing bullish or bearish signals based on technical indicators such as moving averages, momentum, or volume trends?\n"
                "3. Based on this analysis, is the stock in a bullish or bearish trend, and what is your overall sentiment?\n"
                "Output a factor between -1 (very negative) and +1 (very positive), representing the technical sentiment.\n"
                "Format => Factor: <float>\n"
                "Explanation: <text>"
            )
            usr = f"Last5 returns => {s}"
            fac, _ = call_gpt_factor_and_expl(sys, usr)
            return fac
        except Exception as e:
            print(f"[ERROR] TechnicalExpert failed: {e}")
            raise

class RiskExpert:
    def produce_factor(self, df: pd.DataFrame, global_idx: int) -> float:
        try:
            if global_idx < 5 or 'close' not in df.columns:
                raise ValueError("Not enough data for risk analysis.")
            rets = df['close'].pct_change().dropna().values
            w_5 = rets[global_idx-4:global_idx+1]
            stv = np.nanstd(w_5)
            sys = (
                "You are an expert in risk analysis. Evaluate the following risk metrics for the stock:\n"
                "1. What does the recent volatility and beta tell you about the stock's risk profile?\n"
                "2. Is the stock showing signs of higher volatility (bearish) or stability (bullish)? Consider if the risk is manageable or excessive.\n"
                "3. Based on these risk factors, do you consider the stock to be a higher or lower risk, and how does this affect its bullish or bearish nature?\n"
                "4. Output a factor between -1 (very negative) and +1 (very positive), reflecting the risk sentiment.\n"
                "Format => Factor: <float>\n"
                "Explanation: <text>"
            )
            usr = f"5-day std dev = {stv:.4f}"
            fac, _ = call_gpt_factor_and_expl(sys, usr)
            return fac
        except Exception as e:
            print(f"[ERROR] RiskExpert failed: {e}")
            raise

# 7. MoE Model

In [None]:
class GatingNetwork(nn.Module):
    def __init__(self, num_experts=5):
        super().__init__()
        self.input_norm = nn.BatchNorm1d(num_experts)
        self.fc = nn.Linear(num_experts, num_experts)
    def forward(self, expert_preds: torch.Tensor) -> torch.Tensor:
        x = self.input_norm(expert_preds)
        logits = self.fc(x)
        weights = F.softmax(logits / 0.5, dim=1)
        return weights

class MoEModel(nn.Module):
    def __init__(self, num_experts=5):
        super().__init__()
        self.gate = GatingNetwork(num_experts)
        self.scale = nn.Parameter(torch.tensor([0.05]))
    def forward(self, expert_preds: torch.Tensor) -> torch.Tensor:
        weights = self.gate(expert_preds)
        combined_pred = (expert_preds * weights).sum(dim=1)
        return torch.tanh(combined_pred) * self.scale

def combined_loss(pred: torch.Tensor, target: torch.Tensor, model: nn.Module) -> torch.Tensor:
    huber = F.huber_loss(pred, target, reduction='mean', delta=0.1)
    pred_direction = torch.sign(pred)
    target_direction = torch.sign(target)
    direction_loss = -torch.mean(pred_direction * target_direction)
    scale_reg = 0.1 * model.scale**2
    total_loss = huber + 0.3 * direction_loss + scale_reg
    return total_loss

def train_moe_model(recs: List[Dict], epochs=200):
    X = torch.tensor([r['expert_predictions'] for r in recs], dtype=torch.float32)
    y = torch.tensor([r['target'] for r in recs], dtype=torch.float32)
    dataset = torch.utils.data.TensorDataset(X, y)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32)
    model = MoEModel(num_experts=X.shape[1])
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
    best_val_loss = float('inf')
    patience = 20
    patience_counter = 0
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch_X, batch_y in train_loader:
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = combined_loss(outputs, batch_y, model)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch_X, batch_y in val_loader:
                outputs = model(batch_X)
                val_loss += combined_loss(outputs, batch_y, model).item()
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break
        scheduler.step(val_loss)
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Train Loss = {train_loss/len(train_loader):.4f}, Val Loss = {val_loss/len(val_loader):.4f}")
    return model

def process_single_prediction(prediction: float) -> float:
    return np.tanh(prediction) * 0.05

# 8. Build Dataset

In [None]:
def build_dataset(tickers: List[str] = None, start_date: str = None, end_date: str = None, fomcFetcher=None):
    print("[INFO] Building dataset...")
    print("[INFO] Loading tickers from file...")
    with open("sp500_list_wiki.json", "r") as f:
        data = json.load(f)
        tickers = data["tickers"]
        print(f"[INFO] Loaded {len(tickers)} tickers from sp500_list_wiki.json (fetched at {data['fetch_time']})")
    wrds_data = load_cached_data("wrds", date=start_date)
    yf_data = load_cached_data("yfinance", start_date=start_date, end_date=end_date)
    macro_data = load_cached_data("macro_uncertainty", start_date=start_date, end_date=end_date)
    if fomcFetcher is None:
        fomcFetcher = FOMCDateBasedFetcher()
    dfs = []
    for ticker in yf_data:
        df = yf_data[ticker].copy()
        if 'close' not in df.columns:
            if 'Close' in df.columns:
                df = df.rename(columns={'Close': 'close'})
            elif 'Adj Close' in df.columns:
                df = df.rename(columns={'Adj Close': 'close'})
        if 'volume' not in df.columns:
            if 'Volume' in df.columns:
                df = df.rename(columns={'Volume': 'volume'})
        df['ticker'] = ticker
        df['date'] = pd.to_datetime(df.index).date
        df['pe_ratio'] = df['date'].apply(lambda x: wrds_data.get(ticker, {}).get(x.strftime('%Y-%m-%d'), {}))
        df['news'] = df['date'].apply(lambda x: get_news_from_db(ticker, x))
        if isinstance(macro_data, pd.DataFrame) and not macro_data.empty:
            df['macro_data'] = df['date'].apply(lambda x: macro_data.loc[pd.Timestamp(x).strftime('%Y-%m-%d')]['value']
                                                 if pd.Timestamp(x).strftime('%Y-%m-%d') in macro_data.index
                                                 else macro_data['value'].iloc[-1])
        else:
            df['macro_data'] = 0
        if fomcFetcher is not None:
            df['fomc'] = df['date'].apply(lambda x: fomcFetcher.get_most_recent_fomc_for(x))
        else:
            df['fomc'] = ''
        dfs.append(df)
    combined_df = pd.concat(dfs, ignore_index=True)
    print("\n[DEBUG] DataFrame info:")
    print(combined_df.info())
    print("\n[DEBUG] First few rows:")
    print(combined_df.head())
    return combined_df

# 9. Compare MoE vs Single GPT

In [None]:
def date_based_train_test_split(rows: List[Dict], train_ratio=0.7) -> Tuple[List[Dict], List[Dict]]:
    df = pd.DataFrame(rows)
    df['date'] = pd.to_datetime(df['date'])
    unique_dates = df['date'].unique()
    unique_dates.sort()
    n_train_dates = int(len(unique_dates) * train_ratio)
    train_dates = unique_dates[:n_train_dates]
    test_dates = unique_dates[n_train_dates:]
    print(f"[INFO] Train dates: {train_dates[0]} to {train_dates[-1]}")
    print(f"[INFO] Test dates: {test_dates[0]} to {test_dates[-1]}")
    train_df = df[df['date'].isin(train_dates)]
    test_df = df[df['date'].isin(test_dates)]
    return train_df.to_dict('records'), test_df.to_dict('records')

def compare_moe_vs_single(records: List[Dict]):
    actual = np.array([r['true_return'] for r in records])
    moe_pred = np.array([r['moe_prediction'] for r in records])
    gpt_pred = np.array([r['single_prediction'] for r in records])
    moe_rmse = np.sqrt(mean_squared_error(actual, moe_pred))
    gpt_rmse = np.sqrt(mean_squared_error(actual, gpt_pred))
    moe_mae = mean_absolute_error(actual, moe_pred)
    gpt_mae = mean_absolute_error(actual, gpt_pred)
    moe_dir = np.mean((actual > 0) == (moe_pred > 0))
    gpt_dir = np.mean((actual > 0) == (gpt_pred > 0))
    print("\nModel Performance Comparison:")
    print("\nRoot Mean Squared Error (RMSE):")
    print(f"MoE Model: {moe_rmse:.4f}")
    print(f"GPT Model: {gpt_rmse:.4f}")
    print(f"Improvement: {((gpt_rmse - moe_rmse) / gpt_rmse * 100):.2f}%")
    print("\nMean Absolute Error (MAE):")
    print(f"MoE Model: {moe_mae:.4f}")
    print(f"GPT Model: {gpt_mae:.4f}")
    print(f"Improvement: {((gpt_mae - moe_mae) / gpt_mae * 100):.2f}%")
    print("\nDirectional Accuracy:")
    print(f"MoE Model: {moe_dir:.2%}")
    print(f"GPT Model: {gpt_dir:.2%}")
    print(f"Improvement: {((moe_dir - gpt_dir) / gpt_dir * 100):.2f}%")
    plt.figure(figsize=(12, 6))
    plt.plot(actual, label='Actual Returns', alpha=0.5)
    plt.plot(moe_pred, label='MoE Predictions', alpha=0.5)
    plt.plot(gpt_pred, label='GPT Predictions', alpha=0.5)
    plt.title('Model Predictions vs Actual Returns')
    plt.legend()
    plt.show()

# 10. Run Model

In [None]:
def run_model(df: pd.DataFrame, start_date: str, end_date: str):
    print("[INFO] Running model...")
    
    # Dataset Info
    print("\nDataset Summary:")
    print(f"Number of stocks: {len(df['ticker'].unique())}")
    print("\nSample of data:")
    print(df.head())
    print("\nColumns:")
    print(df.columns.tolist())
    
    # Rename columns if necessary
    if 'close' not in df.columns:
        if 'Close' in df.columns:
            df = df.rename(columns={'Close': 'close'})
        elif 'Adj Close' in df.columns:
            df = df.rename(columns={'Adj Close': 'close'})
    if 'volume' not in df.columns:
        if 'Volume' in df.columns:
            df = df.rename(columns={'Volume': 'volume'})
    
    # Debug: Calculating returns
    print("\n[DEBUG] Calculating returns...")
    df['returns'] = df.groupby('ticker')['close'].pct_change()
    print("Sample returns:")
    print(df[['ticker', 'date', 'close', 'returns']].head())
    
    # Debug: Creating technical features
    print("\n[DEBUG] Creating technical features...")
    df['volume_ma5'] = df.groupby('ticker')['volume'].rolling(window=5).mean().reset_index(0, drop=True)
    df['price_ma5'] = df.groupby('ticker')['close'].rolling(window=5).mean().reset_index(0, drop=True)
    df['price_ma10'] = df.groupby('ticker')['close'].rolling(window=10).mean().reset_index(0, drop=True)
    print("Sample technical features:")
    print(df[['ticker', 'date', 'volume_ma5', 'price_ma5', 'price_ma10']].head())

    # Extracting PE and ROE
    print("\n[DEBUG] Extracting fundamental data...")
    def extract_pe_ratio(x):
        if isinstance(x, float):
            return x
        if isinstance(x, dict) and x:
            first_key = next(iter(x))
            if isinstance(x[first_key], dict):
                return x[first_key].get('pe_ratio', None)
            return x[first_key]
        return None

    def extract_roe(x):
        if isinstance(x, float):
            return x
        if isinstance(x, dict) and x:
            first_key = next(iter(x))
            if isinstance(x[first_key], dict):
                return x[first_key].get('roe', None)
            return x[first_key]
        return None

    df['pe'] = df['pe_ratio'].apply(extract_pe_ratio)
    df['roe'] = df['pe_ratio'].apply(extract_roe)
    df[['pe', 'roe']] = df.groupby('ticker')[['pe', 'roe']].ffill()

    # Debug: PE and ROE values
    print("Extracted PE and ROE (after forward-fill):")
    print(df[['ticker', 'date', 'pe', 'roe']].head(10))

    # Calculate volatility and risk metrics
    df['volatility'] = df.groupby('ticker')['returns'].transform(lambda x: x.rolling(window=30, min_periods=5).std())
    market_returns = df[df['ticker'] == 'SPY']['returns']
    def calculate_beta(stock_returns):
        if len(stock_returns) < 5:
            return pd.Series([None] * len(stock_returns))
        rolling_cov = stock_returns.rolling(window=30, min_periods=5).cov(market_returns)
        rolling_market_var = market_returns.rolling(window=30, min_periods=5).var()
        return rolling_cov / rolling_market_var

    df['beta'] = df.groupby('ticker')['returns'].transform(calculate_beta)
    df['liquidity_ratio'] = df.groupby('ticker')['volume'].transform(lambda x: x / x.rolling(window=30, min_periods=5).mean())

    # Debug: Risk metrics
    print("Sample risk metrics:")
    print(df[['ticker', 'date', 'volatility', 'beta', 'liquidity_ratio']].head())

    # Process news data
    df['news_count'] = df['news'].apply(len)
    print("\n[DEBUG] Processing news data...")
    print("News counts:")
    print(df[['ticker', 'date', 'news_count']].head())

    # Load macro data and filter
    macro_df = load_cached_data("macro_uncertainty", start_date=start_date, end_date=end_date)
    if macro_df is None:
        raise ValueError("Macro uncertainty data could not be loaded.")
    macro_start = macro_df.index[0]
    print(f"[DEBUG] Filtering out rows with date equal to macro start date: {macro_start}")
    df = df[df['date'] != macro_start]

    # Initialize expert models
    macro_expert = MacroExpert(macro_df=macro_df, fomcFetcher=FOMCDateBasedFetcher())
    fundamental_expert = FundamentalExpert()
    news_expert = NewsExpert()
    technical_expert = TechnicalExpert()
    risk_expert = RiskExpert()

    # Collect records for training
    records = []
    for idx, row in df.iterrows():
        print(f"\n[DEBUG] Processing {row['ticker']} {str(row['date'])}:")
        expert_predictions = []
        macro_pred = macro_expert.produce_factor(row['date'], row['ticker'])
        expert_predictions.append(macro_pred)
        fund_data = {'pe': row['pe'], 'roe': row['roe']}
        fund_pred = fundamental_expert.produce_factor(fund_data)
        expert_predictions.append(fund_pred)
        news_pred = news_expert.produce_factor(row['ticker'], row['date'])
        expert_predictions.append(news_pred)
        try:
            tech_pred = technical_expert.produce_factor(df, idx)
        except Exception as e:
            print(f"[WARN] TechnicalExpert: {e} Returning default factor of 0.0.")
            tech_pred = 0.0
        expert_predictions.append(tech_pred)
        try:
            risk_pred = risk_expert.produce_factor(df, idx)
        except Exception as e:
            print(f"[WARN] RiskExpert: {e} Returning default factor of 0.0.")
            risk_pred = 0.0
        expert_predictions.append(risk_pred)

        row_data = {
            'ticker': row['ticker'],
            'date': str(row['date']),
            'macro_data': {
                'uncertainty': row['macro_data'],
                'fomc_statement': row['fomc']
            },
            'fundamental_data': {
                'pe': float(row['pe']) if pd.notnull(row['pe']) else None,
                'roe': float(row['roe']) if pd.notnull(row['roe']) else None
            },
            'technical_data': {
                'current_price': float(row['close']),
                'volume': int(row['volume']),
                'price_ma5': float(row['price_ma5']) if pd.notnull(row['price_ma5']) else None,
                'price_ma10': float(row['price_ma10']) if pd.notnull(row['price_ma10']) else None,
                'volume_ma5': float(row['volume_ma5']) if pd.notnull(row['volume_ma5']) else None
            },
            'news_data': row['news'],
            'risk_data': {
                'volatility': float(row['volatility']) if pd.notnull(row['volatility']) else None,
                'beta': float(row['beta']) if pd.notnull(row['beta']) else None,
                'liquidity_ratio': float(row['liquidity_ratio']) if pd.notnull(row['liquidity_ratio']) else None
            }
        }

        # Print the formatted data for debugging
        print("\n[DEBUG] Formatted data:")
        print(json.dumps(row_data, indent=2, default=str))

        system_msg = (
            "You are a trading assistant that analyzes market data and provides a factor score between -1 and 1 indicating bullishness/bearishness.\n"
            "Your output must be exactly two lines: 'Factor: <float>' and 'Explanation: <text>'.\n"
            "Consider all available data including fundamentals, technicals, news, and macro factors."
        )
        user_msg = f"""Please analyze this data and provide a factor score between -1 (extremely bearish) and 1 (extremely bullish):
{json.dumps(row_data, indent=2, default=str)}"""
        single_pred, explanation = call_gpt_factor_and_expl(system_msg, user_msg)
        record = {
            'ticker': row['ticker'],
            'date': row['date'],
            'expert_predictions': expert_predictions,
            'single_prediction': single_pred,
            'explanation': explanation,
            'target': row['returns'] if pd.notnull(row['returns']) else 0.0,
            'true_return': row['returns'] if pd.notnull(row['returns']) else 0.0
        }
        records.append(record)
        print(f"Expert predictions: {expert_predictions}")
        print(f"Single prediction: {single_pred}; Explanation: {explanation}.")

    # Train the MoE model
    moe_model = train_moe_model(records)
    for record in records:
        expert_preds = torch.tensor(record['expert_predictions']).unsqueeze(0)
        moe_pred = moe_model(expert_preds).item()
        record['moe_prediction'] = moe_pred

    # Compare the models and generate the portfolio comparison graph
    compare_moe_vs_single(records)

    # Portfolio Comparison: Best 20% vs Worst 20% Stocks
    compare_portfolio_performance(records, df)

    # Save prediction results to CSV
    save_prediction_results(records, 'prediction_results.csv')

    return records

def compare_portfolio_performance(records: List[Dict], df: pd.DataFrame):
    """
    Compare the MoE and GPT portfolios by buying the top 20% and shorting the bottom 20% 
    based on their factor predictions, and plotting the portfolio returns.
    """
    # Add portfolio returns
    df['moe_pred'] = df.apply(lambda row: next((r['moe_prediction'] for r in records if r['ticker'] == row['ticker'] and r['date'] == row['date']), 0.0), axis=1)
    df['gpt_pred'] = df.apply(lambda row: next((r['single_prediction'] for r in records if r['ticker'] == row['ticker'] and r['date'] == row['date']), 0.0), axis=1)

    # Sort by predictions to select best and worst 20% stocks
    df['moe_rank'] = df['moe_pred'].rank(ascending=False)
    df['gpt_rank'] = df['gpt_pred'].rank(ascending=False)
    top_20_moe = df.nlargest(int(len(df) * 0.2), 'moe_rank')
    bottom_20_moe = df.nsmallest(int(len(df) * 0.2), 'moe_rank')
    top_20_gpt = df.nlargest(int(len(df) * 0.2), 'gpt_rank')
    bottom_20_gpt = df.nsmallest(int(len(df) * 0.2), 'gpt_rank')

    # Simulate portfolio returns
    moe_returns = top_20_moe['returns'].mean() - bottom_20_moe['returns'].mean()
    gpt_returns = top_20_gpt['returns'].mean() - bottom_20_gpt['returns'].mean()
    sp500_returns = df[df['ticker'] == 'SPY']['returns'].mean()

    # Plot the comparison
    plt.figure(figsize=(10, 6))
    plt.bar(['MoE Portfolio', 'GPT Portfolio', 'S&P 500'], [moe_returns, gpt_returns, sp500_returns])
    plt.title('Portfolio Performance Comparison (Best 20% / Worst 20%)')
    plt.ylabel('Average Return')
    plt.show()


def save_prediction_results(records: List[Dict], filename: str):
    """
    Save the prediction results to a CSV file.
    """
    # Create DataFrame from records
    df = pd.DataFrame(records)
    df.to_csv(filename, index=False)
    print(f"[INFO] Prediction results saved to {filename}")

# 11. Main

In [None]:
def main():
    start_date = "2024-12-18"  # Explicit date range used for cache and processing
    end_date = "2025-01-16"
    print(f"[INFO] Using date range: {start_date} to {end_date}")
    df = build_dataset(start_date=start_date, end_date=end_date)
    records = run_model(df, start_date, end_date)
    print("\n[DEBUG] Final records sample:")
    for rec in records[:3]:
        print(json.dumps(rec, indent=2, default=str))

if __name__ == "__main__":
    main()