# Basic Code on 1991/92 season (To check similarity with paper values)

In [None]:
import pandas as pd
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

class BayesianFootballModel:
    """
    Bayesian hierarchical model for football match prediction
    Based on Baio & Blangiardo (2010) paper

    Combines visualization capabilities and comparison table functionality
    """

    def __init__(self, data_file):
        """Initialize the model with data"""
        print(f"Initializing model with data file: {data_file}")

        # Initialize model attributes first (but not data attributes)
        self.basic_model = None
        self.mixture_model = None
        self.basic_trace = None
        self.mixture_trace = None

        try:
            self.data = self.load_and_prepare_data(data_file)
            print(" Model initialization completed successfully")
            print(f" Final check - n_games: {self.n_games}, n_teams: {self.n_teams}")
        except Exception as e:
            print(f" Error during model initialization: {e}")
            print("Please check that the data file exists and has the correct format")
            raise

    def load_and_prepare_data(self, data_file):
        """Load and prepare the football data"""
        # Read the Excel file
        df = pd.read_excel('/data/dataset/italy_serie-a_1991-1992.xlsx')

        # Clean column names
        df.columns = df.columns.str.strip()

        print(f"Original data shape: {df.shape}")
        print(f"Columns: {list(df.columns)}")

        # Create team mappings
        all_teams = pd.concat([
            df['home team'],
            df['away team']
        ]).unique()

        team_to_id = {team: i for i, team in enumerate(sorted(all_teams))}
        id_to_team = {i: team for team, i in team_to_id.items()}

        # Map team names to consecutive IDs (0-based)
        df['home_team_idx'] = df['home team'].map(team_to_id)
        df['away_team_idx'] = df['away team'].map(team_to_id)

        # Check for any mapping issues
        if df['home_team_idx'].isna().any() or df['away_team_idx'].isna().any():
            print("Warning: Some teams could not be mapped!")
            print("Home team mapping issues:", df[df['home_team_idx'].isna()]['home team'].unique())
            print("Away team mapping issues:", df[df['away_team_idx'].isna()]['away team'].unique())

        # Store team information
        self.teams = sorted(all_teams)
        self.n_teams = len(self.teams)
        self.n_games = len(df)

        print(f"Data loaded: {self.n_games} games, {self.n_teams} teams")
        print(f"Teams: {self.teams}")

        # Verify no None values
        print(f"n_games type: {type(self.n_games)}, value: {self.n_games}")
        print(f"n_teams type: {type(self.n_teams)}, value: {self.n_teams}")

        return df

    def build_basic_model(self):
        """Build the basic hierarchical model from Section 2 of the paper"""

        # Check if data is properly loaded
        if self.n_games is None or self.n_teams is None:
            raise ValueError("Data not properly loaded. Please check the data file and team mappings.")

        print(f"Building model with {self.n_games} games and {self.n_teams} teams")

        # Prepare data arrays
        home_team_idx = self.data['home_team_idx'].values
        away_team_idx = self.data['away_team_idx'].values
        y1_data = self.data['y1'].values
        y2_data = self.data['y2'].values

        # Verify data integrity
        print(f"Home team indices range: {home_team_idx.min()} to {home_team_idx.max()}")
        print(f"Away team indices range: {away_team_idx.min()} to {away_team_idx.max()}")
        print(f"Goals range - Home: {y1_data.min()} to {y1_data.max()}, Away: {y2_data.min()} to {y2_data.max()}")

        with pm.Model() as model:
            # Home advantage parameter
            home_advantage = pm.Normal("home_advantage", mu=0, tau=0.0001)

            # Hyperparameters for attack and defense effects
            mu_att = pm.Normal("mu_att", mu=0, tau=0.0001)
            mu_def = pm.Normal("mu_def", mu=0, tau=0.0001)
            tau_att = pm.Gamma("tau_att", alpha=0.01, beta=0.01)
            tau_def = pm.Gamma("tau_def", alpha=0.01, beta=0.01)

            # Team-specific attack and defense effects (before centering)
            att_star = pm.Normal("att_star", mu=mu_att, tau=tau_att, shape=self.n_teams)
            def_star = pm.Normal("def_star", mu=mu_def, tau=tau_def, shape=self.n_teams)

            # Sum-to-zero constraint (centering)
            att = pm.Deterministic("att", att_star - pt.mean(att_star))
            def_ = pm.Deterministic("def", def_star - pt.mean(def_star))


            log_theta_g1 = home_advantage + att[home_team_idx] + def_[away_team_idx]
            log_theta_g2 = att[away_team_idx] + def_[home_team_idx]

            theta_g1 = pm.Deterministic("theta_g1", pt.exp(log_theta_g1))
            theta_g2 = pm.Deterministic("theta_g2", pt.exp(log_theta_g2))

            # Likelihood - each game has its own theta values
            y1 = pm.Poisson("y1", mu=theta_g1, observed=y1_data)
            y2 = pm.Poisson("y2", mu=theta_g2, observed=y2_data)

        print("Model built successfully!")
        self.basic_model = model
        return model

    def fit_basic_model(self, draws=2000, tune=1000, chains=3, cores=1):
        """Fit the basic hierarchical model"""
        print("Fitting basic hierarchical model...")

        if self.basic_model is None:
            self.build_basic_model()

        with self.basic_model:
            # Sample from posterior
            self.basic_trace = pm.sample(
                draws=draws,
                tune=tune,
                chains=chains,
                cores=cores,
                random_seed=42,
                return_inferencedata=True,
                target_accept=0.95  # Higher target acceptance for better sampling
            )

            # Sample posterior predictive
            with self.basic_model:
                self.basic_trace.extend(pm.sample_posterior_predictive(self.basic_trace))

        print("Basic model fitting completed!")
        return self.basic_trace

    # ===== VISUALIZATION METHODS (from Model 1) =====

    def get_home_advantage_summary(self, model_type='basic'):
        """Get summary statistics for home advantage effect"""

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Please fit the {model_type} model first!")
            return None

        # Extract home advantage samples
        home_adv_samples = trace.posterior['home_advantage']

        # Calculate summary statistics
        home_summary = {
            'parameter': 'home_advantage',
            'mean': float(home_adv_samples.mean()),
            'median': float(home_adv_samples.median()),
            'std': float(home_adv_samples.std()),
            'q025': float(home_adv_samples.quantile(0.025)),
            'q975': float(home_adv_samples.quantile(0.975))
        }

        return home_summary

    def plot_team_effects(self, model_type='basic'):
        """Plot attack vs defense effects for each team"""

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Please fit the {model_type} model first!")
            return

        # Get posterior means
        if model_type == 'basic':
            att_means = trace.posterior['att'].mean(dim=['chain', 'draw']).values
            def_means = trace.posterior['def'].mean(dim=['chain', 'draw']).values
        else:
            att_means = trace.posterior['att_centered'].mean(dim=['chain', 'draw']).values
            def_means = trace.posterior['def_centered'].mean(dim=['chain', 'draw']).values

        # Create plot
        plt.figure(figsize=(12, 8))
        plt.scatter(att_means, def_means, s=100, alpha=0.7)

        # Add team labels
        for i, team in enumerate(self.teams):
            plt.annotate(team, (att_means[i], def_means[i]),
                        xytext=(5, 5), textcoords='offset points',
                        fontsize=8, alpha=0.8)

        plt.xlabel('Attack Effect')
        plt.ylabel('Defense Effect')
        plt.title(f'Team Attack vs Defense Effects ({model_type.title()} Model)')
        plt.grid(True, alpha=0.3)
        plt.axhline(y=0, color='k', linestyle='--', alpha=0.5)
        plt.axvline(x=0, color='k', linestyle='--', alpha=0.5)

        # Add quadrant labels
        plt.text(0.02, 0.98, 'Poor Attack,\nPoor Defense',
                transform=plt.gca().transAxes, va='top', ha='left',
                bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.5))
        plt.text(0.98, 0.98, 'Good Attack,\nPoor Defense',
                transform=plt.gca().transAxes, va='top', ha='right',
                bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.5))
        plt.text(0.02, 0.02, 'Poor Attack,\nGood Defense',
                transform=plt.gca().transAxes, va='bottom', ha='left',
                bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.5))
        plt.text(0.98, 0.02, 'Good Attack,\nGood Defense',
                transform=plt.gca().transAxes, va='bottom', ha='right',
                bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.5))

        plt.tight_layout()
        plt.show()


    #     return summary_df
    def get_team_summary(self, model_type='basic'):
        """Get summary statistics for team effects for *all* teams"""

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace
        if trace is None:
            print(f"Please fit the {model_type} model first!")
            return None

        # pick parameter names
        att_param = 'att' if model_type=='basic' else 'att_centered'
        def_param = 'def' if model_type=='basic' else 'def_centered'

        att_samples = trace.posterior[att_param]
        def_samples = trace.posterior[def_param]

        summary_data = []
        for i, team in enumerate(self.teams):
            # select draws for team i
            att_i = att_samples.sel({att_samples.dims[-1]: i})
            def_i = def_samples.sel({def_samples.dims[-1]: i})

            # flatten (chain,draw) into one dimension
            att_flat = att_i.values.reshape(-1)
            def_flat = def_i.values.reshape(-1)

            summary_data.append({
                'team': team,
                'att_mean':   np.mean(att_flat),
                'att_median': np.quantile(att_flat, 0.5),
                'att_q025':   np.quantile(att_flat, 0.025),
                'att_q975':   np.quantile(att_flat, 0.975),
                'def_mean':   np.mean(def_flat),
                'def_median': np.quantile(def_flat, 0.5),
                'def_q025':   np.quantile(def_flat, 0.025),
                'def_q975':   np.quantile(def_flat, 0.975),
            })

        df = pd.DataFrame(summary_data)
        # sort by attack mean descending (just like Table 2)
        return df.sort_values('att_mean', ascending=False)

    # 1. Pull the two pieces from your model
    team_df = model.get_team_summary(model_type='basic')
    home_dict = model.get_home_advantage_summary(model_type='basic')

    # Turn the home‐advantage dict into a 1-row DataFrame
    home_df = pd.DataFrame([{
        'parameter': home_dict['parameter'],
        'mean':      home_dict['mean'],
        'median':    home_dict['median'],
        'q025':      home_dict['q025'],
        'q975':      home_dict['q975']
    }])

    # 2. Write both to an .xlsx file
    output_path = "Table2_SerieA_2007-08.xlsx"
    with pd.ExcelWriter(output_path, engine='openpyxl') as writer:
        team_df.to_excel(
            writer,
            sheet_name='Team Effects',
            index=False,
            float_format="%.4f"
        )
        home_df.to_excel(
            writer,
            sheet_name='Home Advantage',
            index=False,
            float_format="%.4f"
        )

    print(f" Saved Table 2 to {output_path}")

    def print_model_summary(self, model_type='basic', show_all_teams=True):
        """Print comprehensive model summary including home advantage"""

        print(f"\n{model_type.upper()} MODEL SUMMARY")
        print("=" * 60)

        # Home advantage
        home_summary = self.get_home_advantage_summary(model_type)
        if home_summary:
            print(f"\nHOME ADVANTAGE EFFECT:")
            print(f"Mean: {home_summary['mean']:.4f}")
            print(f"95% CI: [{home_summary['q025']:.4f}, {home_summary['q975']:.4f}]")
            print(f"Interpretation: Home teams score exp({home_summary['mean']:.4f}) = {np.exp(home_summary['mean']):.3f}x more goals on average")

        # Team effects
        print(f"\nTEAM EFFECTS (all {len(self.teams)} teams):")
        team_summary = self.get_team_summary(model_type)
        if team_summary is None:
            return

        # Reorder columns for readability
        cols = [
        'team',
        'att_mean','att_median','att_q025','att_q975',
        'def_mean','def_median','def_q025','def_q975'
        ]
        print(team_summary[cols].to_string(index=False, float_format='%.4f'))

    def predict_match(self, home_team, away_team, model_type='basic', n_samples=1000):
        """Predict the outcome of a specific match"""

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Please fit the {model_type} model first!")
            return None

        # Get team indices
        if home_team not in self.teams or away_team not in self.teams:
            print(f"Team not found. Available teams: {self.teams}")
            return None

        home_idx = self.teams.index(home_team)
        away_idx = self.teams.index(away_team)

        # Get parameter samples
        home_adv_samples = trace.posterior['home_advantage'].values.flatten()

        if model_type == 'basic':
            att_samples = trace.posterior['att'].values
            def_samples = trace.posterior['def'].values
        else:
            att_samples = trace.posterior['att_centered'].values
            def_samples = trace.posterior['def_centered'].values

        # Reshape samples for easier indexing
        att_flat = att_samples.reshape(-1, att_samples.shape[-1])
        def_flat = def_samples.reshape(-1, def_samples.shape[-1])
        home_adv_flat = home_adv_samples.flatten()

        # Take only n_samples
        n_available = min(len(home_adv_flat), len(att_flat))
        n_use = min(n_samples, n_available)

        # Calculate scoring intensities
        theta1_samples = np.exp(home_adv_flat[:n_use] +
                               att_flat[:n_use, home_idx] +
                               def_flat[:n_use, away_idx])

        theta2_samples = np.exp(att_flat[:n_use, away_idx] +
                               def_flat[:n_use, home_idx])

        # Generate predictions
        home_goals = np.random.poisson(theta1_samples)
        away_goals = np.random.poisson(theta2_samples)

        # Calculate probabilities
        home_win = np.mean(home_goals > away_goals)
        draw = np.mean(home_goals == away_goals)
        away_win = np.mean(home_goals < away_goals)

        # Expected goals
        exp_home_goals = np.mean(theta1_samples)
        exp_away_goals = np.mean(theta2_samples)

        return {
            'home_team': home_team,
            'away_team': away_team,
            'expected_home_goals': exp_home_goals,
            'expected_away_goals': exp_away_goals,
            'prob_home_win': home_win,
            'prob_draw': draw,
            'prob_away_win': away_win,
            'home_goals_samples': home_goals,
            'away_goals_samples': away_goals
        }

    # ===== COMBINED ANALYSIS METHODS =====

    def run_complete_analysis(self, draws_basic=3000, draws_mixture=1000, save_results=True):
        """
        Run complete analysis with both models including visualizations and comparisons
        """

        print("="*70)
        print("COMPLETE BAYESIAN FOOTBALL MODEL ANALYSIS")
        print("="*70)

        # Fit basic model
        print("\n" + "="*50)
        print("FITTING BASIC MODEL")
        print("="*50)
        basic_trace = self.fit_basic_model(draws=draws_basic, tune=draws_basic, chains=4)

        # Basic model analysis
        print("\n" + "="*50)
        print("BASIC MODEL ANALYSIS")
        print("="*50)
        self.print_model_summary('basic', show_all_teams=True)

        # Plot basic model effects
        self.plot_team_effects('basic')

# ===== USAGE EXAMPLE =====

def run_example_analysis(data_file):
    """
    Example function showing how to use the integrated model
    """

    # Initialize the model
    print("Initializing Bayesian Football Model...")
    model = BayesianFootballModel(data_file)

    # Run complete analysis
    results = model.run_complete_analysis(
        draws_basic=1000,    # Adjust based on your computational resources
        save_results=True
    )

    return model, results

# Main execution
if __name__ == "__main__":

    print("="*70)
    print("INTEGRATED BAYESIAN FOOTBALL MODEL")
    print("Combines Visualization + Comparison Tables + Predictions")
    print("="*70)

    # Replace with your actual file path
    data_file = '/content/final dataset 2007-08.xlsx'

    try:
        # Run the complete analysis
        model, results = run_example_analysis(data_file)

        print("\nSUCCESS! All analysis completed.")
        print("\nYou can now use the model object to:")
        print("- model.predict_match('Team1', 'Team2', 'basic')")
        print("- model.plot_team_effects('mixture')")
        print("- model.print_model_summary('basic')")
        print("- model.create_extended_comparison_table()")

    except Exception as e:
        print(f"\n Error during analysis: {e}")
        print("Please check that your data file exists and has the correct format.")
        print("Required columns: hometeam_name, awayteam_name, y1, y2")

# Final Base Code (Basic + Mixture) - Implemented on 2007/08 season

In [None]:
import pandas as pd
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

class BayesianFootballModel:
    """
    Bayesian hierarchical model for football match prediction
    Based on Baio & Blangiardo (2010) paper

    Combines visualization capabilities and comparison table functionality
    """

    def __init__(self, data_file):
        """Initialize the model with data"""
        print(f"Initializing model with data file: {data_file}")

        # Initialize model attributes first (but not data attributes)
        self.basic_model = None
        self.mixture_model = None
        self.basic_trace = None
        self.mixture_trace = None

        try:
            self.data = self.load_and_prepare_data(data_file)
            print(" Model initialization completed successfully")
            print(f" Final check - n_games: {self.n_games}, n_teams: {self.n_teams}")
        except Exception as e:
            print(f" Error during model initialization: {e}")
            print("Please check that the data file exists and has the correct format")
            raise

    def load_and_prepare_data(self, data_file):
        """Load and prepare the football data"""
        # Read the Excel file
        df = pd.read_excel('/data/dataset/dataset_2007-08.xlsx')

        # Clean column names
        df.columns = df.columns.str.strip()

        print(f"Original data shape: {df.shape}")
        print(f"Columns: {list(df.columns)}")

        # Create team mappings - using correct column names
        all_teams = pd.concat([
            df['hometeam_name'],
            df['awayteam_name']
        ]).unique()

        team_to_id = {team: i for i, team in enumerate(sorted(all_teams))}
        id_to_team = {i: team for team, i in team_to_id.items()}

        # Map team names to consecutive IDs (0-based)
        df['home_team_idx'] = df['hometeam_name'].map(team_to_id)
        df['away_team_idx'] = df['awayteam_name'].map(team_to_id)

        # Check for any mapping issues
        if df['home_team_idx'].isna().any() or df['away_team_idx'].isna().any():
            print("Warning: Some teams could not be mapped!")
            print("Home team mapping issues:", df[df['home_team_idx'].isna()]['hometeam_name'].unique())
            print("Away team mapping issues:", df[df['away_team_idx'].isna()]['awayteam_name'].unique())

        # Store team information
        self.teams = sorted(all_teams)
        self.n_teams = len(self.teams)
        self.n_games = len(df)

        print(f"Data loaded: {self.n_games} games, {self.n_teams} teams")
        print(f"Teams: {self.teams}")

        # Verify no None values
        print(f"n_games type: {type(self.n_games)}, value: {self.n_games}")
        print(f"n_teams type: {type(self.n_teams)}, value: {self.n_teams}")

        return df

    def build_basic_model(self):
        """Build the basic hierarchical model from Section 2 of the paper"""

        # Check if data is properly loaded
        if self.n_games is None or self.n_teams is None:
            raise ValueError("Data not properly loaded. Please check the data file and team mappings.")

        print(f"Building model with {self.n_games} games and {self.n_teams} teams")

        # Prepare data arrays
        home_team_idx = self.data['home_team_idx'].values
        away_team_idx = self.data['away_team_idx'].values
        y1_data = self.data['y1'].values
        y2_data = self.data['y2'].values

        # Verify data integrity
        print(f"Home team indices range: {home_team_idx.min()} to {home_team_idx.max()}")
        print(f"Away team indices range: {away_team_idx.min()} to {away_team_idx.max()}")
        print(f"Goals range - Home: {y1_data.min()} to {y1_data.max()}, Away: {y2_data.min()} to {y2_data.max()}")

        with pm.Model() as model:
            # Home advantage parameter
            home_advantage = pm.Normal("home_advantage", mu=0, tau=0.0001)

            # Hyperparameters for attack and defense effects
            mu_att = pm.Normal("mu_att", mu=0, tau=0.0001)
            mu_def = pm.Normal("mu_def", mu=0, tau=0.0001)
            tau_att = pm.Gamma("tau_att", alpha=0.01, beta=0.01)
            tau_def = pm.Gamma("tau_def", alpha=0.01, beta=0.01)

            # Team-specific attack and defense effects (before centering)
            att_star = pm.Normal("att_star", mu=mu_att, tau=tau_att, shape=self.n_teams)
            def_star = pm.Normal("def_star", mu=mu_def, tau=tau_def, shape=self.n_teams)

            # Sum-to-zero constraint (centering)
            att = pm.Deterministic("att", att_star - pt.mean(att_star))
            def_ = pm.Deterministic("def", def_star - pt.mean(def_star))

            # CORRECTED: Direct indexing for game-specific theta values
            # For each game g:
            # log(theta_g1) = home + att[h(g)] + def[a(g)]  (home team scoring intensity)
            # log(theta_g2) = att[a(g)] + def[h(g)]        (away team scoring intensity)

            log_theta_g1 = home_advantage + att[home_team_idx] + def_[away_team_idx]
            log_theta_g2 = att[away_team_idx] + def_[home_team_idx]

            theta_g1 = pm.Deterministic("theta_g1", pt.exp(log_theta_g1))
            theta_g2 = pm.Deterministic("theta_g2", pt.exp(log_theta_g2))

            # Likelihood - each game has its own theta values
            y1 = pm.Poisson("y1", mu=theta_g1, observed=y1_data)
            y2 = pm.Poisson("y2", mu=theta_g2, observed=y2_data)

        print("Model built successfully!")
        self.basic_model = model
        return model

    def build_mixture_model(self):
        """Build the mixture model from Section 4 of the paper"""

        # Prepare data arrays
        home_team_idx = self.data['home_team_idx'].values
        away_team_idx = self.data['away_team_idx'].values
        y1_data = self.data['y1'].values
        y2_data = self.data['y2'].values

        with pm.Model() as model:
            # Home advantage parameter
            home_advantage = pm.Normal("home_advantage", mu=0, tau=0.0001)

            # Mixture parameters for each team
            # Prior probabilities for group membership (3 groups: bottom, mid, top)
            alpha_att = np.ones(3)  # Uniform prior over groups
            alpha_def = np.ones(3)

            p_att = pm.Dirichlet("p_att", a=alpha_att, shape=(self.n_teams, 3))
            p_def = pm.Dirichlet("p_def", a=alpha_def, shape=(self.n_teams, 3))

            # Group assignment for each team
            grp_att = pm.Categorical("grp_att", p=p_att, shape=self.n_teams)
            grp_def = pm.Categorical("grp_def", p=p_def, shape=self.n_teams)

            # Group-specific parameters
            # Group 1: Bottom teams (poor attack, poor defense)
            mu_att_1 = pm.TruncatedNormal("mu_att_1", mu=0, tau=0.001, lower=-3, upper=0)
            mu_def_1 = pm.TruncatedNormal("mu_def_1", mu=0, tau=0.001, lower=0, upper=3)
            tau_att_1 = pm.Gamma("tau_att_1", alpha=0.01, beta=0.01)
            tau_def_1 = pm.Gamma("tau_def_1", alpha=0.01, beta=0.01)

            # Group 2: Mid-table teams (average)
            tau_att_2 = pm.Gamma("tau_att_2", alpha=0.01, beta=0.01)
            tau_def_2 = pm.Gamma("tau_def_2", alpha=0.01, beta=0.01)
            mu_att_2 = pm.Normal("mu_att_2", mu=0, tau=tau_att_2)
            mu_def_2 = pm.Normal("mu_def_2", mu=0, tau=tau_def_2)

            # Group 3: Top teams (good attack, good defense)
            mu_att_3 = pm.TruncatedNormal("mu_att_3", mu=0, tau=0.001, lower=0, upper=3)
            mu_def_3 = pm.TruncatedNormal("mu_def_3", mu=0, tau=0.001, lower=-3, upper=0)
            tau_att_3 = pm.Gamma("tau_att_3", alpha=0.01, beta=0.01)
            tau_def_3 = pm.Gamma("tau_def_3", alpha=0.01, beta=0.01)

            # Stack parameters
            mu_att_groups = pt.stack([mu_att_1, mu_att_2, mu_att_3])
            mu_def_groups = pt.stack([mu_def_1, mu_def_2, mu_def_3])
            tau_att_groups = pt.stack([tau_att_1, tau_att_2, tau_att_3])
            tau_def_groups = pt.stack([tau_def_1, tau_def_2, tau_def_3])

            # Team-specific effects using t-distribution with 4 degrees of freedom (as in paper)
            att_effects = []
            def_effects = []

            for t in range(self.n_teams):
                # For each team, determine which group they belong to and use appropriate parameters
                att_mu_t = pt.switch(pt.eq(grp_att[t], 0), mu_att_groups[0],
                                   pt.switch(pt.eq(grp_att[t], 1), mu_att_groups[1], mu_att_groups[2]))
                att_tau_t = pt.switch(pt.eq(grp_att[t], 0), tau_att_groups[0],
                                    pt.switch(pt.eq(grp_att[t], 1), tau_att_groups[1], tau_att_groups[2]))

                def_mu_t = pt.switch(pt.eq(grp_def[t], 0), mu_def_groups[0],
                                   pt.switch(pt.eq(grp_def[t], 1), mu_def_groups[1], mu_def_groups[2]))
                def_tau_t = pt.switch(pt.eq(grp_def[t], 0), tau_def_groups[0],
                                    pt.switch(pt.eq(grp_def[t], 1), tau_def_groups[1], tau_def_groups[2]))

                # Use StudentT with nu=4 degrees of freedom as in the paper
                att_t = pm.StudentT(f"att_raw_{t}", nu=4, mu=att_mu_t, lam=att_tau_t)
                def_t = pm.StudentT(f"def_raw_{t}", nu=4, mu=def_mu_t, lam=def_tau_t)

                att_effects.append(att_t)
                def_effects.append(def_t)

            att = pt.stack(att_effects)
            def_ = pt.stack(def_effects)

            # Apply sum-to-zero constraint
            att_centered = pm.Deterministic("att_centered", att - pt.mean(att))
            def_centered = pm.Deterministic("def_centered", def_ - pt.mean(def_))

            # CORRECTED: Direct indexing for game-specific theta values (same as basic model)
            # For each game g:
            # log(theta_g1) = home + att[h(g)] + def[a(g)]  (home team scoring intensity)
            # log(theta_g2) = att[a(g)] + def[h(g)]        (away team scoring intensity)

            log_theta_g1 = home_advantage + att_centered[home_team_idx] + def_centered[away_team_idx]
            log_theta_g2 = att_centered[away_team_idx] + def_centered[home_team_idx]

            theta_g1 = pm.Deterministic("theta_g1", pt.exp(log_theta_g1))
            theta_g2 = pm.Deterministic("theta_g2", pt.exp(log_theta_g2))

            # Likelihood - each game has its own theta values
            y1 = pm.Poisson("y1", mu=theta_g1, observed=y1_data)
            y2 = pm.Poisson("y2", mu=theta_g2, observed=y2_data)

        self.mixture_model = model
        return model

    def fit_basic_model(self, draws=2000, tune=2000, chains=4, cores=1):
        """Fit the basic hierarchical model"""
        print("Fitting basic hierarchical model...")

        if self.basic_model is None:
            self.build_basic_model()

        with self.basic_model:
            # Sample from posterior
            self.basic_trace = pm.sample(
                draws=draws,
                tune=tune,
                chains=chains,
                cores=cores,
                random_seed=42,
                return_inferencedata=True,
                target_accept=0.97  # Higher target acceptance for better sampling
            )

            # Sample posterior predictive
            with self.basic_model:
                self.basic_trace.extend(pm.sample_posterior_predictive(self.basic_trace))

        print("Basic model fitting completed!")
        return self.basic_trace

    def fit_mixture_model(self, draws=2000, tune=2000, chains=4, cores=1):
        """Fit the mixture model"""
        print("Fitting mixture model...")

        if self.mixture_model is None:
            self.build_mixture_model()

        with self.mixture_model:
            # Sample from posterior with simplified parameters
            self.mixture_trace = pm.sample(
                draws=draws,
                tune=tune,
                chains=chains,
                cores=cores,
                random_seed=42,
                return_inferencedata=True,
                target_accept=0.97  # Slightly lower target acceptance to avoid issues
            )

            # Sample posterior predictive
            with self.mixture_model:
                self.mixture_trace.extend(pm.sample_posterior_predictive(self.mixture_trace))

        print("Mixture model fitting completed!")
        return self.mixture_trace

    # ===== REALISTIC SIMULATION METHODS =====

    def get_realistic_model_predictions(self, model_type, n_simulations=1500):
        """
        Get realistic predictions by simulating actual match outcomes
        Using MEDIAN of season totals (exactly like the paper)
        """
        np.random.seed(42)
        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Warning: {model_type} model not fitted, skipping...")
            return None

        # Get posterior samples of scoring intensities (theta values)
        if 'theta_g1' in trace.posterior.data_vars and 'theta_g2' in trace.posterior.data_vars:
            theta1_samples = trace.posterior['theta_g1'].values  # [chains, draws, games]
            theta2_samples = trace.posterior['theta_g2'].values  # [chains, draws, games]
        else:
            print(f"Could not find theta variables in {model_type} model")
            print("Available variables:", list(trace.posterior.data_vars))
            return None

        # Reshape for easier handling
        n_chains, n_draws, n_games = theta1_samples.shape
        theta1_flat = theta1_samples.reshape(-1, n_games)  # [total_samples, games]
        theta2_flat = theta2_samples.reshape(-1, n_games)

        n_available_samples = len(theta1_flat)

        # If we have fewer samples than desired, resample with replacement
        if n_available_samples < n_simulations:
            print(f"Only {n_available_samples} posterior samples available, resampling to get {n_simulations}...")
            resample_indices = np.random.choice(n_available_samples, size=n_simulations, replace=True)
            theta1_sim = theta1_flat[resample_indices]
            theta2_sim = theta2_flat[resample_indices]
            n_samples = n_simulations
        else:
            theta1_sim = theta1_flat[:n_simulations]
            theta2_sim = theta2_flat[:n_simulations]
            n_samples = n_simulations

        print(f"Simulating {n_samples} scenarios for {model_type} model predictions...")

        pred_stats = []

        for team in self.teams:
            # Get indices for this team's games
            team_home_mask = (self.data['hometeam_name'] == team)
            team_away_mask = (self.data['awayteam_name'] == team)
            team_mask = team_home_mask | team_away_mask

            team_games = self.data[team_mask].copy()

            # Store season totals for each simulation
            season_points = []
            season_goals_scored = []
            season_goals_conceded = []
            season_wins = []
            season_draws = []
            season_losses = []

            # For each posterior sample, simulate a complete season
            for sim_idx in range(n_samples):
                sim_points = 0
                sim_goals_scored = 0
                sim_goals_conceded = 0
                sim_wins = 0
                sim_draws = 0
                sim_losses = 0

                # Simulate all matches for this team in this posterior sample
                for _, match in team_games.iterrows():
                    game_idx = match.name  # original index in dataset

                    # Get theta values for this specific simulation
                    game_theta1 = theta1_sim[sim_idx, game_idx]  # home team scoring intensity
                    game_theta2 = theta2_sim[sim_idx, game_idx]  # away team scoring intensity

                    # Simulate actual goals for this specific match
                    simulated_home_goals = np.random.poisson(game_theta1)
                    simulated_away_goals = np.random.poisson(game_theta2)

                    # Determine team's perspective
                    if match['hometeam_name'] == team:
                        # Team is playing at home
                        team_goals = simulated_home_goals
                        opponent_goals = simulated_away_goals
                    else:
                        # Team is playing away
                        team_goals = simulated_away_goals
                        opponent_goals = simulated_home_goals

                    # Update season totals for this simulation
                    sim_goals_scored += team_goals
                    sim_goals_conceded += opponent_goals

                    # Determine match result
                    if team_goals > opponent_goals:
                        sim_points += 3
                        sim_wins += 1
                    elif team_goals == opponent_goals:
                        sim_points += 1
                        sim_draws += 1
                    else:
                        sim_losses += 1

                # Store this simulation's season totals
                season_points.append(sim_points)
                season_goals_scored.append(sim_goals_scored)
                season_goals_conceded.append(sim_goals_conceded)
                season_wins.append(sim_wins)
                season_draws.append(sim_draws)
                season_losses.append(sim_losses)

            # Take MEDIAN of season totals (exactly like the paper)
            pred_stats.append({
                'team': team,
                f'{model_type}_points': int(np.median(season_points)),
                f'{model_type}_scored': int(np.median(season_goals_scored)),
                f'{model_type}_conceded': int(np.median(season_goals_conceded)),
                f'{model_type}_wins': int(np.median(season_wins)),
                f'{model_type}_draws': int(np.median(season_draws)),
                f'{model_type}_losses': int(np.median(season_losses))
            })

        return pred_stats

    def create_extended_comparison_table(self, save_to_file=False, filename="extended_season_comparison.csv"):
        """
        Create a comprehensive comparison table including mixture model results
        Now includes wins, draws, and losses with REALISTIC simulation
        """

        if self.basic_trace is None:
            print("Please fit the basic model first!")
            return None

        # Calculate observed statistics for each team
        observed_stats = []

        for team in self.teams:
            team_data = self.data[
                (self.data['hometeam_name'] == team) |
                (self.data['awayteam_name'] == team)
            ].copy()

            # Calculate points, goals scored, goals conceded, wins, draws, losses
            points = 0
            goals_scored = 0
            goals_conceded = 0
            wins = 0
            draws = 0
            losses = 0

            for _, match in team_data.iterrows():
                if match['hometeam_name'] == team:
                    # Team playing at home
                    goals_for = match['y1']
                    goals_against = match['y2']
                else:
                    # Team playing away
                    goals_for = match['y2']
                    goals_against = match['y1']

                # Calculate match result
                if goals_for > goals_against:
                    points += 3
                    wins += 1
                elif goals_for == goals_against:
                    points += 1
                    draws += 1
                else:
                    losses += 1

                goals_scored += goals_for
                goals_conceded += goals_against

            observed_stats.append({
                'team': team,
                'obs_points': points,
                'obs_scored': goals_scored,
                'obs_conceded': goals_conceded,
                'obs_wins': wins,
                'obs_draws': draws,
                'obs_losses': losses
            })

        # Get REALISTIC predictions from both models using simulation
        basic_predictions = self.get_realistic_model_predictions('basic')
        mixture_predictions = self.get_realistic_model_predictions('mixture') if self.mixture_trace else None

        # Combine all data
        comparison_data = []
        for i, obs in enumerate(observed_stats):
            row = obs.copy()

            # Add basic model predictions
            if basic_predictions:
                row.update(basic_predictions[i])

            # Add mixture model predictions if available
            if mixture_predictions:
                row.update(mixture_predictions[i])

            comparison_data.append(row)

        # Create DataFrame and sort by observed points
        df = pd.DataFrame(comparison_data)
        df = df.sort_values('obs_points', ascending=False)

        # Save to file if requested
        if save_to_file:
            df.to_csv(filename, index=False)
            print(f"Extended comparison table saved to {filename}")

        return df

    # ===== VISUALIZATION METHODS =====

    def get_home_advantage_summary(self, model_type='basic'):
        """Get summary statistics for home advantage effect"""

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Please fit the {model_type} model first!")
            return None

        # Extract home advantage samples
        home_adv_samples = trace.posterior['home_advantage']

        # Calculate summary statistics
        home_summary = {
            'parameter': 'home_advantage',
            'mean': float(home_adv_samples.mean()),
            'median': float(home_adv_samples.median()),
            'std': float(home_adv_samples.std()),
            'q025': float(home_adv_samples.quantile(0.025)),
            'q975': float(home_adv_samples.quantile(0.975))
        }

        return home_summary

    def plot_team_effects(self, model_type='basic'):
        """Plot attack vs defense effects for each team"""

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Please fit the {model_type} model first!")
            return

        # Get posterior means
        if model_type == 'basic':
            att_means = trace.posterior['att'].mean(dim=['chain', 'draw']).values
            def_means = trace.posterior['def'].mean(dim=['chain', 'draw']).values
        else:
            att_means = trace.posterior['att_centered'].mean(dim=['chain', 'draw']).values
            def_means = trace.posterior['def_centered'].mean(dim=['chain', 'draw']).values

        # Create plot
        plt.figure(figsize=(12, 8))
        plt.scatter(att_means, def_means, s=100, alpha=0.7)

        # Add team labels
        for i, team in enumerate(self.teams):
            plt.annotate(team, (att_means[i], def_means[i]),
                        xytext=(5, 5), textcoords='offset points',
                        fontsize=8, alpha=0.8)

        plt.xlabel('Attack Effect')
        plt.ylabel('Defense Effect')
        plt.title(f'Team Attack vs Defense Effects ({model_type.title()} Model)')
        plt.grid(True, alpha=0.3)
        plt.axhline(y=0, color='k', linestyle='--', alpha=0.5)
        plt.axvline(x=0, color='k', linestyle='--', alpha=0.5)

        # Add quadrant labels
        plt.text(0.02, 0.98, 'Poor Attack,\nPoor Defense',
                transform=plt.gca().transAxes, va='top', ha='left',
                bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.5))
        plt.text(0.98, 0.98, 'Good Attack,\nPoor Defense',
                transform=plt.gca().transAxes, va='top', ha='right',
                bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.5))
        plt.text(0.02, 0.02, 'Poor Attack,\nGood Defense',
                transform=plt.gca().transAxes, va='bottom', ha='left',
                bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.5))
        plt.text(0.98, 0.02, 'Good Attack,\nGood Defense',
                transform=plt.gca().transAxes, va='bottom', ha='right',
                bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.5))

        plt.tight_layout()
        plt.show()

    def get_team_summary(self, model_type='basic'):
        """Get summary statistics for team effects"""

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Please fit the {model_type} model first!")
            return None

        # Get parameter names
        if model_type == 'basic':
            att_param = 'att'
            def_param = 'def'
        else:
            att_param = 'att_centered'
            def_param = 'def_centered'

        # Extract posterior samples
        att_samples = trace.posterior[att_param]
        def_samples = trace.posterior[def_param]

        # Calculate summary statistics
        summary_data = []
        for i, team in enumerate(self.teams):
            # Handle different indexing based on parameter dimensions
            if len(att_samples.dims) == 3:  # chain, draw, team
                att_team_samples = att_samples.isel({list(att_samples.dims)[2]: i})
                def_team_samples = def_samples.isel({list(def_samples.dims)[2]: i})
            else:  # Different structure
                att_team_samples = att_samples[..., i]
                def_team_samples = def_samples[..., i]

            att_mean = float(att_team_samples.mean())
            att_q025 = float(att_team_samples.quantile(0.025))
            att_q975 = float(att_team_samples.quantile(0.975))

            def_mean = float(def_team_samples.mean())
            def_q025 = float(def_team_samples.quantile(0.025))
            def_q975 = float(def_team_samples.quantile(0.975))

            summary_data.append({
                'team': team,
                'att_mean': att_mean,
                'att_q025': att_q025,
                'att_q975': att_q975,
                'def_mean': def_mean,
                'def_q025': def_q025,
                'def_q975': def_q975
            })

        summary_df = pd.DataFrame(summary_data)

        # Sort by attack effect (descending)
        summary_df = summary_df.sort_values('att_mean', ascending=False)

        return summary_df

    def print_model_summary(self, model_type='basic', show_all_teams=True):
        """Print comprehensive model summary including home advantage"""

        print(f"\n{model_type.upper()} MODEL SUMMARY")
        print("=" * 60)

        # Home advantage
        home_summary = self.get_home_advantage_summary(model_type)
        if home_summary:
            print(f"\nHOME ADVANTAGE EFFECT:")
            print(f"Mean: {home_summary['mean']:.4f}")
            print(f"95% CI: [{home_summary['q025']:.4f}, {home_summary['q975']:.4f}]")
            print(f"Interpretation: Home teams score exp({home_summary['mean']:.4f}) = {np.exp(home_summary['mean']):.3f}x more goals on average")

        # Team effects
        print(f"\nTEAM EFFECTS:")
        team_summary = self.get_team_summary(model_type)
        if team_summary is not None:
            print("\nTop 5 Attack (most goals scored):")
            print(team_summary.head()[['team', 'att_mean', 'att_q025', 'att_q975']].to_string(index=False))

            print(f"\nTop 5 Defense (fewest goals conceded - most negative values):")
            defense_sorted = team_summary.sort_values('def_mean', ascending=True)
            print(defense_sorted.head()[['team', 'def_mean', 'def_q025', 'def_q975']].to_string(index=False))

            if show_all_teams:
                print(f"\nBottom 5 Attack (fewest goals scored):")
                attack_sorted = team_summary.sort_values('att_mean', ascending=True)
                print(attack_sorted.head()[['team', 'att_mean', 'att_q025', 'att_q975']].to_string(index=False))

                print(f"\nBottom 5 Defense (most goals conceded - most positive values):")
                print(defense_sorted.tail()[['team', 'def_mean', 'def_q025', 'def_q975']].to_string(index=False))

    def predict_match(self, home_team, away_team, model_type='basic', n_samples=1000):
        """Predict the outcome of a specific match"""

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Please fit the {model_type} model first!")
            return None

        # Get team indices
        if home_team not in self.teams or away_team not in self.teams:
            print(f"Team not found. Available teams: {self.teams}")
            return None

        home_idx = self.teams.index(home_team)
        away_idx = self.teams.index(away_team)

        # Get parameter samples
        home_adv_samples = trace.posterior['home_advantage'].values.flatten()

        if model_type == 'basic':
            att_samples = trace.posterior['att'].values
            def_samples = trace.posterior['def'].values
        else:
            att_samples = trace.posterior['att_centered'].values
            def_samples = trace.posterior['def_centered'].values

        # Reshape samples for easier indexing
        att_flat = att_samples.reshape(-1, att_samples.shape[-1])
        def_flat = def_samples.reshape(-1, def_samples.shape[-1])
        home_adv_flat = home_adv_samples.flatten()

        # Take only n_samples
        n_available = min(len(home_adv_flat), len(att_flat))
        n_use = min(n_samples, n_available)

        # Calculate scoring intensities
        theta1_samples = np.exp(home_adv_flat[:n_use] +
                               att_flat[:n_use, home_idx] +
                               def_flat[:n_use, away_idx])

        theta2_samples = np.exp(att_flat[:n_use, away_idx] +
                               def_flat[:n_use, home_idx])

        # Generate predictions
        home_goals = np.random.poisson(theta1_samples)
        away_goals = np.random.poisson(theta2_samples)

        # Calculate probabilities
        home_win = np.mean(home_goals > away_goals)
        draw = np.mean(home_goals == away_goals)
        away_win = np.mean(home_goals < away_goals)

        # Expected goals
        exp_home_goals = np.mean(theta1_samples)
        exp_away_goals = np.mean(theta2_samples)

        return {
            'home_team': home_team,
            'away_team': away_team,
            'expected_home_goals': exp_home_goals,
            'expected_away_goals': exp_away_goals,
            'prob_home_win': home_win,
            'prob_draw': draw,
            'prob_away_win': away_win,
            'home_goals_samples': home_goals,
            'away_goals_samples': away_goals
        }

    # ===== TRACEPLOTS METHODS =====

    def plot_traceplots(self, model_type='basic', var_names=None, figsize=(15, 12)):
        """
        Generate trace plots for MCMC diagnostics

        Parameters:
        -----------
        model_type : str
            'basic' or 'mixture' to specify which model to plot
        var_names : list, optional
            List of variable names to plot. If None, plots key parameters
        figsize : tuple
            Figure size for the plots
        """

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Please fit the {model_type} model first!")
            return

        print(f"\nGenerating trace plots for {model_type} model...")

        # Default variables to plot if not specified
        if var_names is None:
            if model_type == 'basic':
                var_names = ['home_advantage', 'mu_att', 'mu_def', 'tau_att', 'tau_def']
                # Add a few team-specific effects for illustration
                var_names.extend(['att', 'def'])
            else:
                var_names = ['home_advantage', 'mu_att_1', 'mu_att_2', 'mu_att_3',
                           'mu_def_1', 'mu_def_2', 'mu_def_3']

        # Filter variables that actually exist in the trace
        available_vars = list(trace.posterior.data_vars)
        var_names = [var for var in var_names if var in available_vars]

        if not var_names:
            print(f"No specified variables found in {model_type} model trace.")
            print(f"Available variables: {available_vars}")
            return

        try:
            # Create trace plots using ArviZ
            ax = az.plot_trace(
                trace,
                var_names=var_names,
                figsize=figsize,
                combined=False,
                compact=False
            )

            plt.suptitle(f'Trace Plots - {model_type.title()} Model', fontsize=16, y=0.98)
            plt.tight_layout()
            plt.show()

            # Print convergence diagnostics
            self.print_convergence_diagnostics(model_type, var_names)

        except Exception as e:
            print(f"Error generating trace plots: {e}")
            print("Trying with simpler variable selection...")

            # Fallback: just plot home_advantage if it exists
            if 'home_advantage' in available_vars:
                az.plot_trace(trace, var_names=['home_advantage'], figsize=(12, 4))
                plt.suptitle(f'Trace Plot: Home Advantage - {model_type.title()} Model')
                plt.tight_layout()
                plt.show()

    def print_convergence_diagnostics(self, model_type='basic', var_names=None):
        """
        Print convergence diagnostics for the specified model

        Parameters:
        -----------
        model_type : str
            'basic' or 'mixture' to specify which model to diagnose
        var_names : list, optional
            List of variable names to diagnose. If None, uses all variables
        """

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Please fit the {model_type} model first!")
            return

        print(f"\n{'='*60}")
        print(f"CONVERGENCE DIAGNOSTICS - {model_type.upper()} MODEL")
        print(f"{'='*60}")

        try:
            # Calculate R-hat (should be < 1.1 for good convergence)
            rhat = az.rhat(trace, var_names=var_names)

            # Calculate effective sample size (should be > 400)
            ess = az.ess(trace, var_names=var_names)

            # Handle different return types from ArviZ
            if hasattr(rhat, 'values'):
                # If it's an xarray Dataset/DataArray, get the values
                rhat_values = rhat.values if hasattr(rhat.values, 'flatten') else [rhat.values]
                if hasattr(rhat_values, 'flatten'):
                    rhat_flat = rhat_values.flatten()
                else:
                    rhat_flat = np.array(rhat_values).flatten()
                max_rhat = float(np.nanmax(rhat_flat))
            else:
                # If it's already a scalar or array
                max_rhat = float(np.nanmax(rhat))

            if hasattr(ess, 'values'):
                # If it's an xarray Dataset/DataArray, get the values
                ess_values = ess.values if hasattr(ess.values, 'flatten') else [ess.values]
                if hasattr(ess_values, 'flatten'):
                    ess_flat = ess_values.flatten()
                else:
                    ess_flat = np.array(ess_values).flatten()
                min_ess = float(np.nanmin(ess_flat))
            else:
                # If it's already a scalar or array
                min_ess = float(np.nanmin(ess))

            print(f"Maximum R-hat: {max_rhat:.4f}")
            print(f"Minimum Effective Sample Size: {min_ess:.0f}")
            print()

            # Convergence assessment
            if max_rhat < 1.1:
                print(" R-hat indicates good convergence (< 1.1)")
            else:
                print("R-hat indicates potential convergence issues (≥ 1.1)")
                print("Consider running more iterations or increasing tune parameter")

            if min_ess > 400:
                print(" Effective sample size is adequate (> 400)")
            else:
                print("Low effective sample size (≤ 400)")
                print("Consider running more iterations")

            # Show detailed diagnostics for problematic parameters
            if max_rhat >= 1.1 or min_ess <= 400:
                print(f"\nProblematic parameters:")

                # Try to extract detailed information
                try:
                    # Convert to pandas for easier handling if possible
                    if hasattr(rhat, 'to_dataframe'):
                        rhat_df = rhat.to_dataframe().reset_index()
                        ess_df = ess.to_dataframe().reset_index()

                        # Find parameters with high R-hat
                        try:
                            value_col = [col for col in rhat_df.columns if col not in ['chain', 'draw']][-1]
                            high_rhat = rhat_df[rhat_df[value_col] >= 1.1]
                            if not high_rhat.empty:
                                print("  High R-hat (≥ 1.1):")
                                for _, row in high_rhat.head(10).iterrows():
                                    param_info = [str(row[col]) for col in rhat_df.columns[:-1]]
                                    print(f"    {param_info}: {row[value_col]:.4f}")
                        except:
                            pass

                        # Find parameters with low ESS
                        try:
                            value_col = [col for col in ess_df.columns if col not in ['chain', 'draw']][-1]
                            low_ess = ess_df[ess_df[value_col] <= 400]
                            if not low_ess.empty:
                                print("  Low ESS (≤ 400):")
                                for _, row in low_ess.head(10).iterrows():
                                    param_info = [str(row[col]) for col in ess_df.columns[:-1]]
                                    print(f"    {param_info}: {row[value_col]:.0f}")
                        except:
                            pass
                    else:
                        # Fallback: just show summary statistics
                        print("  Detailed parameter breakdown not available")
                        print(f"  Overall: Max R-hat = {max_rhat:.4f}, Min ESS = {min_ess:.0f}")

                except Exception as detail_error:
                    print(f"  Could not extract detailed parameter information: {detail_error}")
                    print(f"  Overall: Max R-hat = {max_rhat:.4f}, Min ESS = {min_ess:.0f}")

            print(f"\n{'='*60}")

        except Exception as e:
            print(f"Error calculating convergence diagnostics: {e}")
            print("This might be due to the trace structure or variable names.")

            # Fallback: try to get basic diagnostics for key parameters only
            try:
                print("Attempting basic diagnostics for key parameters...")
                key_params = ['home_advantage']
                if model_type == 'basic':
                    key_params.extend(['mu_att', 'mu_def'])
                else:
                    key_params.extend(['mu_att_1', 'mu_att_2'])

                # Filter to only available parameters
                available_params = [p for p in key_params if p in trace.posterior.data_vars]

                if available_params:
                    basic_rhat = az.rhat(trace, var_names=available_params)
                    basic_ess = az.ess(trace, var_names=available_params)

                    print(f"Basic diagnostics for {available_params}:")
                    for param in available_params:
                        if param in basic_rhat.data_vars:
                            param_rhat = float(basic_rhat[param].values)
                            param_ess = float(basic_ess[param].values)
                            status_rhat = "" if param_rhat < 1.1 else "⚠"
                            status_ess = "" if param_ess > 400 else "⚠"
                            print(f"  {param}: R-hat = {param_rhat:.4f} {status_rhat}, ESS = {param_ess:.0f} {status_ess}")

            except Exception as fallback_error:
                print(f"Fallback diagnostics also failed: {fallback_error}")
                print("Convergence diagnostics could not be calculated for this model.")

    def plot_team_effect_traceplots(self, model_type='basic', team_names=None, n_teams=5):
        """
        Plot trace plots specifically for team attack and defense effects

        Parameters:
        -----------
        model_type : str
            'basic' or 'mixture' to specify which model to plot
        team_names : list, optional
            Specific team names to plot. If None, plots first n_teams
        n_teams : int
            Number of teams to plot if team_names not specified
        """

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Please fit the {model_type} model first!")
            return

        # Determine parameter names
        if model_type == 'basic':
            att_param = 'att'
            def_param = 'def'
        else:
            att_param = 'att_centered'
            def_param = 'def_centered'

        # Check if parameters exist
        if att_param not in trace.posterior.data_vars or def_param not in trace.posterior.data_vars:
            print(f"Team effect parameters not found in {model_type} model trace.")
            print(f"Available variables: {list(trace.posterior.data_vars)}")
            return

        # Select teams to plot
        if team_names is None:
            team_indices = list(range(min(n_teams, len(self.teams))))
            selected_teams = [self.teams[i] for i in team_indices]
        else:
            team_indices = [self.teams.index(team) for team in team_names if team in self.teams]
            selected_teams = team_names

        if not team_indices:
            print("No valid teams found for plotting.")
            return

        print(f"\nGenerating team effect trace plots for: {', '.join(selected_teams)}")

        # Create subplots
        n_teams_plot = len(team_indices)
        fig, axes = plt.subplots(n_teams_plot, 4, figsize=(16, 4*n_teams_plot))

        if n_teams_plot == 1:
            axes = axes.reshape(1, -1)

        for i, (team_idx, team_name) in enumerate(zip(team_indices, selected_teams)):
            # Attack effect traces
            att_samples = trace.posterior[att_param].isel({att_param + '_dim_0': team_idx})
            def_samples = trace.posterior[def_param].isel({def_param + '_dim_0': team_idx})

            # Plot traces for each chain
            for chain in range(att_samples.sizes['chain']):
                axes[i, 0].plot(att_samples.isel(chain=chain), alpha=0.7, label=f'Chain {chain}')
                axes[i, 2].plot(def_samples.isel(chain=chain), alpha=0.7, label=f'Chain {chain}')

            axes[i, 0].set_title(f'{team_name} - Attack Effect (Trace)')
            axes[i, 0].set_ylabel('Attack Effect')
            axes[i, 0].legend()
            axes[i, 0].grid(True, alpha=0.3)

            axes[i, 2].set_title(f'{team_name} - Defense Effect (Trace)')
            axes[i, 2].set_ylabel('Defense Effect')
            axes[i, 2].legend()
            axes[i, 2].grid(True, alpha=0.3)

            # Plot distributions
            att_flat = att_samples.values.flatten()
            def_flat = def_samples.values.flatten()

            axes[i, 1].hist(att_flat, bins=50, alpha=0.7, density=True)
            axes[i, 1].set_title(f'{team_name} - Attack Effect (Distribution)')
            axes[i, 1].set_xlabel('Attack Effect')
            axes[i, 1].set_ylabel('Density')
            axes[i, 1].grid(True, alpha=0.3)

            axes[i, 3].hist(def_flat, bins=50, alpha=0.7, density=True)
            axes[i, 3].set_title(f'{team_name} - Defense Effect (Distribution)')
            axes[i, 3].set_xlabel('Defense Effect')
            axes[i, 3].set_ylabel('Density')
            axes[i, 3].grid(True, alpha=0.3)

        plt.suptitle(f'Team Effects Trace Plots - {model_type.title()} Model', fontsize=16)
        plt.tight_layout()
        plt.show()

    def generate_all_traceplots(self, model_type='basic'):
        """
        Generate comprehensive trace plots for model diagnostics

        Parameters:
        -----------
        model_type : str
            'basic' or 'mixture' to specify which model to plot
        """

        print(f"\n{'='*70}")
        print(f"COMPREHENSIVE TRACE PLOT ANALYSIS - {model_type.upper()} MODEL")
        print(f"{'='*70}")

        # 1. Main parameter trace plots
        print("\n1. Main Parameters Trace Plots:")
        self.plot_traceplots(model_type)

        # 2. Team effects trace plots (for a subset of teams)
        print("\n2. Team Effects Trace Plots:")
        self.plot_team_effect_traceplots(model_type, n_teams=3)

        # 3. Convergence diagnostics
        print("\n3. Convergence Diagnostics:")
        self.print_convergence_diagnostics(model_type)

        print(f"\n{'='*70}")
        print("TRACE PLOT ANALYSIS COMPLETE")
        print(f"{'='*70}")

    def print_extended_comparison_table(self, show_errors=True):
        """
        Print a formatted comparison table with both basic and mixture models
        Now includes wins, draws, and losses with REALISTIC simulation
        """

        df = self.create_extended_comparison_table()
        if df is None:
            return

        print("\n" + "="*170)
        print("EXTENDED SEASON RESULTS COMPARISON - OBSERVED vs BASIC vs MIXTURE MODELS")
        print("="*170)

        # Check if mixture model results are available
        has_mixture = 'mixture_points' in df.columns

        # Print headers
        if has_mixture:
            print(f"{'team':15} {'Observed results':^45} {'Basic model (medians)':^45} {'Mixture model (medians)':^45}")
            print(f"{'':15} {'pts':>5} {'sc':>4} {'co':>4} {'W':>3} {'D':>3} {'L':>3} " +
                  f"{'pts':>5} {'sc':>4} {'co':>4} {'W':>3} {'D':>3} {'L':>3} " +
                  f"{'pts':>5} {'sc':>4} {'co':>4} {'W':>3} {'D':>3} {'L':>3}")
        else:
            print(f"{'team':15} {'Observed results':^45} {'Basic model (medians)':^45}")
            print(f"{'':15} {'pts':>5} {'sc':>4} {'co':>4} {'W':>3} {'D':>3} {'L':>3} " +
                  f"{'pts':>5} {'sc':>4} {'co':>4} {'W':>3} {'D':>3} {'L':>3}")

        print("-" * 170)

        # Print data rows
        for _, row in df.iterrows():
            print(f"{row['team']:15}", end="")
            # Observed
            print(f"{row['obs_points']:5d}", end="")
            print(f"{row['obs_scored']:4d}", end="")
            print(f"{row['obs_conceded']:4d}", end="")
            print(f"{row['obs_wins']:3d}", end="")
            print(f"{row['obs_draws']:3d}", end="")
            print(f"{row['obs_losses']:3d}", end="")
            # Basic
            print(f"{row['basic_points']:5d}", end="")
            print(f"{row['basic_scored']:4d}", end="")
            print(f"{row['basic_conceded']:4d}", end="")
            print(f"{row['basic_wins']:3d}", end="")
            print(f"{row['basic_draws']:3d}", end="")
            print(f"{row['basic_losses']:3d}", end="")

            if has_mixture:
                # Mixture
                print(f"{row['mixture_points']:5d}", end="")
                print(f"{row['mixture_scored']:4d}", end="")
                print(f"{row['mixture_conceded']:4d}", end="")
                print(f"{row['mixture_wins']:3d}", end="")
                print(f"{row['mixture_draws']:3d}", end="")
                print(f"{row['mixture_losses']:3d}", end="")
            print()

        # Print summary statistics
        if show_errors:
            print("\n" + "="*90)
            print("MODEL PERFORMANCE COMPARISON - MEAN ABSOLUTE ERROR")
            print("="*90)

            # Calculate mean absolute errors for basic model
            basic_points_mae = np.mean(np.abs(df['obs_points'] - df['basic_points']))
            basic_scored_mae = np.mean(np.abs(df['obs_scored'] - df['basic_scored']))
            basic_conceded_mae = np.mean(np.abs(df['obs_conceded'] - df['basic_conceded']))
            basic_wins_mae = np.mean(np.abs(df['obs_wins'] - df['basic_wins']))
            basic_draws_mae = np.mean(np.abs(df['obs_draws'] - df['basic_draws']))
            basic_losses_mae = np.mean(np.abs(df['obs_losses'] - df['basic_losses']))

            print(f"Basic Model Performance:")
            print(f"  Points MAE:          {basic_points_mae:.2f}")
            print(f"  Goals Scored MAE:    {basic_scored_mae:.2f}")
            print(f"  Goals Conceded MAE:  {basic_conceded_mae:.2f}")
            print(f"  Wins MAE:            {basic_wins_mae:.2f}")
            print(f"  Draws MAE:           {basic_draws_mae:.2f}")
            print(f"  Losses MAE:          {basic_losses_mae:.2f}")
            basic_total = (basic_points_mae + basic_scored_mae + basic_conceded_mae +
                          basic_wins_mae + basic_draws_mae + basic_losses_mae)
            print(f"  Total MAE:           {basic_total:.2f}")

            if has_mixture:
                mixture_points_mae = np.mean(np.abs(df['obs_points'] - df['mixture_points']))
                mixture_scored_mae = np.mean(np.abs(df['obs_scored'] - df['mixture_scored']))
                mixture_conceded_mae = np.mean(np.abs(df['obs_conceded'] - df['mixture_conceded']))
                mixture_wins_mae = np.mean(np.abs(df['obs_wins'] - df['mixture_wins']))
                mixture_draws_mae = np.mean(np.abs(df['obs_draws'] - df['mixture_draws']))
                mixture_losses_mae = np.mean(np.abs(df['obs_losses'] - df['mixture_losses']))

                print(f"\nMixture Model Performance:")
                print(f"  Points MAE:          {mixture_points_mae:.2f}")
                print(f"  Goals Scored MAE:    {mixture_scored_mae:.2f}")
                print(f"  Goals Conceded MAE:  {mixture_conceded_mae:.2f}")
                print(f"  Wins MAE:            {mixture_wins_mae:.2f}")
                print(f"  Draws MAE:           {mixture_draws_mae:.2f}")
                print(f"  Losses MAE:          {mixture_losses_mae:.2f}")
                mixture_total = (mixture_points_mae + mixture_scored_mae + mixture_conceded_mae +
                               mixture_wins_mae + mixture_draws_mae + mixture_losses_mae)
                print(f"  Total MAE:           {mixture_total:.2f}")

                # Compare models
                print(f"\n" + "="*60)
                print("MODEL COMPARISON SUMMARY")
                print("="*60)

                if basic_total < mixture_total:
                    print(f" BASIC MODEL WINS with lower total MAE ({basic_total:.2f} vs {mixture_total:.2f})")
                    improvement = ((mixture_total - basic_total) / mixture_total) * 100
                    print(f"   Basic model is {improvement:.1f}% better overall")
                else:
                    print(f" MIXTURE MODEL WINS with lower total MAE ({mixture_total:.2f} vs {basic_total:.2f})")
                    improvement = ((basic_total - mixture_total) / basic_total) * 100
                    print(f"   Mixture model is {improvement:.1f}% better overall")

                # Individual category winners
                print(f"\nCategory Winners:")
                print(f"  Points:   {'Basic' if basic_points_mae < mixture_points_mae else 'Mixture'} ({min(basic_points_mae, mixture_points_mae):.2f})")
                print(f"  Scored:   {'Basic' if basic_scored_mae < mixture_scored_mae else 'Mixture'} ({min(basic_scored_mae, mixture_scored_mae):.2f})")
                print(f"  Conceded: {'Basic' if basic_conceded_mae < mixture_conceded_mae else 'Mixture'} ({min(basic_conceded_mae, mixture_conceded_mae):.2f})")
                print(f"  Wins:     {'Basic' if basic_wins_mae < mixture_wins_mae else 'Mixture'} ({min(basic_wins_mae, mixture_wins_mae):.2f})")
                print(f"  Draws:    {'Basic' if basic_draws_mae < mixture_draws_mae else 'Mixture'} ({min(basic_draws_mae, mixture_draws_mae):.2f})")
                print(f"  Losses:   {'Basic' if basic_losses_mae < mixture_losses_mae else 'Mixture'} ({min(basic_losses_mae, mixture_losses_mae):.2f})")

        return df

    def detailed_model_analysis(self):
        """
        Perform detailed analysis comparing both models (only if both are fitted)
        Now includes wins, draws, and losses analysis
        """

        df = self.create_extended_comparison_table()
        if df is None or 'mixture_points' not in df.columns:
            print("Both models need to be fitted for detailed comparison analysis.")
            return

        print(f"\n" + "="*70)
        print("DETAILED MODEL DIFFERENCES ANALYSIS")
        print("="*70)

        # Calculate differences for all metrics
        df['points_diff'] = df['mixture_points'] - df['basic_points']
        df['scored_diff'] = df['mixture_scored'] - df['basic_scored']
        df['conceded_diff'] = df['mixture_conceded'] - df['basic_conceded']
        df['wins_diff'] = df['mixture_wins'] - df['basic_wins']
        df['draws_diff'] = df['mixture_draws'] - df['basic_draws']
        df['losses_diff'] = df['mixture_losses'] - df['basic_losses']

        # Show teams where mixture model performs significantly better
        print("\nTeams where MIXTURE model predicts closer to observed results:")
        mixture_better = df[
            (abs(df['obs_points'] - df['mixture_points']) < abs(df['obs_points'] - df['basic_points']))
        ].copy()

        if len(mixture_better) > 0:
            mixture_better['basic_error'] = abs(mixture_better['obs_points'] - mixture_better['basic_points'])
            mixture_better['mixture_error'] = abs(mixture_better['obs_points'] - mixture_better['mixture_points'])
            mixture_better['improvement'] = mixture_better['basic_error'] - mixture_better['mixture_error']
            mixture_better = mixture_better.sort_values('improvement', ascending=False)

            for _, team in mixture_better.head(5).iterrows():
                print(f"  {team['team']:15}: Basic error {team['basic_error']:.1f}, Mixture error {team['mixture_error']:.1f} (improvement: {team['improvement']:.1f})")
        else:
            print("  No teams found where mixture model significantly outperforms basic model")

        # Show teams where basic model performs significantly better
        print("\nTeams where BASIC model predicts closer to observed results:")
        basic_better = df[
            (abs(df['obs_points'] - df['basic_points']) < abs(df['obs_points'] - df['mixture_points']))
        ].copy()

        if len(basic_better) > 0:
            basic_better['basic_error'] = abs(basic_better['obs_points'] - basic_better['basic_points'])
            basic_better['mixture_error'] = abs(basic_better['obs_points'] - basic_better['mixture_points'])
            basic_better['improvement'] = basic_better['mixture_error'] - basic_better['basic_error']
            basic_better = basic_better.sort_values('improvement', ascending=False)

            for _, team in basic_better.head(5).iterrows():
                print(f"  {team['team']:15}: Mixture error {team['mixture_error']:.1f}, Basic error {team['basic_error']:.1f} (improvement: {team['improvement']:.1f})")
        else:
            print("  No teams found where basic model significantly outperforms mixture model")

        # Show largest prediction differences between models
        print(f"\nLargest differences between Basic and Mixture model predictions:")
        df['total_abs_diff'] = (abs(df['points_diff']) + abs(df['scored_diff']) + abs(df['conceded_diff']) +
                               abs(df['wins_diff']) + abs(df['draws_diff']) + abs(df['losses_diff']))
        biggest_diffs = df.nlargest(5, 'total_abs_diff')

        for _, team in biggest_diffs.iterrows():
            print(f"  {team['team']:15}: Pts {team['points_diff']:+3.0f}, W {team['wins_diff']:+2.0f}, D {team['draws_diff']:+2.0f}, L {team['losses_diff']:+2.0f}")

        print(f"\n{'='*70}")
        print("SUMMARY STATISTICS")
        print(f"{'='*70}")
        print(f"Teams where mixture model is better: {len(mixture_better)}")
        print(f"Teams where basic model is better: {len(basic_better)}")
        print(f"Average absolute difference in points: {abs(df['points_diff']).mean():.2f}")
        print(f"Average absolute difference in goals scored: {abs(df['scored_diff']).mean():.2f}")
        print(f"Average absolute difference in goals conceded: {abs(df['conceded_diff']).mean():.2f}")
        print(f"Average absolute difference in wins: {abs(df['wins_diff']).mean():.2f}")
        print(f"Average absolute difference in draws: {abs(df['draws_diff']).mean():.2f}")
        print(f"Average absolute difference in losses: {abs(df['losses_diff']).mean():.2f}")

        # Correlation between observed and predicted for key metrics
        basic_corr_points = df[['obs_points', 'basic_points']].corr().iloc[0,1]
        mixture_corr_points = df[['obs_points', 'mixture_points']].corr().iloc[0,1]

        basic_corr_wins = df[['obs_wins', 'basic_wins']].corr().iloc[0,1]
        mixture_corr_wins = df[['obs_wins', 'mixture_wins']].corr().iloc[0,1]

        print(f"\nCorrelation with observed results:")
        print(f"Points - Basic: {basic_corr_points:.3f}, Mixture: {mixture_corr_points:.3f}")
        print(f"Wins   - Basic: {basic_corr_wins:.3f}, Mixture: {mixture_corr_wins:.3f}")
        print(f"Points winner: {'Mixture' if mixture_corr_points > basic_corr_points else 'Basic'} model")
        print(f"Wins winner:   {'Mixture' if mixture_corr_wins > basic_corr_wins else 'Basic'} model")

        return df

    # ===== COMBINED ANALYSIS METHODS =====

    def run_complete_analysis(self, draws_basic=250, draws_mixture=50, save_results=True, include_traceplots=True):
        """
        Run complete analysis with both models including visualizations and comparisons
        """

        print("="*70)
        print("COMPLETE BAYESIAN FOOTBALL MODEL ANALYSIS")
        print("="*70)

        # Fit basic model
        print("\n" + "="*50)
        print("FITTING BASIC MODEL")
        print("="*50)
        basic_trace = self.fit_basic_model(draws=draws_basic, tune=draws_basic, chains=4)

        # Basic model analysis
        print("\n" + "="*50)
        print("BASIC MODEL ANALYSIS")
        print("="*50)
        self.print_model_summary('basic', show_all_teams=True)

        # Plot basic model effects
        self.plot_team_effects('basic')

        # Generate trace plots for basic model
        if include_traceplots:
            #self.plot_traceplots('basic')
            self.generate_all_traceplots('basic')

        # Fit mixture model
        print("\n" + "="*50)
        print("FITTING MIXTURE MODEL")
        print("="*50)
        try:
            mixture_trace = self.fit_mixture_model(draws=draws_mixture, tune=draws_mixture, chains=4)
            print(" Mixture model fitted successfully!")

            # Mixture model analysis
            print("\n" + "="*50)
            print("MIXTURE MODEL ANALYSIS")
            print("="*50)
            self.print_model_summary('mixture', show_all_teams=True)

            # Plot mixture model effects
            self.plot_team_effects('mixture')

            # Generate trace plots for mixture model
            if include_traceplots:
                #self.plot_traceplots('mixture')
                self.generate_all_traceplots('mixture')

        except Exception as e:
            print(f" Mixture model failed: {e}")
            print("Continuing with basic model only...")

        # Generate comparison tables
        print("\n" + "="*50)
        print("SEASON COMPARISON ANALYSIS")
        print("="*50)

        comparison_df = self.print_extended_comparison_table(show_errors=True)

        # Detailed analysis if both models fitted
        if self.mixture_trace is not None:
            self.detailed_model_analysis()

        # Save results
        if save_results:
            if comparison_df is not None:
                filename = f"football_analysis_results_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.csv"
                self.create_extended_comparison_table(save_to_file=True, filename=filename)
                print(f"\n Results saved to: {filename}")

        # Example predictions
        print("\n" + "="*50)
        print("EXAMPLE MATCH PREDICTIONS")
        print("="*50)

        # Try some example predictions
        example_matches = [
            ('Internazionale', 'Milan'),
            ('Roma', 'Lazio'),
            ('Juventus', 'Napoli')
        ]

        for home, away in example_matches:
            try:
                # Basic model prediction
                pred_basic = self.predict_match(home, away, 'basic')
                if pred_basic:
                    print(f"\n{home} vs {away} (Basic Model):")
                    print(f"  Expected goals: {pred_basic['expected_home_goals']:.2f} - {pred_basic['expected_away_goals']:.2f}")
                    print(f"  Probabilities: Home {pred_basic['prob_home_win']:.3f}, Draw {pred_basic['prob_draw']:.3f}, Away {pred_basic['prob_away_win']:.3f}")

                # Mixture model prediction (if available)
                if self.mixture_trace is not None:
                    pred_mixture = self.predict_match(home, away, 'mixture')
                    if pred_mixture:
                        print(f"{home} vs {away} (Mixture Model):")
                        print(f"  Expected goals: {pred_mixture['expected_home_goals']:.2f} - {pred_mixture['expected_away_goals']:.2f}")
                        print(f"  Probabilities: Home {pred_mixture['prob_home_win']:.3f}, Draw {pred_mixture['prob_draw']:.3f}, Away {pred_mixture['prob_away_win']:.3f}")

            except Exception as e:
                print(f"Could not predict {home} vs {away}: {e}")

        print(f"\n{'='*70}")
        print("ANALYSIS COMPLETE!")
        print(f"{'='*70}")

        return {
            'basic_trace': self.basic_trace,
            'mixture_trace': self.mixture_trace,
            'comparison_df': comparison_df
        }

# ===== USAGE EXAMPLE =====

def run_example_analysis(data_file):
    """
    Example function showing how to use the integrated model
    """

    # Initialize the model
    print("Initializing Bayesian Football Model...")
    model = BayesianFootballModel(data_file)

    # Run complete analysis
    results = model.run_complete_analysis(
        draws_basic=2500,    # Adjust based on your computational resources
        draws_mixture=2000,   # Mixture model typically needs fewer draws
        save_results=True,
        include_traceplots=True  # Set to False if you don't want trace plots
    )

    return model, results

# Main execution
if __name__ == "__main__":

    print("="*70)
    print("INTEGRATED BAYESIAN FOOTBALL MODEL")
    print("Combines Visualization + Comparison Tables + Predictions + Traceplots")
    print("="*70)

    # Replace with your actual file path
    data_file = '/content/final dataset 2007-08.xlsx'

    try:
        # Run the complete analysis
        model, results = run_example_analysis(data_file)

        print("\n SUCCESS! All analysis completed.")
        print("\nYou can now use the model object to:")
        print("- model.predict_match('Team1', 'Team2', 'basic')")
        print("- model.plot_team_effects('mixture')")
        print("- model.print_model_summary('basic')")
        print("- model.create_extended_comparison_table()")
        #print("- model.plot_traceplots('basic')")
        print("- model.generate_all_traceplots('basic')")
        print("- model.plot_traceplots('mixture')")
        print("- model.plot_team_effect_traceplots('basic', ['Juventus', 'Milan'])")


    except Exception as e:
        print(f"\n Error during analysis: {e}")
        print("Please check that your data file exists and has the correct format.")
        print("Required columns: hometeam_name, awayteam_name, y1, y2")

# Contribution - First try

In [None]:
import pandas as pd
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import warnings
from datetime import datetime
warnings.filterwarnings('ignore')

class BayesianFootballModel:
    """
    Enhanced Bayesian hierarchical model for football match prediction
    Based on Baio & Blangiardo (2010) paper with team-specific covariates

    Includes stadium capacity, attendance, distance effects, and temporal factors
    """

    def __init__(self, data_file):
        """Initialize the model with data"""
        print(f"Initializing enhanced model with data file: {data_file}")

        # Initialize model attributes
        self.basic_model = None
        self.mixture_model = None
        self.enhanced_model = None
        self.full_covariate_model = None
        self.basic_trace = None
        self.mixture_trace = None
        self.enhanced_trace = None
        self.full_trace = None

        try:
            self.data = self.load_and_prepare_data(data_file)
            print(" Model initialization completed successfully")
            print(f" Final check - n_games: {self.n_games}, n_teams: {self.n_teams}")
        except Exception as e:
            print(f" Error during model initialization: {e}")
            print("Please check that the data file exists and has the correct format")
            raise

    def load_and_prepare_data(self, data_file):
        """Load and prepare the enhanced football data with all covariates"""

        try:
            # Try to read the new dataset
            df = pd.read_excel('/data/dataset/dataset_2007-08_stadium_distance_date.xlsx')
        except:
            try:
                df = pd.read_excel('/data/dataset/dataset_2007-08_stadium_distance_date.xlsx')
            except:
                # Fallback to basic dataset
                df = pd.read_excel('/data/dataset/dataset_2007-08.xlsx')
                print("Warning: Using basic dataset, creating sample covariates")

        # Clean column names
        df.columns = df.columns.str.strip()

        print(f"Original data shape: {df.shape}")
        print(f"Columns: {list(df.columns)}")

        # Check for required columns
        required_cols = ['hometeam_name', 'awayteam_name', 'y1', 'y2']
        missing_cols = [col for col in required_cols if col not in df.columns]
        if missing_cols:
            print(f"Missing required columns: {missing_cols}")
            raise ValueError(f"Dataset must contain columns: {required_cols}")

        # Create team mappings
        all_teams = pd.concat([
            df['hometeam_name'],
            df['awayteam_name']
        ]).unique()

        team_to_id = {team: i for i, team in enumerate(sorted(all_teams))}
        id_to_team = {i: team for team, i in team_to_id.items()}

        # Map team names to consecutive IDs (0-based)
        df['home_team_idx'] = df['hometeam_name'].map(team_to_id)
        df['away_team_idx'] = df['awayteam_name'].map(team_to_id)

        # Store basic team information
        self.teams = sorted(all_teams)
        self.n_teams = len(self.teams)
        self.n_games = len(df)

        print(f"Data loaded: {self.n_games} games, {self.n_teams} teams")

        # Process and prepare covariates
        self._prepare_covariates(df)

        return df

    def _prepare_covariates(self, df):
        """Prepare all covariates for modeling"""

        print("\n" + "="*60)
        print("PREPARING COVARIATES")
        print("="*60)

        # Initialize covariate dictionaries
        self.team_covariates = {}
        self.game_covariates = {}
        self.standardized_team_covariates = {}
        self.standardized_game_covariates = {}

        # 1. STADIUM CHARACTERISTICS (team-specific)
        self._prepare_stadium_covariates(df)

        # 2. DISTANCE EFFECTS (game-specific)
        self._prepare_distance_covariates(df)

        # 3. TEMPORAL EFFECTS (game-specific)
        self._prepare_temporal_covariates(df)

        # 4. STANDARDIZE ALL COVARIATES
        self._standardize_covariates()

        print("\n All covariates prepared successfully")

    def _prepare_stadium_covariates(self, df):
        """Prepare stadium-related covariates (team-specific)"""

        print("\n1. STADIUM CHARACTERISTICS:")

        # Check what stadium columns are available
        stadium_cols = [col for col in df.columns if any(keyword in col.lower()
                       for keyword in ['stadium', 'capacity', 'attendance', 'utilization'])]

        print(f"   Available stadium columns: {stadium_cols}")

        # Initialize team covariates dictionary
        for i, team in enumerate(self.teams):
            self.team_covariates[i] = {'team_name': team}

        # Extract stadium information for each team
        for team_idx, team in enumerate(self.teams):
            # Get home games for this team
            home_games = df[df['hometeam_name'] == team]

            if len(home_games) > 0:
                # Stadium capacity
                if 'stadium_capacity' in df.columns:
                    capacity = home_games['stadium_capacity'].iloc[0]
                    self.team_covariates[team_idx]['stadium_capacity'] = float(capacity) if pd.notna(capacity) else 40000.0
                else:
                    # Create sample capacity based on team name (for demonstration)
                    capacity = np.random.randint(20000, 80000)
                    self.team_covariates[team_idx]['stadium_capacity'] = float(capacity)

                # Average attendance
                if 'average_attendance' in df.columns:
                    avg_att = home_games['average_attendance'].mean()
                    self.team_covariates[team_idx]['average_attendance'] = float(avg_att) if pd.notna(avg_att) else 30000.0
                elif 'attendance' in df.columns:
                    avg_att = home_games['attendance'].mean()
                    self.team_covariates[team_idx]['average_attendance'] = float(avg_att) if pd.notna(avg_att) else 30000.0
                else:
                    # Sample attendance
                    capacity = self.team_covariates[team_idx]['stadium_capacity']
                    avg_att = capacity * np.random.uniform(0.4, 0.85)
                    self.team_covariates[team_idx]['average_attendance'] = float(avg_att)

                # Capacity utilization rate
                if 'capacity_utilization' in df.columns:
                    util = home_games['capacity_utilization'].mean()
                    self.team_covariates[team_idx]['capacity_utilization'] = float(util) if pd.notna(util) else 0.75
                else:
                    # Calculate from attendance and capacity
                    capacity = self.team_covariates[team_idx]['stadium_capacity']
                    attendance = self.team_covariates[team_idx]['average_attendance']
                    util = min(attendance / capacity, 1.0)
                    self.team_covariates[team_idx]['capacity_utilization'] = float(util)

      

        print(f"    Stadium characteristics prepared for {self.n_teams} teams")

        # Print summary statistics
        capacities = [self.team_covariates[i]['stadium_capacity'] for i in range(self.n_teams)]
        utilizations = [self.team_covariates[i]['capacity_utilization'] for i in range(self.n_teams)]

        print(f"   Stadium capacity range: {min(capacities):.0f} - {max(capacities):.0f}")
        print(f"   Capacity utilization range: {min(utilizations):.3f} - {max(utilizations):.3f}")

    def _prepare_distance_covariates(self, df):
        """Prepare distance-related covariates (game-specific)"""

        print("\n2. DISTANCE EFFECTS:")

        # Check for distance columns
        distance_cols = [col for col in df.columns if 'distance' in col.lower() or 'km' in col.lower()]
        print(f"   Available distance columns: {distance_cols}")

        if distance_cols and len(distance_cols) > 0:
            # Use the first distance column found
            distance_col = distance_cols[0]
            distances = df[distance_col].values

            # Handle missing values
            distances = np.where(pd.isna(distances), np.median(distances[~pd.isna(distances)]), distances)

            self.game_covariates['travel_distance'] = distances.astype(float)

            print(f"    Using column '{distance_col}' for travel distances")
            print(f"   Distance range: {distances.min():.1f} - {distances.max():.1f} km")


            self.game_covariates['travel_distance'] = np.array(distances)
            print(f"    Sample distances created, range: {min(distances):.1f} - {max(distances):.1f} km")

    def _prepare_temporal_covariates(self, df):
        """Prepare temporal/seasonal covariates (game-specific)"""

        print("\n3. TEMPORAL EFFECTS:")

        # Check for date columns
        date_cols = [col for col in df.columns if 'date' in col.lower() or 'weekday' in col.lower()]
        print(f"   Available date columns: {date_cols}")

        if date_cols and len(date_cols) > 0:
            # Use the first date column
            date_col = date_cols[0]

            try:
                # Try to parse dates
                dates = pd.to_datetime(df[date_col])

                # Extract temporal features
                self.game_covariates['month'] = dates.dt.month.values
                self.game_covariates['day_of_week'] = dates.dt.dayofweek.values  # 0=Monday
                self.game_covariates['is_weekend'] = (dates.dt.dayofweek >= 5).astype(int).values

                # Season phase (early, mid, late season)
                months = dates.dt.month.values
                season_phase = np.where(months <= 10, 0,  # Early season (Aug-Oct)
                               np.where(months <= 2, 1,   # Mid season (Nov-Feb)
                                       2))               # Late season (Mar-May)
                self.game_covariates['season_phase'] = season_phase

                print(f"    Using column '{date_col}' for temporal effects")
                print(f"   Date range: {dates.min()} to {dates.max()}")

            except Exception as e:
                print(f"   Warning: Could not parse dates from '{date_col}': {e}")
                self._create_sample_temporal_data()


    def _standardize_covariates(self):
        """Standardize all covariates (z-score normalization)"""

        print("\n4. STANDARDIZING COVARIATES:")

        # Standardize team-specific covariates
        team_cov_names = ['stadium_capacity', 'average_attendance', 'capacity_utilization']

        for cov_name in team_cov_names:
            if cov_name in [list(self.team_covariates[0].keys())][0]:
                values = [self.team_covariates[i][cov_name] for i in range(self.n_teams)]
                mean_val = np.mean(values)
                std_val = np.std(values)

                if std_val > 0:
                    standardized_values = [(val - mean_val) / std_val for val in values]
                else:
                    standardized_values = [0.0] * len(values)

                self.standardized_team_covariates[cov_name] = {
                    'values': standardized_values,
                    'mean': mean_val,
                    'std': std_val
                }

                print(f"   {cov_name}: mean={mean_val:.2f}, std={std_val:.2f}")

        # Standardize game-specific covariates
        game_cov_names = ['travel_distance']

        for cov_name in game_cov_names:
            if cov_name in self.game_covariates:
                values = self.game_covariates[cov_name]
                mean_val = np.mean(values)
                std_val = np.std(values)

                if std_val > 0:
                    standardized_values = (values - mean_val) / std_val
                else:
                    standardized_values = np.zeros_like(values)

                self.standardized_game_covariates[cov_name] = {
                    'values': standardized_values,
                    'mean': mean_val,
                    'std': std_val
                }

                print(f"   {cov_name}: mean={mean_val:.2f}, std={std_val:.2f}")

        print("    All covariates standardized")

    def build_basic_model(self):
        """Build the basic hierarchical model (same as original)"""

        print("Building basic hierarchical model...")

        # Prepare data arrays
        home_team_idx = self.data['home_team_idx'].values
        away_team_idx = self.data['away_team_idx'].values
        y1_data = self.data['y1'].values
        y2_data = self.data['y2'].values

        with pm.Model() as model:
            # Home advantage parameter (fixed for all teams)
            home_advantage = pm.Normal("home_advantage", mu=0, tau=0.0001)

            # Hyperparameters for attack and defense effects
            mu_att = pm.Normal("mu_att", mu=0, tau=0.0001)
            mu_def = pm.Normal("mu_def", mu=0, tau=0.0001)
            tau_att = pm.Gamma("tau_att", alpha=0.01, beta=0.01)
            tau_def = pm.Gamma("tau_def", alpha=0.01, beta=0.01)

            # Team-specific attack and defense effects
            att_star = pm.Normal("att_star", mu=mu_att, tau=tau_att, shape=self.n_teams)
            def_star = pm.Normal("def_star", mu=mu_def, tau=tau_def, shape=self.n_teams)

            # Sum-to-zero constraint
            att = pm.Deterministic("att", att_star - pt.mean(att_star))
            def_ = pm.Deterministic("def", def_star - pt.mean(def_star))

            # Scoring intensities
            log_theta_g1 = home_advantage + att[home_team_idx] + def_[away_team_idx]
            log_theta_g2 = att[away_team_idx] + def_[home_team_idx]

            theta_g1 = pm.Deterministic("theta_g1", pt.exp(log_theta_g1))
            theta_g2 = pm.Deterministic("theta_g2", pt.exp(log_theta_g2))

            # Likelihood
            y1 = pm.Poisson("y1", mu=theta_g1, observed=y1_data)
            y2 = pm.Poisson("y2", mu=theta_g2, observed=y2_data)

        print(" Basic model built successfully")
        self.basic_model = model
        return model

    def build_enhanced_stadium_model(self):
        """Build model with team-specific home advantage based on stadium characteristics"""

        print("Building enhanced model with stadium-based home advantage...")

        # Prepare data arrays
        home_team_idx = self.data['home_team_idx'].values
        away_team_idx = self.data['away_team_idx'].values
        y1_data = self.data['y1'].values
        y2_data = self.data['y2'].values

        # Prepare standardized stadium covariates
        capacity_std = np.array(self.standardized_team_covariates['stadium_capacity']['values'])
        utilization_std = np.array(self.standardized_team_covariates['capacity_utilization']['values'])
        attendance_std = np.array(self.standardized_team_covariates['average_attendance']['values'])

        with pm.Model() as model:
            # ========== TEAM-SPECIFIC HOME ADVANTAGE ==========

            # Base home advantage
            home_base = pm.Normal("home_base", mu=0, tau=0.0001)

            # Stadium effects on home advantage
            beta_capacity = pm.Normal("beta_capacity", mu=0, tau=0.001)
            beta_utilization = pm.Normal("beta_utilization", mu=0, tau=0.001)
            beta_attendance = pm.Normal("beta_attendance", mu=0, tau=0.001)

            # Interaction between capacity and utilization
            beta_interaction = pm.Normal("beta_interaction", mu=0, tau=0.002)

            # Team-specific home advantages
            home_advantage_team = pm.Deterministic(
                "home_advantage_team",
                home_base +
                beta_capacity * capacity_std +
                beta_utilization * utilization_std +
                beta_attendance * attendance_std +
                beta_interaction * capacity_std * utilization_std
            )

            # ========== STANDARD TEAM EFFECTS ==========

            mu_att = pm.Normal("mu_att", mu=0, tau=0.0001)
            mu_def = pm.Normal("mu_def", mu=0, tau=0.0001)
            tau_att = pm.Gamma("tau_att", alpha=0.01, beta=0.01)
            tau_def = pm.Gamma("tau_def", alpha=0.01, beta=0.01)

            att_star = pm.Normal("att_star", mu=mu_att, tau=tau_att, shape=self.n_teams)
            def_star = pm.Normal("def_star", mu=mu_def, tau=tau_def, shape=self.n_teams)

            att = pm.Deterministic("att", att_star - pt.mean(att_star))
            def_ = pm.Deterministic("def", def_star - pt.mean(def_star))

            # ========== SCORING INTENSITIES ==========

            # Each game uses the home team's specific home advantage
            log_theta_g1 = home_advantage_team[home_team_idx] + att[home_team_idx] + def_[away_team_idx]
            log_theta_g2 = att[away_team_idx] + def_[home_team_idx]

            theta_g1 = pm.Deterministic("theta_g1", pt.exp(log_theta_g1))
            theta_g2 = pm.Deterministic("theta_g2", pt.exp(log_theta_g2))

            y1 = pm.Poisson("y1", mu=theta_g1, observed=y1_data)
            y2 = pm.Poisson("y2", mu=theta_g2, observed=y2_data)

        print(" Enhanced stadium model built successfully")
        self.enhanced_model = model
        return model

    def build_full_covariate_model(self):
        """Build comprehensive model with all covariates: stadium + distance + temporal"""

        print("Building full covariate model with stadium, distance, and temporal effects...")

        # Prepare data arrays
        home_team_idx = self.data['home_team_idx'].values
        away_team_idx = self.data['away_team_idx'].values
        y1_data = self.data['y1'].values
        y2_data = self.data['y2'].values

        # Prepare all standardized covariates
        capacity_std = np.array(self.standardized_team_covariates['stadium_capacity']['values'])
        utilization_std = np.array(self.standardized_team_covariates['capacity_utilization']['values'])
        attendance_std = np.array(self.standardized_team_covariates['average_attendance']['values'])

        distance_std = self.standardized_game_covariates['travel_distance']['values']
        is_weekend = self.game_covariates['is_weekend']
        season_phase = self.game_covariates['season_phase']

        with pm.Model() as model:
            # ========== TEAM-SPECIFIC HOME ADVANTAGE ==========

            home_base = pm.Normal("home_base", mu=0, tau=0.0001)

            # Stadium effects
            beta_capacity = pm.Normal("beta_capacity", mu=0, tau=0.001)
            beta_utilization = pm.Normal("beta_utilization", mu=0, tau=0.001)
            beta_attendance = pm.Normal("beta_attendance", mu=0, tau=0.001)
            beta_capacity_util = pm.Normal("beta_capacity_util", mu=0, tau=0.002)

            # Team-specific home advantages (stadium-based)
            home_advantage_team = pm.Deterministic(
                "home_advantage_team",
                home_base +
                beta_capacity * capacity_std +
                beta_utilization * utilization_std +
                beta_attendance * attendance_std +
                beta_capacity_util * capacity_std * utilization_std
            )

            # ========== GAME-SPECIFIC EFFECTS ==========

            # Distance effect (reduces away team's performance)
            beta_distance = pm.Normal("beta_distance", mu=0, tau=0.002)
            distance_effect = beta_distance * distance_std

            # Temporal effects
            beta_weekend = pm.Normal("beta_weekend", mu=0, tau=0.002)
            weekend_effect = beta_weekend * is_weekend

            # Season phase effects (3 phases: early, mid, late)
            beta_season = pm.Normal("beta_season", mu=0, tau=0.002, shape=3)
            season_effect = beta_season[season_phase]

            # Combined game-specific home advantage
            home_advantage_game = pm.Deterministic(
                "home_advantage_game",
                home_advantage_team[home_team_idx] +
                distance_effect +  # Positive distance helps home team
                weekend_effect +   # Weekend effect
                season_effect      # Seasonal variation
            )

            # ========== DISTANCE EFFECTS ON AWAY TEAM ==========

            # Distance fatigue reduces away team's attack and defense
            beta_distance_att = pm.Normal("beta_distance_att", mu=0, sigma=0.1)
            beta_distance_def = pm.Normal("beta_distance_def", mu=0, sigma=0.1)

            distance_att_penalty = beta_distance_att * distance_std
            distance_def_penalty = beta_distance_def * distance_std

            # ========== STANDARD TEAM EFFECTS ==========

            mu_att = pm.Normal("mu_att", mu=0, tau=0.0001)
            mu_def = pm.Normal("mu_def", mu=0, tau=0.0001)
            tau_att = pm.Gamma("tau_att", alpha=0.01, beta=0.01)
            tau_def = pm.Gamma("tau_def", alpha=0.01, beta=0.01)

            att_star = pm.Normal("att_star", mu=mu_att, tau=tau_att, shape=self.n_teams)
            def_star = pm.Normal("def_star", mu=mu_def, tau=tau_def, shape=self.n_teams)

            att = pm.Deterministic("att", att_star - pt.mean(att_star))
            def_ = pm.Deterministic("def", def_star - pt.mean(def_star))

            # ========== SCORING INTENSITIES ==========

            # Home team scoring (with full home advantage)
            log_theta_g1 = (home_advantage_game +
                           att[home_team_idx] +
                           def_[away_team_idx])

            # Away team scoring (with distance penalties)
            log_theta_g2 = (att[away_team_idx] + distance_att_penalty +  # Reduced attack due to travel
                           def_[home_team_idx] + distance_def_penalty)   # Reduced defense due to travel

            theta_g1 = pm.Deterministic("theta_g1", pt.exp(log_theta_g1))
            theta_g2 = pm.Deterministic("theta_g2", pt.exp(log_theta_g2))

            y1 = pm.Poisson("y1", mu=theta_g1, observed=y1_data)
            y2 = pm.Poisson("y2", mu=theta_g2, observed=y2_data)

        print(" Full covariate model built successfully")
        self.full_covariate_model = model
        return model

    def fit_basic_model(self, draws=2000, tune=1000, chains=3, cores=1):
        """Fit the basic hierarchical model"""
        print("Fitting basic hierarchical model...")

        if self.basic_model is None:
            self.build_basic_model()

        with self.basic_model:
            self.basic_trace = pm.sample(
                draws=draws, tune=tune, chains=chains, cores=cores,
                random_seed=42,
                return_inferencedata=True,
                target_accept=0.95
            )

            # Sample posterior predictive
            self.basic_trace.extend(pm.sample_posterior_predictive(self.basic_trace))

        print("Basic model fitting completed!")
        return self.basic_trace

    def fit_enhanced_model(self, draws=2000, tune=1000, chains=3, cores=1):
        """Fit the enhanced model with stadium covariates"""
        print("Fitting enhanced model with stadium covariates...")

        if self.enhanced_model is None:
            self.build_enhanced_stadium_model()

        with self.enhanced_model:
            self.enhanced_trace = pm.sample(
                draws=draws, tune=tune, chains=chains, cores=cores,
                random_seed=42, return_inferencedata=True, target_accept=0.9
            )

            self.enhanced_trace.extend(pm.sample_posterior_predictive(self.enhanced_trace))

        print("Enhanced model fitting completed!")
        return self.enhanced_trace

    def fit_full_model(self, draws=2000, tune=1000, chains=3, cores=1):
        """Fit the full covariate model"""
        print("Fitting full covariate model...")

        if self.full_covariate_model is None:
            self.build_full_covariate_model()

        with self.full_covariate_model:
            self.full_trace = pm.sample(
                draws=draws, tune=tune, chains=chains, cores=cores,
                random_seed=42, return_inferencedata=True, target_accept=0.9
            )

            self.full_trace.extend(pm.sample_posterior_predictive(self.full_trace))

        print("Full covariate model fitting completed!")
        return self.full_trace

    def check_convergence(self, trace, model_name="Model"):
        """Comprehensive convergence diagnostics for MCMC chains"""

        print(f"\n{'='*60}")
        print(f"CONVERGENCE DIAGNOSTICS FOR {model_name.upper()}")
        print(f"{'='*60}")

        # 1. R-hat diagnostics
        try:
            rhat = az.rhat(trace)
            max_rhat = float(rhat.max())

            print(f"\n1. R-HAT DIAGNOSTICS:")
            print(f"   Max R-hat: {max_rhat:.4f}")

            if max_rhat < 1.01:
                print("    EXCELLENT: All parameters have R-hat < 1.01")
            elif max_rhat < 1.1:
                print("    GOOD: All parameters have R-hat < 1.1")
            else:
                print("     WARNING: Some parameters have R-hat >= 1.1")
        except Exception as e:
            print(f"   Error calculating R-hat: {e}")

        # 2. Effective Sample Size
        try:
            ess_bulk = az.ess(trace, kind="bulk")
            ess_tail = az.ess(trace, kind="tail")

            min_ess_bulk = float(ess_bulk.min())
            min_ess_tail = float(ess_tail.min())

            print(f"\n2. EFFECTIVE SAMPLE SIZE:")
            print(f"   Min ESS (bulk): {min_ess_bulk:.0f}")
            print(f"   Min ESS (tail): {min_ess_tail:.0f}")

            total_samples = trace.posterior.dims['draw'] * trace.posterior.dims['chain']
            min_recommended = max(100, total_samples // 10)

            if min_ess_bulk >= min_recommended and min_ess_tail >= min_recommended:
                print(f"    GOOD: ESS values above recommended minimum ({min_recommended})")
            else:
                print(f"     WARNING: Some ESS values below recommended minimum ({min_recommended})")
        except Exception as e:
            print(f"   Error calculating ESS: {e}")

        return True

    def analyze_covariate_effects(self, model_type='enhanced'):
        """Analyze the effects of covariates on home advantage and performance"""

        if model_type == 'enhanced':
            trace = self.enhanced_trace
            model_name = "Enhanced Stadium Model"
        elif model_type == 'full':
            trace = self.full_trace
            model_name = "Full Covariate Model"
        else:
            print("Please specify model_type as 'enhanced' or 'full'")
            return None

        if trace is None:
            print(f"Please fit the {model_name} first!")
            return None

        print(f"\n{'='*70}")
        print(f"COVARIATE EFFECTS ANALYSIS - {model_name.upper()}")
        print(f"{'='*70}")

        results = {}

        # Stadium effects on home advantage
        print("\n1. STADIUM EFFECTS ON HOME ADVANTAGE:")

        stadium_effects = ['beta_capacity', 'beta_utilization', 'beta_attendance', 'beta_interaction']

        for effect in stadium_effects:
            if effect in trace.posterior.data_vars:
                samples = trace.posterior[effect]
                mean_val = float(samples.mean())
                ci_low = float(samples.quantile(0.025))
                ci_high = float(samples.quantile(0.975))

                results[effect] = {
                    'mean': mean_val,
                    'ci_low': ci_low,
                    'ci_high': ci_high,
                    'significant': ci_low > 0 or ci_high < 0
                }

                significance = " SIGNIFICANT" if results[effect]['significant'] else "• Not significant"

                print(f"   {effect}: {mean_val:.4f} [{ci_low:.4f}, {ci_high:.4f}] {significance}")

        # Distance and temporal effects (if full model)
        if model_type == 'full':
            print("\n2. DISTANCE EFFECTS:")

            distance_effects = ['beta_distance', 'beta_distance_att', 'beta_distance_def']

            for effect in distance_effects:
                if effect in trace.posterior.data_vars:
                    samples = trace.posterior[effect]
                    mean_val = float(samples.mean())
                    ci_low = float(samples.quantile(0.025))
                    ci_high = float(samples.quantile(0.975))

                    results[effect] = {
                        'mean': mean_val,
                        'ci_low': ci_low,
                        'ci_high': ci_high,
                        'significant': ci_low > 0 or ci_high < 0
                    }

                    significance = " SIGNIFICANT" if results[effect]['significant'] else "• Not significant"
                    print(f"   {effect}: {mean_val:.4f} [{ci_low:.4f}, {ci_high:.4f}] {significance}")

            print("\n3. TEMPORAL EFFECTS:")

            temporal_effects = ['beta_weekend', 'beta_season']

            for effect in temporal_effects:
                if effect in trace.posterior.data_vars:
                    samples = trace.posterior[effect]

                    if samples.ndim == 3:  # beta_season has multiple values
                        for i in range(samples.shape[-1]):
                            phase_samples = samples[..., i]
                            mean_val = float(phase_samples.mean())
                            ci_low = float(phase_samples.quantile(0.025))
                            ci_high = float(phase_samples.quantile(0.975))

                            phase_name = ['Early', 'Mid', 'Late'][i]
                            significance = " SIGNIFICANT" if ci_low > 0 or ci_high < 0 else "• Not significant"

                            print(f"   {effect}[{phase_name}]: {mean_val:.4f} [{ci_low:.4f}, {ci_high:.4f}] {significance}")
                    else:
                        mean_val = float(samples.mean())
                        ci_low = float(samples.quantile(0.025))
                        ci_high = float(samples.quantile(0.975))

                        significance = " SIGNIFICANT" if ci_low > 0 or ci_high < 0 else "• Not significant"
                        print(f"   {effect}: {mean_val:.4f} [{ci_low:.4f}, {ci_high:.4f}] {significance}")

        # Team-specific home advantages
        print("\n4. TEAM-SPECIFIC HOME ADVANTAGES:")

        if 'home_advantage_team' in trace.posterior.data_vars:
            home_team_effects = trace.posterior['home_advantage_team']
            home_means = home_team_effects.mean(dim=['chain', 'draw']).values

            # Create summary dataframe
            home_df = pd.DataFrame({
                'team': self.teams,
                'home_advantage': home_means
            })

            # Add stadium characteristics
            for i, team in enumerate(self.teams):
                home_df.loc[i, 'capacity'] = self.team_covariates[i]['stadium_capacity']
                home_df.loc[i, 'utilization'] = self.team_covariates[i]['capacity_utilization']
                home_df.loc[i, 'attendance'] = self.team_covariates[i]['average_attendance']

            home_df = home_df.sort_values('home_advantage', ascending=False)

            print("\n   Top 5 teams with highest home advantage:")
            print(home_df.head().round(4).to_string(index=False))

            print("\n   Bottom 5 teams with lowest home advantage:")
            print(home_df.tail().round(4).to_string(index=False))

            results['team_home_advantages'] = home_df

        return results

    def plot_covariate_effects(self, model_type='enhanced'):
        """Create visualizations of covariate effects"""

        if model_type == 'enhanced':
            trace = self.enhanced_trace
        elif model_type == 'full':
            trace = self.full_trace
        else:
            print("Please specify model_type as 'enhanced' or 'full'")
            return

        if trace is None:
            print(f"Please fit the {model_type} model first!")
            return

        # Determine number of subplots needed
        if model_type == 'enhanced':
            fig, axes = plt.subplots(2, 2, figsize=(15, 12))
            axes = axes.flatten()
        else:
            fig, axes = plt.subplots(3, 2, figsize=(15, 18))
            axes = axes.flatten()

        plot_idx = 0

        # Plot 1: Stadium effects
        stadium_effects = ['beta_capacity', 'beta_utilization', 'beta_attendance']

        for effect in stadium_effects:
            if effect in trace.posterior.data_vars and plot_idx < len(axes):
                samples = trace.posterior[effect].values.flatten()
                axes[plot_idx].hist(samples, bins=50, alpha=0.7, density=True)
                axes[plot_idx].axvline(0, color='red', linestyle='--', alpha=0.7)
                axes[plot_idx].set_title(f'{effect.replace("beta_", "").title()} Effect')
                axes[plot_idx].set_xlabel('Effect Size')
                axes[plot_idx].set_ylabel('Density')
                plot_idx += 1

        # Plot: Team-specific home advantages vs stadium characteristics
        if 'home_advantage_team' in trace.posterior.data_vars and plot_idx < len(axes):
            home_advantages = trace.posterior['home_advantage_team'].mean(dim=['chain', 'draw']).values
            utilizations = [self.team_covariates[i]['capacity_utilization'] for i in range(self.n_teams)]

            axes[plot_idx].scatter(utilizations, home_advantages, alpha=0.7, s=60)
            axes[plot_idx].set_xlabel('Capacity Utilization')
            axes[plot_idx].set_ylabel('Home Advantage')
            axes[plot_idx].set_title('Home Advantage vs Stadium Utilization')
            axes[plot_idx].grid(True, alpha=0.3)

            # Add team labels for extreme points
            for i, team in enumerate(self.teams):
                if home_advantages[i] > np.percentile(home_advantages, 80) or home_advantages[i] < np.percentile(home_advantages, 20):
                    axes[plot_idx].annotate(team[:8], (utilizations[i], home_advantages[i]),
                                          xytext=(3, 3), textcoords='offset points', fontsize=8)
            plot_idx += 1

        # Distance effects (full model only)
        if model_type == 'full':
            distance_effects = ['beta_distance', 'beta_distance_att']

            for effect in distance_effects:
                if effect in trace.posterior.data_vars and plot_idx < len(axes):
                    samples = trace.posterior[effect].values.flatten()
                    axes[plot_idx].hist(samples, bins=50, alpha=0.7, density=True)
                    axes[plot_idx].axvline(0, color='red', linestyle='--', alpha=0.7)
                    axes[plot_idx].set_title(f'{effect.replace("beta_", "").replace("_", " ").title()} Effect')
                    axes[plot_idx].set_xlabel('Effect Size')
                    axes[plot_idx].set_ylabel('Density')
                    plot_idx += 1

        # Hide unused subplots
        for i in range(plot_idx, len(axes)):
            axes[i].set_visible(False)

        plt.tight_layout()
        plt.show()

    def compare_all_models(self):
        """Compare all fitted models using information criteria and predictive accuracy"""

        print(f"\n{'='*80}")
        print("COMPREHENSIVE MODEL COMPARISON")
        print(f"{'='*80}")

        fitted_models = []
        model_names = []

        if self.basic_trace is not None:
            fitted_models.append(('basic', self.basic_trace))
            model_names.append('Basic')

        if self.enhanced_trace is not None:
            fitted_models.append(('enhanced', self.enhanced_trace))
            model_names.append('Enhanced Stadium')

        if self.full_trace is not None:
            fitted_models.append(('full', self.full_trace))
            model_names.append('Full Covariate')

        if len(fitted_models) < 2:
            print("Need at least 2 fitted models for comparison")
            return None

        print(f"\n1. MODEL SELECTION CRITERIA:")

        model_comparison = {}

        for model_name, trace in fitted_models:
            try:
                waic = az.waic(trace)
                loo = az.loo(trace)

                model_comparison[model_name] = {
                    'waic': float(waic.waic),
                    'waic_se': float(waic.se),
                    'loo': float(loo.loo),
                    'loo_se': float(loo.se)
                }

                print(f"\n   {model_name.upper()} MODEL:")
                print(f"     WAIC: {model_comparison[model_name]['waic']:.2f} ± {model_comparison[model_name]['waic_se']:.2f}")
                print(f"     LOO:  {model_comparison[model_name]['loo']:.2f} ± {model_comparison[model_name]['loo_se']:.2f}")

            except Exception as e:
                print(f"   Error calculating criteria for {model_name}: {e}")

        # Determine best models
        if model_comparison:
            best_waic = min(model_comparison.keys(), key=lambda x: model_comparison[x]['waic'])
            best_loo = min(model_comparison.keys(), key=lambda x: model_comparison[x]['loo'])

            print(f"\n    BEST MODEL BY WAIC: {best_waic.upper()}")
            print(f"    BEST MODEL BY LOO:  {best_loo.upper()}")

        return model_comparison

    def create_enhanced_comparison_table(self):
        """Create comparison table including all fitted models"""

        if self.basic_trace is None:
            print("Please fit at least the basic model first!")
            return None

        # Calculate observed statistics
        observed_stats = []

        for team in self.teams:
            team_data = self.data[
                (self.data['hometeam_name'] == team) |
                (self.data['awayteam_name'] == team)
            ].copy()

            points = 0
            goals_scored = 0
            goals_conceded = 0

            for _, match in team_data.iterrows():
                if match['hometeam_name'] == team:
                    goals_for = match['y1']
                    goals_against = match['y2']
                    if goals_for > goals_against:
                        points += 3
                    elif goals_for == goals_against:
                        points += 1
                else:
                    goals_for = match['y2']
                    goals_against = match['y1']
                    if goals_for > goals_against:
                        points += 3
                    elif goals_for == goals_against:
                        points += 1

                goals_scored += goals_for
                goals_conceded += goals_against

            observed_stats.append({
                'team': team,
                'obs_points': points,
                'obs_scored': goals_scored,
                'obs_conceded': goals_conceded
            })

        # Get predictions from each model
        def get_model_predictions(trace, model_name):
            if trace is None:
                return None

            try:
                # Get posterior predictive samples
                if 'y1' in trace.posterior_predictive.data_vars:
                    y1_pred = trace.posterior_predictive['y1'].values
                    y2_pred = trace.posterior_predictive['y2'].values
                else:
                    print(f"No posterior predictive samples for {model_name}")
                    return None

                # Calculate median predictions
                y1_median = np.median(y1_pred, axis=(0, 1))
                y2_median = np.median(y2_pred, axis=(0, 1))

                pred_stats = []

                for team in self.teams:
                    team_indices = (
                        (self.data['hometeam_name'] == team) |
                        (self.data['awayteam_name'] == team)
                    )
                    team_games = self.data[team_indices].copy()

                    points = 0
                    goals_scored = 0
                    goals_conceded = 0

                    for _, match in team_games.iterrows():
                        game_idx = match.name

                        if match['hometeam_name'] == team:
                            goals_for = y1_median[game_idx]
                            goals_against = y2_median[game_idx]
                        else:
                            goals_for = y2_median[game_idx]
                            goals_against = y1_median[game_idx]

                        if goals_for > goals_against:
                            points += 3
                        elif abs(goals_for - goals_against) < 0.1:
                            points += 1

                        goals_scored += goals_for
                        goals_conceded += goals_against

                    pred_stats.append({
                        'team': team,
                        f'{model_name}_points': int(round(points)),
                        f'{model_name}_scored': int(round(goals_scored)),
                        f'{model_name}_conceded': int(round(goals_conceded))
                    })

                return pred_stats

            except Exception as e:
                print(f"Error getting predictions for {model_name}: {e}")
                return None

        # Get predictions from all models
        basic_preds = get_model_predictions(self.basic_trace, 'basic')
        enhanced_preds = get_model_predictions(self.enhanced_trace, 'enhanced')
        full_preds = get_model_predictions(self.full_trace, 'full')

        # Combine all data
        comparison_data = []
        for i, obs in enumerate(observed_stats):
            row = obs.copy()

            if basic_preds:
                row.update(basic_preds[i])
            if enhanced_preds:
                row.update(enhanced_preds[i])
            if full_preds:
                row.update(full_preds[i])

            comparison_data.append(row)

        df = pd.DataFrame(comparison_data)
        df = df.sort_values('obs_points', ascending=False)

        return df

    def print_enhanced_comparison_table(self):
        """Print formatted comparison table with all models"""

        df = self.create_enhanced_comparison_table()
        if df is None:
            return None

        print(f"\n{'='*140}")
        print("ENHANCED SEASON RESULTS COMPARISON - ALL MODELS")
        print(f"{'='*140}")

        # Determine which models are available
        has_enhanced = 'enhanced_points' in df.columns
        has_full = 'full_points' in df.columns

        # Print headers
        header = f"{'team':15} {'Observed':^20} {'Basic':^20}"
        if has_enhanced:
            header += f" {'Enhanced':^20}"
        if has_full:
            header += f" {'Full':^20}"

        print(header)

        subheader = f"{'':15} {'pts':>6} {'scored':>6} {'conceded':>6} {'pts':>6} {'scored':>6} {'conceded':>6}"
        if has_enhanced:
            subheader += f" {'pts':>6} {'scored':>6} {'conceded':>6}"
        if has_full:
            subheader += f" {'pts':>6} {'scored':>6} {'conceded':>6}"

        print(subheader)
        print("-" * 140)

        # Print data rows
        for _, row in df.iterrows():
            line = f"{row['team']:15}"
            line += f"{row['obs_points']:6d}{row['obs_scored']:6d}{row['obs_conceded']:6d}"
            line += f"{row['basic_points']:6d}{row['basic_scored']:6d}{row['basic_conceded']:6d}"

            if has_enhanced:
                line += f"{row['enhanced_points']:6d}{row['enhanced_scored']:6d}{row['enhanced_conceded']:6d}"
            if has_full:
                line += f"{row['full_points']:6d}{row['full_scored']:6d}{row['full_conceded']:6d}"

            print(line)

        # Calculate and print MAE
        print(f"\n{'='*60}")
        print("MEAN ABSOLUTE ERROR COMPARISON")
        print(f"{'='*60}")

        models = ['basic']
        if has_enhanced:
            models.append('enhanced')
        if has_full:
            models.append('full')

        mae_results = {}

        for model in models:
            points_mae = np.mean(np.abs(df['obs_points'] - df[f'{model}_points']))
            scored_mae = np.mean(np.abs(df['obs_scored'] - df[f'{model}_scored']))
            conceded_mae = np.mean(np.abs(df['obs_conceded'] - df[f'{model}_conceded']))
            total_mae = points_mae + scored_mae + conceded_mae

            mae_results[model] = {
                'points': points_mae,
                'scored': scored_mae,
                'conceded': conceded_mae,
                'total': total_mae
            }

            print(f"\n{model.upper()} MODEL:")
            print(f"  Points MAE:   {points_mae:.2f}")
            print(f"  Scored MAE:   {scored_mae:.2f}")
            print(f"  Conceded MAE: {conceded_mae:.2f}")
            print(f"  Total MAE:    {total_mae:.2f}")

        # Find best model
        best_model = min(mae_results.keys(), key=lambda x: mae_results[x]['total'])
        print(f"\n BEST PREDICTIVE MODEL: {best_model.upper()}")
        print(f"   Total MAE: {mae_results[best_model]['total']:.2f}")

        return df

# ===== ENHANCED USAGE EXAMPLES =====

def run_enhanced_analysis(data_file):
    """
    Run comprehensive analysis with all models
    """

    print("="*80)
    print("ENHANCED BAYESIAN FOOTBALL MODEL WITH COVARIATES")
    print("Stadium Capacity + Attendance + Distance + Temporal Effects")
    print("="*80)

    # Initialize model
    model = BayesianFootballModel(data_file)

    # Fit basic model
    print(f"\n{'='*60}")
    print("FITTING BASIC MODEL")
    print(f"{'='*60}")
    model.fit_basic_model(draws=1000, tune=1000, chains=4)
    model.check_convergence(model.basic_trace, "Basic Model")

    # Fit enhanced stadium model
    print(f"\n{'='*60}")
    print("FITTING ENHANCED STADIUM MODEL")
    print(f"{'='*60}")
    model.fit_enhanced_model(draws=1000, tune=1000, chains=4)
    model.check_convergence(model.enhanced_trace, "Enhanced Stadium Model")

    # Analyze stadium effects
    model.analyze_covariate_effects('enhanced')
    model.plot_covariate_effects('enhanced')

    # Fit full covariate model
    print(f"\n{'='*60}")
    print("FITTING FULL COVARIATE MODEL")
    print(f"{'='*60}")
    try:
        model.fit_full_model(draws=1000, tune=1000, chains=4)
        model.check_convergence(model.full_trace, "Full Covariate Model")

        # Analyze all effects
        model.analyze_covariate_effects('full')
        model.plot_covariate_effects('full')

    except Exception as e:
        print(f"Full model failed: {e}")
        print("Continuing with basic and enhanced models...")

    # Compare all models
    model.compare_all_models()
    model.print_enhanced_comparison_table()

    return model

# Main execution
if __name__ == "__main__":

    print("="*80)
    print("ENHANCED BAYESIAN FOOTBALL MODEL")
    print("With Stadium, Distance, and Temporal Covariates")
    print("="*80)

    # Use the new dataset
    data_file = 'data/dataset/dataset_2007-08_stadium_distance_date.xlsx.xlsx'

    try:
        model = run_enhanced_analysis(data_file)

        print(f"\n SUCCESS! Enhanced analysis completed.")
        print(f"\nAvailable methods:")
        print(f"- model.analyze_covariate_effects('enhanced' or 'full')")
        print(f"- model.plot_covariate_effects('enhanced' or 'full')")
        print(f"- model.compare_all_models()")
        print(f"- model.print_enhanced_comparison_table()")

    except Exception as e:
        print(f"\n Error during analysis: {e}")
        print("Please check that your data file exists and has the correct format.")