# Basic and Mixture Models - Implemented on 2022/23 season

In [None]:
import pandas as pd
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

class BayesianFootballModel:
    """
    Bayesian hierarchical model for football match prediction
    Based on Baio & Blangiardo (2010) paper
    """

    def __init__(self, data_file):
        """Initialize the model"""
        print(f"Initializing model with data file: {data_file}")

        # Initialize model attributes first (but not data attributes)
        self.basic_model = None
        self.mixture_model = None
        self.basic_trace = None
        self.mixture_trace = None

        try:
            self.data = self.load_and_prepare_data(data_file)
            print(f"n_games: {self.n_games}, n_teams: {self.n_teams}")
        except Exception as e:
            print(f"data errors during model initialization: {e}")
            raise

    def load_and_prepare_data(self, data_file):
        """Load and prepare the football data"""
        # Read the Excel file
        df = pd.read_excel('/data/dataset/dataset_2022-23_stadium_distance_date.xlsx')

        # Clean column names
        df.columns = df.columns.str.strip()

        print(f"Original data shape: {df.shape}")
        print(f"Columns: {list(df.columns)}")

        # Create team mappings - using correct column names
        all_teams = pd.concat([
            df['hometeam_name'],
            df['awayteam_name']
        ]).unique()

        team_to_id = {team: i for i, team in enumerate(sorted(all_teams))}
        id_to_team = {i: team for team, i in team_to_id.items()}

        # Map team names to consecutive IDs (0-based)
        df['home_team_idx'] = df['hometeam_name'].map(team_to_id)
        df['away_team_idx'] = df['awayteam_name'].map(team_to_id)

        # Check for any mapping issues
        if df['home_team_idx'].isna().any() or df['away_team_idx'].isna().any():
            print("Warning: Some teams could not be mapped!")
            print("Home team mapping issues:", df[df['home_team_idx'].isna()]['hometeam_name'].unique())
            print("Away team mapping issues:", df[df['away_team_idx'].isna()]['awayteam_name'].unique())

        # Store team information
        self.teams = sorted(all_teams)
        self.n_teams = len(self.teams)
        self.n_games = len(df)

        print(f"Data loaded: {self.n_games} games, {self.n_teams} teams")
        print(f"Teams: {self.teams}")

        # Verify no None values
        print(f"n_games type: {type(self.n_games)}, value: {self.n_games}")
        print(f"n_teams type: {type(self.n_teams)}, value: {self.n_teams}")

        return df

    def build_basic_model(self):
        """Building basic hierarchical model from Section 2 of the paper"""

        # Check if data is properly loaded
        if self.n_games is None or self.n_teams is None:
            raise ValueError("Data not properly loaded. check data file and team mappings")

        print(f"Building model with {self.n_games} games and {self.n_teams} teams")

        # Prepare data arrays
        home_team_idx = self.data['home_team_idx'].values
        away_team_idx = self.data['away_team_idx'].values
        y1_data = self.data['y1'].values
        y2_data = self.data['y2'].values

        # Verify data integrity
        print(f"Home team indices range: {home_team_idx.min()} to {home_team_idx.max()}")
        print(f"Away team indices range: {away_team_idx.min()} to {away_team_idx.max()}")
        print(f"Goals range - Home: {y1_data.min()} to {y1_data.max()}, Away: {y2_data.min()} to {y2_data.max()}")

        with pm.Model() as model:
            # Home advantage parameter
            home_advantage = pm.Normal("home_advantage", mu=0, tau=0.0001)

            # Hyperparameters for attack and defense effects
            mu_att = pm.Normal("mu_att", mu=0, tau=0.0001)
            mu_def = pm.Normal("mu_def", mu=0, tau=0.0001)
            tau_att = pm.Gamma("tau_att", alpha=0.01, beta=0.01)
            tau_def = pm.Gamma("tau_def", alpha=0.01, beta=0.01)

            # Team-specific attack and defense effects (before centering)
            att_star = pm.Normal("att_star", mu=mu_att, tau=tau_att, shape=self.n_teams)
            def_star = pm.Normal("def_star", mu=mu_def, tau=tau_def, shape=self.n_teams)

            # Sum-to-zero constraint (centering)
            att = pm.Deterministic("att", att_star - pt.mean(att_star))
            def_ = pm.Deterministic("def", def_star - pt.mean(def_star))

            log_theta_g1 = home_advantage + att[home_team_idx] + def_[away_team_idx]
            log_theta_g2 = att[away_team_idx] + def_[home_team_idx]

            theta_g1 = pm.Deterministic("theta_g1", pt.exp(log_theta_g1))
            theta_g2 = pm.Deterministic("theta_g2", pt.exp(log_theta_g2))

            # Likelihood - each game has its own theta values
            y1 = pm.Poisson("y1", mu=theta_g1, observed=y1_data)
            y2 = pm.Poisson("y2", mu=theta_g2, observed=y2_data)

        print("Successfully Built Model  ")
        self.basic_model = model
        return model

    def build_mixture_model(self):
        """Build the mixture model from Section 4 of the paper"""

        # Prepare data arrays
        home_team_idx = self.data['home_team_idx'].values
        away_team_idx = self.data['away_team_idx'].values
        y1_data = self.data['y1'].values
        y2_data = self.data['y2'].values

        with pm.Model() as model:
            # Home advantage parameter
            home_advantage = pm.Normal("home_advantage", mu=0, tau=0.0001)

            # Mixture parameters for each team
            # Prior probabilities for group membership (3 groups: bottom, mid, top)
            alpha_att = np.ones(3)  # Uniform prior over groups
            alpha_def = np.ones(3)

            p_att = pm.Dirichlet("p_att", a=alpha_att, shape=(self.n_teams, 3))
            p_def = pm.Dirichlet("p_def", a=alpha_def, shape=(self.n_teams, 3))

            # Group assignment for each team
            grp_att = pm.Categorical("grp_att", p=p_att, shape=self.n_teams)
            grp_def = pm.Categorical("grp_def", p=p_def, shape=self.n_teams)

            # Group-specific parameters
            # Group 1: Bottom teams (poor attack, poor defense)
            mu_att_1 = pm.TruncatedNormal("mu_att_1", mu=0, tau=0.001, lower=-3, upper=0)
            mu_def_1 = pm.TruncatedNormal("mu_def_1", mu=0, tau=0.001, lower=0, upper=3)
            tau_att_1 = pm.Gamma("tau_att_1", alpha=0.01, beta=0.01)
            tau_def_1 = pm.Gamma("tau_def_1", alpha=0.01, beta=0.01)

            # Group 2: Mid-table teams (average)
            tau_att_2 = pm.Gamma("tau_att_2", alpha=0.01, beta=0.01)
            tau_def_2 = pm.Gamma("tau_def_2", alpha=0.01, beta=0.01)
            mu_att_2 = pm.Normal("mu_att_2", mu=0, tau=tau_att_2)
            mu_def_2 = pm.Normal("mu_def_2", mu=0, tau=tau_def_2)

            # Group 3: Top teams (good attack, good defense)
            mu_att_3 = pm.TruncatedNormal("mu_att_3", mu=0, tau=0.001, lower=0, upper=3)
            mu_def_3 = pm.TruncatedNormal("mu_def_3", mu=0, tau=0.001, lower=-3, upper=0)
            tau_att_3 = pm.Gamma("tau_att_3", alpha=0.01, beta=0.01)
            tau_def_3 = pm.Gamma("tau_def_3", alpha=0.01, beta=0.01)

            # Stack parameters
            mu_att_groups = pt.stack([mu_att_1, mu_att_2, mu_att_3])
            mu_def_groups = pt.stack([mu_def_1, mu_def_2, mu_def_3])
            tau_att_groups = pt.stack([tau_att_1, tau_att_2, tau_att_3])
            tau_def_groups = pt.stack([tau_def_1, tau_def_2, tau_def_3])

            # Team-specific effects using t-distribution with 4 degrees of freedom (as in paper)
            att_effects = []
            def_effects = []

            for t in range(self.n_teams):
                # For each team, determine which group they belong to and use appropriate parameters
                att_mu_t = pt.switch(pt.eq(grp_att[t], 0), mu_att_groups[0],
                                   pt.switch(pt.eq(grp_att[t], 1), mu_att_groups[1], mu_att_groups[2]))
                att_tau_t = pt.switch(pt.eq(grp_att[t], 0), tau_att_groups[0],
                                    pt.switch(pt.eq(grp_att[t], 1), tau_att_groups[1], tau_att_groups[2]))

                def_mu_t = pt.switch(pt.eq(grp_def[t], 0), mu_def_groups[0],
                                   pt.switch(pt.eq(grp_def[t], 1), mu_def_groups[1], mu_def_groups[2]))
                def_tau_t = pt.switch(pt.eq(grp_def[t], 0), tau_def_groups[0],
                                    pt.switch(pt.eq(grp_def[t], 1), tau_def_groups[1], tau_def_groups[2]))

                # Use StudentT with nu=4 degrees of freedom as in the paper
                att_t = pm.StudentT(f"att_raw_{t}", nu=4, mu=att_mu_t, lam=att_tau_t)
                def_t = pm.StudentT(f"def_raw_{t}", nu=4, mu=def_mu_t, lam=def_tau_t)

                att_effects.append(att_t)
                def_effects.append(def_t)

            att = pt.stack(att_effects)
            def_ = pt.stack(def_effects)

            # Apply sum-to-zero constraint
            att_centered = pm.Deterministic("att_centered", att - pt.mean(att))
            def_centered = pm.Deterministic("def_centered", def_ - pt.mean(def_))

            #  Direct indexing for game-specific theta values (same as basic model)
            log_theta_g1 = home_advantage + att_centered[home_team_idx] + def_centered[away_team_idx]
            log_theta_g2 = att_centered[away_team_idx] + def_centered[home_team_idx]

            theta_g1 = pm.Deterministic("theta_g1", pt.exp(log_theta_g1))
            theta_g2 = pm.Deterministic("theta_g2", pt.exp(log_theta_g2))

            # Likelihood - each game has its own theta values
            y1 = pm.Poisson("y1", mu=theta_g1, observed=y1_data)
            y2 = pm.Poisson("y2", mu=theta_g2, observed=y2_data)

        self.mixture_model = model
        return model

    def fit_basic_model(self, draws=2000, tune=2000, chains=4, cores=1):
        """Fit the basic hierarchical model"""
        print("Fitting basic hierarchical model...")

        if self.basic_model is None:
            self.build_basic_model()

        with self.basic_model:
            # Sample from posterior
            self.basic_trace = pm.sample(
                draws=draws,
                tune=tune,
                chains=chains,
                cores=cores,
                random_seed=42,
                return_inferencedata=True,
                target_accept=0.97  # Higher target acceptance for better sampling
            )

            # Sample posterior predictive
            with self.basic_model:
                self.basic_trace.extend(pm.sample_posterior_predictive(self.basic_trace))

        print("Basic model fitting completed!")
        return self.basic_trace

    def fit_mixture_model(self, draws=2000, tune=2000, chains=4, cores=1):
        """Fit the mixture model"""
        print("Fitting mixture model...")

        if self.mixture_model is None:
            self.build_mixture_model()

        with self.mixture_model:
            # Sample from posterior with simplified parameters
            self.mixture_trace = pm.sample(
                draws=draws,
                tune=tune,
                chains=chains,
                cores=cores,
                random_seed=42,
                return_inferencedata=True,
                target_accept=0.97  # Slightly lower target acceptance to avoid issues
            )

            # Sample posterior predictive
            with self.mixture_model:
                self.mixture_trace.extend(pm.sample_posterior_predictive(self.mixture_trace))

        print("Mixture model fitting completed!")
        return self.mixture_trace

    # ===== REALISTIC SIMULATION METHODS =====

    def get_realistic_model_predictions(self, model_type, n_simulations=1500):
        """
        Get realistic predictions by simulating actual match outcomes
        Using MEDIAN of season totals (exactly like the paper)
        """
        np.random.seed(42)
        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Warning: {model_type} model not fitted, skipping...")
            return None

        # Get posterior samples of scoring intensities (theta values)
        if 'theta_g1' in trace.posterior.data_vars and 'theta_g2' in trace.posterior.data_vars:
            theta1_samples = trace.posterior['theta_g1'].values  # [chains, draws, games]
            theta2_samples = trace.posterior['theta_g2'].values  # [chains, draws, games]
        else:
            print(f"Could not find theta variables in {model_type} model")
            print("Available variables:", list(trace.posterior.data_vars))
            return None

        # Reshape for easier handling
        n_chains, n_draws, n_games = theta1_samples.shape
        theta1_flat = theta1_samples.reshape(-1, n_games)  # [total_samples, games]
        theta2_flat = theta2_samples.reshape(-1, n_games)

        n_available_samples = len(theta1_flat)

        # If we have fewer samples than desired, resample with replacement
        if n_available_samples < n_simulations:
            print(f"Only {n_available_samples} posterior samples available, resampling to get {n_simulations}...")
            resample_indices = np.random.choice(n_available_samples, size=n_simulations, replace=True)
            theta1_sim = theta1_flat[resample_indices]
            theta2_sim = theta2_flat[resample_indices]
            n_samples = n_simulations
        else:
            theta1_sim = theta1_flat[:n_simulations]
            theta2_sim = theta2_flat[:n_simulations]
            n_samples = n_simulations

        print(f"Simulating {n_samples} scenarios for {model_type} model predictions...")

        pred_stats = []

        for team in self.teams:
            # Get indices for this team's games
            team_home_mask = (self.data['hometeam_name'] == team)
            team_away_mask = (self.data['awayteam_name'] == team)
            team_mask = team_home_mask | team_away_mask

            team_games = self.data[team_mask].copy()

            # Store season totals for each simulation
            season_points = []
            season_goals_scored = []
            season_goals_conceded = []
            season_wins = []
            season_draws = []
            season_losses = []

            # For each posterior sample, simulate a complete season
            for sim_idx in range(n_samples):
                sim_points = 0
                sim_goals_scored = 0
                sim_goals_conceded = 0
                sim_wins = 0
                sim_draws = 0
                sim_losses = 0

                # Simulate all matches for this team in this posterior sample
                for _, match in team_games.iterrows():
                    game_idx = match.name  # original index in dataset

                    # Get theta values for this specific simulation
                    game_theta1 = theta1_sim[sim_idx, game_idx]  # home team scoring intensity
                    game_theta2 = theta2_sim[sim_idx, game_idx]  # away team scoring intensity

                    # Simulate actual goals for this specific match
                    simulated_home_goals = np.random.poisson(game_theta1)
                    simulated_away_goals = np.random.poisson(game_theta2)

                    # Determine team's perspective
                    if match['hometeam_name'] == team:
                        # Team is playing at home
                        team_goals = simulated_home_goals
                        opponent_goals = simulated_away_goals
                    else:
                        # Team is playing away
                        team_goals = simulated_away_goals
                        opponent_goals = simulated_home_goals

                    # Update season totals for this simulation
                    sim_goals_scored += team_goals
                    sim_goals_conceded += opponent_goals

                    # Determine match result
                    if team_goals > opponent_goals:
                        sim_points += 3
                        sim_wins += 1
                    elif team_goals == opponent_goals:
                        sim_points += 1
                        sim_draws += 1
                    else:
                        sim_losses += 1

                # Store this simulation's season totals
                season_points.append(sim_points)
                season_goals_scored.append(sim_goals_scored)
                season_goals_conceded.append(sim_goals_conceded)
                season_wins.append(sim_wins)
                season_draws.append(sim_draws)
                season_losses.append(sim_losses)

            # Take MEDIAN of season totals (exactly like the paper)
            pred_stats.append({
                'team': team,
                f'{model_type}_points': int(np.median(season_points)),
                f'{model_type}_scored': int(np.median(season_goals_scored)),
                f'{model_type}_conceded': int(np.median(season_goals_conceded)),
                f'{model_type}_wins': int(np.median(season_wins)),
                f'{model_type}_draws': int(np.median(season_draws)),
                f'{model_type}_losses': int(np.median(season_losses))
            })

        return pred_stats

    def create_extended_comparison_table(self, save_to_file=False, filename="extended_season_comparison.csv"):
        """
        Create a comprehensive comparison table including mixture model results
        Now includes wins, draws, and losses with REALISTIC simulation
        """

        if self.basic_trace is None:
            print("Please fit the basic model first!")
            return None

        # Calculate observed statistics for each team
        observed_stats = []

        for team in self.teams:
            team_data = self.data[
                (self.data['hometeam_name'] == team) |
                (self.data['awayteam_name'] == team)
            ].copy()

            # Calculate points, goals scored, goals conceded, wins, draws, losses
            points = 0
            goals_scored = 0
            goals_conceded = 0
            wins = 0
            draws = 0
            losses = 0

            for _, match in team_data.iterrows():
                if match['hometeam_name'] == team:
                    # Team playing at home
                    goals_for = match['y1']
                    goals_against = match['y2']
                else:
                    # Team playing away
                    goals_for = match['y2']
                    goals_against = match['y1']

                # Calculate match result
                if goals_for > goals_against:
                    points += 3
                    wins += 1
                elif goals_for == goals_against:
                    points += 1
                    draws += 1
                else:
                    losses += 1

                goals_scored += goals_for
                goals_conceded += goals_against

            observed_stats.append({
                'team': team,
                'obs_points': points,
                'obs_scored': goals_scored,
                'obs_conceded': goals_conceded,
                'obs_wins': wins,
                'obs_draws': draws,
                'obs_losses': losses
            })

        # Get REALISTIC predictions from both models using simulation
        basic_predictions = self.get_realistic_model_predictions('basic')
        mixture_predictions = self.get_realistic_model_predictions('mixture') if self.mixture_trace else None

        # Combine all data
        comparison_data = []
        for i, obs in enumerate(observed_stats):
            row = obs.copy()

            # Add basic model predictions
            if basic_predictions:
                row.update(basic_predictions[i])

            # Add mixture model predictions if available
            if mixture_predictions:
                row.update(mixture_predictions[i])

            comparison_data.append(row)

        # Create DataFrame and sort by observed points
        df = pd.DataFrame(comparison_data)
        df = df.sort_values('obs_points', ascending=False)

        # Save to file if requested
        if save_to_file:
            df.to_csv(filename, index=False)
            print(f"Extended comparison table saved to {filename}")

        return df

    # ===== VISUALIZATION METHODS =====

    def get_home_advantage_summary(self, model_type='basic'):
        """Get summary statistics for home advantage effect"""

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Please fit the {model_type} model first!")
            return None

        # Extract home advantage samples
        home_adv_samples = trace.posterior['home_advantage']

        # Calculate summary statistics
        home_summary = {
            'parameter': 'home_advantage',
            'mean': float(home_adv_samples.mean()),
            'median': float(home_adv_samples.median()),
            'std': float(home_adv_samples.std()),
            'q025': float(home_adv_samples.quantile(0.025)),
            'q975': float(home_adv_samples.quantile(0.975))
        }

        return home_summary

    def plot_team_effects(self, model_type='basic'):
        """Plot attack vs defense effects for each team"""

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Please fit the {model_type} model first!")
            return

        # Get posterior means
        if model_type == 'basic':
            att_means = trace.posterior['att'].mean(dim=['chain', 'draw']).values
            def_means = trace.posterior['def'].mean(dim=['chain', 'draw']).values
        else:
            att_means = trace.posterior['att_centered'].mean(dim=['chain', 'draw']).values
            def_means = trace.posterior['def_centered'].mean(dim=['chain', 'draw']).values

        # Create plot
        plt.figure(figsize=(12, 8))
        plt.scatter(att_means, def_means, s=100, alpha=0.7)

        # Add team labels
        for i, team in enumerate(self.teams):
            plt.annotate(team, (att_means[i], def_means[i]),
                        xytext=(5, 5), textcoords='offset points',
                        fontsize=8, alpha=0.8)

        plt.xlabel('Attack Effect')
        plt.ylabel('Defense Effect')
        plt.title(f'Team Attack vs Defense Effects ({model_type.title()} Model)')
        plt.grid(True, alpha=0.3)
        plt.axhline(y=0, color='k', linestyle='--', alpha=0.5)
        plt.axvline(x=0, color='k', linestyle='--', alpha=0.5)

        # Add quadrant labels
        plt.text(0.02, 0.98, 'Poor Attack,\nPoor Defense',
                transform=plt.gca().transAxes, va='top', ha='left',
                bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.5))
        plt.text(0.98, 0.98, 'Good Attack,\nPoor Defense',
                transform=plt.gca().transAxes, va='top', ha='right',
                bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.5))
        plt.text(0.02, 0.02, 'Poor Attack,\nGood Defense',
                transform=plt.gca().transAxes, va='bottom', ha='left',
                bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.5))
        plt.text(0.98, 0.02, 'Good Attack,\nGood Defense',
                transform=plt.gca().transAxes, va='bottom', ha='right',
                bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.5))

        plt.tight_layout()
        plt.show()

    def get_team_summary(self, model_type='basic'):
        """Get summary statistics for team effects"""

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Please fit the {model_type} model first!")
            return None

        # Get parameter names
        if model_type == 'basic':
            att_param = 'att'
            def_param = 'def'
        else:
            att_param = 'att_centered'
            def_param = 'def_centered'

        # Extract posterior samples
        att_samples = trace.posterior[att_param]
        def_samples = trace.posterior[def_param]

        # Calculate summary statistics
        summary_data = []
        for i, team in enumerate(self.teams):
            # Handle different indexing based on parameter dimensions
            if len(att_samples.dims) == 3:  # chain, draw, team
                att_team_samples = att_samples.isel({list(att_samples.dims)[2]: i})
                def_team_samples = def_samples.isel({list(def_samples.dims)[2]: i})
            else:  # Different structure
                att_team_samples = att_samples[..., i]
                def_team_samples = def_samples[..., i]

            att_mean = float(att_team_samples.mean())
            att_q025 = float(att_team_samples.quantile(0.025))
            att_q975 = float(att_team_samples.quantile(0.975))

            def_mean = float(def_team_samples.mean())
            def_q025 = float(def_team_samples.quantile(0.025))
            def_q975 = float(def_team_samples.quantile(0.975))

            summary_data.append({
                'team': team,
                'att_mean': att_mean,
                'att_q025': att_q025,
                'att_q975': att_q975,
                'def_mean': def_mean,
                'def_q025': def_q025,
                'def_q975': def_q975
            })

        summary_df = pd.DataFrame(summary_data)

        # Sort by attack effect (descending)
        summary_df = summary_df.sort_values('att_mean', ascending=False)

        return summary_df

    def print_model_summary(self, model_type='basic', show_all_teams=True):
        """Print comprehensive model summary including home advantage"""

        print(f"\n{model_type.upper()} MODEL SUMMARY")
        print("=" * 60)

        # Home advantage
        home_summary = self.get_home_advantage_summary(model_type)
        if home_summary:
            print(f"\nHOME ADVANTAGE EFFECT:")
            print(f"Mean: {home_summary['mean']:.4f}")
            print(f"95% CI: [{home_summary['q025']:.4f}, {home_summary['q975']:.4f}]")
            print(f"Interpretation: Home teams score exp({home_summary['mean']:.4f}) = {np.exp(home_summary['mean']):.3f}x more goals on average")

        # Team effects
        print(f"\nTEAM EFFECTS:")
        team_summary = self.get_team_summary(model_type)
        if team_summary is not None:
            print("\nTop 5 Attack (most goals scored):")
            print(team_summary.head()[['team', 'att_mean', 'att_q025', 'att_q975']].to_string(index=False))

            print(f"\nTop 5 Defense (fewest goals conceded - most negative values):")
            defense_sorted = team_summary.sort_values('def_mean', ascending=True)
            print(defense_sorted.head()[['team', 'def_mean', 'def_q025', 'def_q975']].to_string(index=False))

            if show_all_teams:
                print(f"\nBottom 5 Attack (fewest goals scored):")
                attack_sorted = team_summary.sort_values('att_mean', ascending=True)
                print(attack_sorted.head()[['team', 'att_mean', 'att_q025', 'att_q975']].to_string(index=False))

                print(f"\nBottom 5 Defense (most goals conceded - most positive values):")
                print(defense_sorted.tail()[['team', 'def_mean', 'def_q025', 'def_q975']].to_string(index=False))

    def predict_match(self, home_team, away_team, model_type='basic', n_samples=1000):
        """Predict the outcome of a specific match"""

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Please fit the {model_type} model first!")
            return None

        # Get team indices
        if home_team not in self.teams or away_team not in self.teams:
            print(f"Team not found. Available teams: {self.teams}")
            return None

        home_idx = self.teams.index(home_team)
        away_idx = self.teams.index(away_team)

        # Get parameter samples
        home_adv_samples = trace.posterior['home_advantage'].values.flatten()

        if model_type == 'basic':
            att_samples = trace.posterior['att'].values
            def_samples = trace.posterior['def'].values
        else:
            att_samples = trace.posterior['att_centered'].values
            def_samples = trace.posterior['def_centered'].values

        # Reshape samples for easier indexing
        att_flat = att_samples.reshape(-1, att_samples.shape[-1])
        def_flat = def_samples.reshape(-1, def_samples.shape[-1])
        home_adv_flat = home_adv_samples.flatten()

        # Take only n_samples
        n_available = min(len(home_adv_flat), len(att_flat))
        n_use = min(n_samples, n_available)

        # Calculate scoring intensities
        theta1_samples = np.exp(home_adv_flat[:n_use] +
                               att_flat[:n_use, home_idx] +
                               def_flat[:n_use, away_idx])

        theta2_samples = np.exp(att_flat[:n_use, away_idx] +
                               def_flat[:n_use, home_idx])

        # Generate predictions
        home_goals = np.random.poisson(theta1_samples)
        away_goals = np.random.poisson(theta2_samples)

        # Calculate probabilities
        home_win = np.mean(home_goals > away_goals)
        draw = np.mean(home_goals == away_goals)
        away_win = np.mean(home_goals < away_goals)

        # Expected goals
        exp_home_goals = np.mean(theta1_samples)
        exp_away_goals = np.mean(theta2_samples)

        return {
            'home_team': home_team,
            'away_team': away_team,
            'expected_home_goals': exp_home_goals,
            'expected_away_goals': exp_away_goals,
            'prob_home_win': home_win,
            'prob_draw': draw,
            'prob_away_win': away_win,
            'home_goals_samples': home_goals,
            'away_goals_samples': away_goals
        }

    # ===== TRACEPLOTS METHODS =====

    def plot_traceplots(self, model_type='basic', var_names=None, figsize=(15, 12)):
        """
        Generate trace plots for MCMC diagnostics

        Parameters:
        -----------
        model_type : str
            'basic' or 'mixture' to specify which model to plot
        var_names : list, optional
            List of variable names to plot. If None, plots key parameters
        figsize : tuple
            Figure size for the plots
        """

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f" fit the {model_type} model first")
            return

        print(f"\nGenerating trace plots for {model_type} model...")

        # Default variables to plot if not specified
        if var_names is None:
            if model_type == 'basic':
                var_names = ['home_advantage', 'mu_att', 'mu_def', 'tau_att', 'tau_def']
                # Add a few team-specific effects for illustration
                var_names.extend(['att', 'def'])
            else:
                var_names = ['home_advantage', 'mu_att_1', 'mu_att_2', 'mu_att_3',
                           'mu_def_1', 'mu_def_2', 'mu_def_3']

        # Filter variables that actually exist in the trace
        available_vars = list(trace.posterior.data_vars)
        var_names = [var for var in var_names if var in available_vars]

        if not var_names:
            print(f"No specified variables found in {model_type} model trace.")
            print(f"Available variables: {available_vars}")
            return

        try:
            # Create trace plots using ArviZ
            ax = az.plot_trace(
                trace,
                var_names=var_names,
                figsize=figsize,
                combined=False,
                compact=False
            )

            plt.suptitle(f'Trace Plots - {model_type.title()} Model', fontsize=16, y=0.98)
            plt.tight_layout()
            plt.show()

            # Print convergence diagnostics
            self.print_convergence_diagnostics(model_type, var_names)

        except Exception as e:
            print(f"Error generating trace plots: {e}")
            print("Trying with simpler variable selection...")

            # Fallback: just plot home_advantage if it exists
            if 'home_advantage' in available_vars:
                az.plot_trace(trace, var_names=['home_advantage'], figsize=(12, 4))
                plt.suptitle(f'Trace Plot: Home Advantage - {model_type.title()} Model')
                plt.tight_layout()
                plt.show()

    def print_convergence_diagnostics(self, model_type='basic', var_names=None):
        """
        Print convergence diagnostics for the specified model

        Parameters:
        -----------
        model_type : str
            'basic' or 'mixture' to specify which model to diagnose
        var_names : list, optional
            List of variable names to diagnose. If None, uses all variables
        """

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"Fit the {model_type} model first")
            return

        print(f"\n{'='*60}")
        print(f"CONVERGENCE DIAGNOSTICS - {model_type.upper()} MODEL")
        print(f"{'='*60}")

        try:
            # Calculate R-hat (should be < 1.1 for good convergence)
            rhat = az.rhat(trace, var_names=var_names)

            # Calculate effective sample size (should be > 400)
            ess = az.ess(trace, var_names=var_names)

            # Handle different return types from ArviZ
            if hasattr(rhat, 'values'):
                # If it's an xarray Dataset/DataArray, get the values
                rhat_values = rhat.values if hasattr(rhat.values, 'flatten') else [rhat.values]
                if hasattr(rhat_values, 'flatten'):
                    rhat_flat = rhat_values.flatten()
                else:
                    rhat_flat = np.array(rhat_values).flatten()
                max_rhat = float(np.nanmax(rhat_flat))
            else:
                # If it's already a scalar or array
                max_rhat = float(np.nanmax(rhat))

            if hasattr(ess, 'values'):
                # If it's an xarray Dataset/DataArray, get the values
                ess_values = ess.values if hasattr(ess.values, 'flatten') else [ess.values]
                if hasattr(ess_values, 'flatten'):
                    ess_flat = ess_values.flatten()
                else:
                    ess_flat = np.array(ess_values).flatten()
                min_ess = float(np.nanmin(ess_flat))
            else:
                # If it's already a scalar or array
                min_ess = float(np.nanmin(ess))

            print(f"Maximum R-hat: {max_rhat:.4f}")
            print(f"Minimum Effective Sample Size: {min_ess:.0f}")
            print()

            # Convergence assessment
            if max_rhat < 1.1:
                print("R-hat indicates good convergence (< 1.1)")
            else:
                print("R-hat indicates potential convergence issues (≥ 1.1)")
                print("Consider running more iterations or increasing tune parameter")

            if min_ess > 400:
                print("Effective sample size is adequate (> 400)")
            else:
                print("Low effective sample size (≤ 400)")
                print("Consider running more iterations")

            # Show detailed diagnostics for problematic parameters
            if max_rhat >= 1.1 or min_ess <= 400:
                print(f"\n Problematic parameters:")

                # Try to extract detailed information
                try:
                    # Convert to pandas for easier handling if possible
                    if hasattr(rhat, 'to_dataframe'):
                        rhat_df = rhat.to_dataframe().reset_index()
                        ess_df = ess.to_dataframe().reset_index()

                        # Find parameters with high R-hat
                        try:
                            value_col = [col for col in rhat_df.columns if col not in ['chain', 'draw']][-1]
                            high_rhat = rhat_df[rhat_df[value_col] >= 1.1]
                            if not high_rhat.empty:
                                print("  High R-hat (≥ 1.1):")
                                for _, row in high_rhat.head(10).iterrows():
                                    param_info = [str(row[col]) for col in rhat_df.columns[:-1]]
                                    print(f"    {param_info}: {row[value_col]:.4f}")
                        except:
                            pass

                        # Find parameters with low ESS
                        try:
                            value_col = [col for col in ess_df.columns if col not in ['chain', 'draw']][-1]
                            low_ess = ess_df[ess_df[value_col] <= 400]
                            if not low_ess.empty:
                                print("  Low ESS (≤ 400):")
                                for _, row in low_ess.head(10).iterrows():
                                    param_info = [str(row[col]) for col in ess_df.columns[:-1]]
                                    print(f"    {param_info}: {row[value_col]:.0f}")
                        except:
                            pass
                    else:
                        # Fallback: just show summary statistics
                        print("  Detailed parameter breakdown not available")
                        print(f"  Overall: Max R-hat = {max_rhat:.4f}, Min ESS = {min_ess:.0f}")

                except Exception as detail_error:
                    print(f"  Could not extract detailed parameter information: {detail_error}")
                    print(f"  Overall: Max R-hat = {max_rhat:.4f}, Min ESS = {min_ess:.0f}")

            print(f"\n{'='*60}")

        except Exception as e:
            print(f"Error calculating convergence diagnostics: {e}")
            print("This might be due to the trace structure or variable names.")

            # Fallback: try to get basic diagnostics for key parameters only
            try:
                print("Attempting basic diagnostics for key parameters...")
                key_params = ['home_advantage']
                if model_type == 'basic':
                    key_params.extend(['mu_att', 'mu_def'])
                else:
                    key_params.extend(['mu_att_1', 'mu_att_2'])

                # Filter to only available parameters
                available_params = [p for p in key_params if p in trace.posterior.data_vars]

                if available_params:
                    basic_rhat = az.rhat(trace, var_names=available_params)
                    basic_ess = az.ess(trace, var_names=available_params)

                    print(f"Basic diagnostics for {available_params}:")
                    for param in available_params:
                        if param in basic_rhat.data_vars:
                            param_rhat = float(basic_rhat[param].values)
                            param_ess = float(basic_ess[param].values)
                            status_rhat = "✓" if param_rhat < 1.1 else "⚠"
                            status_ess = "✓" if param_ess > 400 else "⚠"
                            print(f"  {param}: R-hat = {param_rhat:.4f} {status_rhat}, ESS = {param_ess:.0f} {status_ess}")

            except Exception as fallback_error:
                print(f"Fallback diagnostics also failed: {fallback_error}")
                print("Convergence diagnostics could not be calculated for this model.")

    def plot_team_effect_traceplots(self, model_type='basic', team_names=None, n_teams=5):
        """
        Plot trace plots specifically for team attack and defense effects

        Parameters:
        -----------
        model_type : str
            'basic' or 'mixture' to specify which model to plot
        team_names : list, optional
            Specific team names to plot. If None, plots first n_teams
        n_teams : int
            Number of teams to plot if team_names not specified
        """

        trace = self.basic_trace if model_type == 'basic' else self.mixture_trace

        if trace is None:
            print(f"fit the {model_type} model first!")
            return

        # Determine parameter names
        if model_type == 'basic':
            att_param = 'att'
            def_param = 'def'
        else:
            att_param = 'att_centered'
            def_param = 'def_centered'

        # Check if parameters exist
        if att_param not in trace.posterior.data_vars or def_param not in trace.posterior.data_vars:
            print(f"Team effect parameters not found in {model_type} model trace.")
            print(f"Available variables: {list(trace.posterior.data_vars)}")
            return

        # Select teams to plot
        if team_names is None:
            team_indices = list(range(min(n_teams, len(self.teams))))
            selected_teams = [self.teams[i] for i in team_indices]
        else:
            team_indices = [self.teams.index(team) for team in team_names if team in self.teams]
            selected_teams = team_names

        if not team_indices:
            print("No valid teams found for plotting.")
            return

        print(f"\nGenerating team effect trace plots for: {', '.join(selected_teams)}")

        # Create subplots
        n_teams_plot = len(team_indices)
        fig, axes = plt.subplots(n_teams_plot, 4, figsize=(16, 4*n_teams_plot))

        if n_teams_plot == 1:
            axes = axes.reshape(1, -1)

        for i, (team_idx, team_name) in enumerate(zip(team_indices, selected_teams)):
            # Attack effect traces
            att_samples = trace.posterior[att_param].isel({att_param + '_dim_0': team_idx})
            def_samples = trace.posterior[def_param].isel({def_param + '_dim_0': team_idx})

            # Plot traces for each chain
            for chain in range(att_samples.sizes['chain']):
                axes[i, 0].plot(att_samples.isel(chain=chain), alpha=0.7, label=f'Chain {chain}')
                axes[i, 2].plot(def_samples.isel(chain=chain), alpha=0.7, label=f'Chain {chain}')

            axes[i, 0].set_title(f'{team_name} - Attack Effect (Trace)')
            axes[i, 0].set_ylabel('Attack Effect')
            axes[i, 0].legend()
            axes[i, 0].grid(True, alpha=0.3)

            axes[i, 2].set_title(f'{team_name} - Defense Effect (Trace)')
            axes[i, 2].set_ylabel('Defense Effect')
            axes[i, 2].legend()
            axes[i, 2].grid(True, alpha=0.3)

            # Plot distributions
            att_flat = att_samples.values.flatten()
            def_flat = def_samples.values.flatten()

            axes[i, 1].hist(att_flat, bins=50, alpha=0.7, density=True)
            axes[i, 1].set_title(f'{team_name} - Attack Effect (Distribution)')
            axes[i, 1].set_xlabel('Attack Effect')
            axes[i, 1].set_ylabel('Density')
            axes[i, 1].grid(True, alpha=0.3)

            axes[i, 3].hist(def_flat, bins=50, alpha=0.7, density=True)
            axes[i, 3].set_title(f'{team_name} - Defense Effect (Distribution)')
            axes[i, 3].set_xlabel('Defense Effect')
            axes[i, 3].set_ylabel('Density')
            axes[i, 3].grid(True, alpha=0.3)

        plt.suptitle(f'Team Effects Trace Plots - {model_type.title()} Model', fontsize=16)
        plt.tight_layout()
        plt.show()

    def generate_all_traceplots(self, model_type='basic'):
        """
        Generate comprehensive trace plots for model diagnostics

        Parameters:
        -----------
        model_type : str
            'basic' or 'mixture' to specify which model to plot
        """

        print(f"\n{'='*70}")
        print(f"COMPREHENSIVE TRACE PLOT ANALYSIS - {model_type.upper()} MODEL")
        print(f"{'='*70}")

        # 1. Main parameter trace plots
        print("\n1. Main Parameters Trace Plots:")
        self.plot_traceplots(model_type)

        # 2. Team effects trace plots (for a subset of teams)
        print("\n2. Team Effects Trace Plots:")
        self.plot_team_effect_traceplots(model_type, n_teams=3)

        # 3. Convergence diagnostics
        print("\n3. Convergence Diagnostics:")
        self.print_convergence_diagnostics(model_type)

        print(f"\n{'='*70}")
        print("TRACE PLOT ANALYSIS COMPLETE")
        print(f"{'='*70}")

    def print_extended_comparison_table(self, show_errors=True):
        """
        Print a formatted comparison table with both basic and mixture models
        Now includes wins, draws, and losses with REALISTIC simulation
        """

        df = self.create_extended_comparison_table()
        if df is None:
            return

        print("\n" + "="*170)
        print("EXTENDED SEASON RESULTS COMPARISON - OBSERVED vs BASIC vs MIXTURE MODELS")
        print("="*170)

        # Check if mixture model results are available
        has_mixture = 'mixture_points' in df.columns

        # Print headers
        if has_mixture:
            print(f"{'team':15} {'Observed results':^45} {'Basic model (medians)':^45} {'Mixture model (medians)':^45}")
            print(f"{'':15} {'pts':>5} {'sc':>4} {'co':>4} {'W':>3} {'D':>3} {'L':>3} " +
                  f"{'pts':>5} {'sc':>4} {'co':>4} {'W':>3} {'D':>3} {'L':>3} " +
                  f"{'pts':>5} {'sc':>4} {'co':>4} {'W':>3} {'D':>3} {'L':>3}")
        else:
            print(f"{'team':15} {'Observed results':^45} {'Basic model (medians)':^45}")
            print(f"{'':15} {'pts':>5} {'sc':>4} {'co':>4} {'W':>3} {'D':>3} {'L':>3} " +
                  f"{'pts':>5} {'sc':>4} {'co':>4} {'W':>3} {'D':>3} {'L':>3}")

        print("-" * 170)

        # Print data rows
        for _, row in df.iterrows():
            print(f"{row['team']:15}", end="")
            # Observed
            print(f"{row['obs_points']:5d}", end="")
            print(f"{row['obs_scored']:4d}", end="")
            print(f"{row['obs_conceded']:4d}", end="")
            print(f"{row['obs_wins']:3d}", end="")
            print(f"{row['obs_draws']:3d}", end="")
            print(f"{row['obs_losses']:3d}", end="")
            # Basic
            print(f"{row['basic_points']:5d}", end="")
            print(f"{row['basic_scored']:4d}", end="")
            print(f"{row['basic_conceded']:4d}", end="")
            print(f"{row['basic_wins']:3d}", end="")
            print(f"{row['basic_draws']:3d}", end="")
            print(f"{row['basic_losses']:3d}", end="")

            if has_mixture:
                # Mixture
                print(f"{row['mixture_points']:5d}", end="")
                print(f"{row['mixture_scored']:4d}", end="")
                print(f"{row['mixture_conceded']:4d}", end="")
                print(f"{row['mixture_wins']:3d}", end="")
                print(f"{row['mixture_draws']:3d}", end="")
                print(f"{row['mixture_losses']:3d}", end="")
            print()

        # Print summary statistics
        if show_errors:
            print("\n" + "="*90)
            print("MODEL PERFORMANCE COMPARISON - MEAN ABSOLUTE ERROR")
            print("="*90)

            # Calculate mean absolute errors for basic model
            basic_points_mae = np.mean(np.abs(df['obs_points'] - df['basic_points']))
            basic_scored_mae = np.mean(np.abs(df['obs_scored'] - df['basic_scored']))
            basic_conceded_mae = np.mean(np.abs(df['obs_conceded'] - df['basic_conceded']))
            basic_wins_mae = np.mean(np.abs(df['obs_wins'] - df['basic_wins']))
            basic_draws_mae = np.mean(np.abs(df['obs_draws'] - df['basic_draws']))
            basic_losses_mae = np.mean(np.abs(df['obs_losses'] - df['basic_losses']))

            print(f"Basic Model Performance:")
            print(f"  Points MAE:          {basic_points_mae:.2f}")
            print(f"  Goals Scored MAE:    {basic_scored_mae:.2f}")
            print(f"  Goals Conceded MAE:  {basic_conceded_mae:.2f}")
            print(f"  Wins MAE:            {basic_wins_mae:.2f}")
            print(f"  Draws MAE:           {basic_draws_mae:.2f}")
            print(f"  Losses MAE:          {basic_losses_mae:.2f}")
            basic_total = (basic_points_mae + basic_scored_mae + basic_conceded_mae +
                          basic_wins_mae + basic_draws_mae + basic_losses_mae)
            print(f"  Total MAE:           {basic_total:.2f}")

            if has_mixture:
                mixture_points_mae = np.mean(np.abs(df['obs_points'] - df['mixture_points']))
                mixture_scored_mae = np.mean(np.abs(df['obs_scored'] - df['mixture_scored']))
                mixture_conceded_mae = np.mean(np.abs(df['obs_conceded'] - df['mixture_conceded']))
                mixture_wins_mae = np.mean(np.abs(df['obs_wins'] - df['mixture_wins']))
                mixture_draws_mae = np.mean(np.abs(df['obs_draws'] - df['mixture_draws']))
                mixture_losses_mae = np.mean(np.abs(df['obs_losses'] - df['mixture_losses']))

                print(f"\nMixture Model Performance:")
                print(f"  Points MAE:          {mixture_points_mae:.2f}")
                print(f"  Goals Scored MAE:    {mixture_scored_mae:.2f}")
                print(f"  Goals Conceded MAE:  {mixture_conceded_mae:.2f}")
                print(f"  Wins MAE:            {mixture_wins_mae:.2f}")
                print(f"  Draws MAE:           {mixture_draws_mae:.2f}")
                print(f"  Losses MAE:          {mixture_losses_mae:.2f}")
                mixture_total = (mixture_points_mae + mixture_scored_mae + mixture_conceded_mae +
                               mixture_wins_mae + mixture_draws_mae + mixture_losses_mae)
                print(f"  Total MAE:           {mixture_total:.2f}")

                # Compare models
                print(f"\n" + "="*60)
                print("MODEL COMPARISON SUMMARY")
                print("="*60)

                if basic_total < mixture_total:
                    print(f"🏆 BASIC MODEL WINS with lower total MAE ({basic_total:.2f} vs {mixture_total:.2f})")
                    improvement = ((mixture_total - basic_total) / mixture_total) * 100
                    print(f"   Basic model is {improvement:.1f}% better overall")
                else:
                    print(f"🏆 MIXTURE MODEL WINS with lower total MAE ({mixture_total:.2f} vs {basic_total:.2f})")
                    improvement = ((basic_total - mixture_total) / basic_total) * 100
                    print(f"   Mixture model is {improvement:.1f}% better overall")

                # Individual category winners
                print(f"\nCategory Winners:")
                print(f"  Points:   {'Basic' if basic_points_mae < mixture_points_mae else 'Mixture'} ({min(basic_points_mae, mixture_points_mae):.2f})")
                print(f"  Scored:   {'Basic' if basic_scored_mae < mixture_scored_mae else 'Mixture'} ({min(basic_scored_mae, mixture_scored_mae):.2f})")
                print(f"  Conceded: {'Basic' if basic_conceded_mae < mixture_conceded_mae else 'Mixture'} ({min(basic_conceded_mae, mixture_conceded_mae):.2f})")
                print(f"  Wins:     {'Basic' if basic_wins_mae < mixture_wins_mae else 'Mixture'} ({min(basic_wins_mae, mixture_wins_mae):.2f})")
                print(f"  Draws:    {'Basic' if basic_draws_mae < mixture_draws_mae else 'Mixture'} ({min(basic_draws_mae, mixture_draws_mae):.2f})")
                print(f"  Losses:   {'Basic' if basic_losses_mae < mixture_losses_mae else 'Mixture'} ({min(basic_losses_mae, mixture_losses_mae):.2f})")

        return df

    def detailed_model_analysis(self):
        """
        Perform detailed analysis comparing both models (only if both are fitted)
        Now includes wins, draws, and losses analysis
        """

        df = self.create_extended_comparison_table()
        if df is None or 'mixture_points' not in df.columns:
            print("Both models need to be fitted for detailed comparison analysis.")
            return

        print(f"\n" + "="*70)
        print("DETAILED MODEL DIFFERENCES ANALYSIS")
        print("="*70)

        # Calculate differences for all metrics
        df['points_diff'] = df['mixture_points'] - df['basic_points']
        df['scored_diff'] = df['mixture_scored'] - df['basic_scored']
        df['conceded_diff'] = df['mixture_conceded'] - df['basic_conceded']
        df['wins_diff'] = df['mixture_wins'] - df['basic_wins']
        df['draws_diff'] = df['mixture_draws'] - df['basic_draws']
        df['losses_diff'] = df['mixture_losses'] - df['basic_losses']

        # Show teams where mixture model performs significantly better
        print("\nTeams where MIXTURE model predicts closer to observed results:")
        mixture_better = df[
            (abs(df['obs_points'] - df['mixture_points']) < abs(df['obs_points'] - df['basic_points']))
        ].copy()

        if len(mixture_better) > 0:
            mixture_better['basic_error'] = abs(mixture_better['obs_points'] - mixture_better['basic_points'])
            mixture_better['mixture_error'] = abs(mixture_better['obs_points'] - mixture_better['mixture_points'])
            mixture_better['improvement'] = mixture_better['basic_error'] - mixture_better['mixture_error']
            mixture_better = mixture_better.sort_values('improvement', ascending=False)

            for _, team in mixture_better.head(5).iterrows():
                print(f"  {team['team']:15}: Basic error {team['basic_error']:.1f}, Mixture error {team['mixture_error']:.1f} (improvement: {team['improvement']:.1f})")
        else:
            print("  No teams found where mixture model significantly outperforms basic model")

        # Show teams where basic model performs significantly better
        print("\nTeams where BASIC model predicts closer to observed results:")
        basic_better = df[
            (abs(df['obs_points'] - df['basic_points']) < abs(df['obs_points'] - df['mixture_points']))
        ].copy()

        if len(basic_better) > 0:
            basic_better['basic_error'] = abs(basic_better['obs_points'] - basic_better['basic_points'])
            basic_better['mixture_error'] = abs(basic_better['obs_points'] - basic_better['mixture_points'])
            basic_better['improvement'] = basic_better['mixture_error'] - basic_better['basic_error']
            basic_better = basic_better.sort_values('improvement', ascending=False)

            for _, team in basic_better.head(5).iterrows():
                print(f"  {team['team']:15}: Mixture error {team['mixture_error']:.1f}, Basic error {team['basic_error']:.1f} (improvement: {team['improvement']:.1f})")
        else:
            print("  No teams found where basic model significantly outperforms mixture model")

        # Show largest prediction differences between models
        print(f"\nLargest differences between Basic and Mixture model predictions:")
        df['total_abs_diff'] = (abs(df['points_diff']) + abs(df['scored_diff']) + abs(df['conceded_diff']) +
                               abs(df['wins_diff']) + abs(df['draws_diff']) + abs(df['losses_diff']))
        biggest_diffs = df.nlargest(5, 'total_abs_diff')

        for _, team in biggest_diffs.iterrows():
            print(f"  {team['team']:15}: Pts {team['points_diff']:+3.0f}, W {team['wins_diff']:+2.0f}, D {team['draws_diff']:+2.0f}, L {team['losses_diff']:+2.0f}")

        print(f"\n{'='*70}")
        print("SUMMARY STATISTICS")
        print(f"{'='*70}")
        print(f"Teams where mixture model is better: {len(mixture_better)}")
        print(f"Teams where basic model is better: {len(basic_better)}")
        print(f"Average absolute difference in points: {abs(df['points_diff']).mean():.2f}")
        print(f"Average absolute difference in goals scored: {abs(df['scored_diff']).mean():.2f}")
        print(f"Average absolute difference in goals conceded: {abs(df['conceded_diff']).mean():.2f}")
        print(f"Average absolute difference in wins: {abs(df['wins_diff']).mean():.2f}")
        print(f"Average absolute difference in draws: {abs(df['draws_diff']).mean():.2f}")
        print(f"Average absolute difference in losses: {abs(df['losses_diff']).mean():.2f}")

        # Correlation between observed and predicted for key metrics
        basic_corr_points = df[['obs_points', 'basic_points']].corr().iloc[0,1]
        mixture_corr_points = df[['obs_points', 'mixture_points']].corr().iloc[0,1]

        basic_corr_wins = df[['obs_wins', 'basic_wins']].corr().iloc[0,1]
        mixture_corr_wins = df[['obs_wins', 'mixture_wins']].corr().iloc[0,1]

        print(f"\nCorrelation with observed results:")
        print(f"Points - Basic: {basic_corr_points:.3f}, Mixture: {mixture_corr_points:.3f}")
        print(f"Wins   - Basic: {basic_corr_wins:.3f}, Mixture: {mixture_corr_wins:.3f}")
        print(f"Points winner: {'Mixture' if mixture_corr_points > basic_corr_points else 'Basic'} model")
        print(f"Wins winner:   {'Mixture' if mixture_corr_wins > basic_corr_wins else 'Basic'} model")

        return df

    # ===== COMBINED ANALYSIS METHODS =====

    def run_complete_analysis(self, draws_basic=250, draws_mixture=50, save_results=True, include_traceplots=True):
        """
        Run complete analysis with both models including visualizations and comparisons
        """

        print("="*70)
        print("COMPLETE BAYESIAN FOOTBALL MODEL ANALYSIS")
        print("="*70)

        # Fit basic model
        print("\n" + "="*50)
        print("FITTING BASIC MODEL")
        print("="*50)
        basic_trace = self.fit_basic_model(draws=draws_basic, tune=draws_basic, chains=4)

        # Basic model analysis
        print("\n" + "="*50)
        print("BASIC MODEL ANALYSIS")
        print("="*50)
        self.print_model_summary('basic', show_all_teams=True)

        # Plot basic model effects
        self.plot_team_effects('basic')

        # Generate trace plots for basic model
        if include_traceplots:
            #self.plot_traceplots('basic')
            self.generate_all_traceplots('basic')

        # Fit mixture model
        print("\n" + "="*50)
        print("FITTING MIXTURE MODEL")
        print("="*50)
        try:
            mixture_trace = self.fit_mixture_model(draws=draws_mixture, tune=draws_mixture, chains=4)
            print("Successfully Fitted Mixture model")

            # Mixture model analysis
            print("\n" + "="*50)
            print("MIXTURE MODEL ANALYSIS")
            print("="*50)
            self.print_model_summary('mixture', show_all_teams=True)

            # Plot mixture model effects
            self.plot_team_effects('mixture')

            # Generate trace plots for mixture model
            if include_traceplots:
                #self.plot_traceplots('mixture')
                self.generate_all_traceplots('mixture')

        except Exception as e:
            print(f"Failed Mixture model : {e}")
            print("Continuing with basic model only...")

        # Generate comparison tables
        print("\n" + "="*50)
        print("SEASON COMPARISON ANALYSIS")
        print("="*50)

        comparison_df = self.print_extended_comparison_table(show_errors=True)

        # Detailed analysis if both models fitted
        if self.mixture_trace is not None:
            self.detailed_model_analysis()

        # Save results
        if save_results:
            if comparison_df is not None:
                filename = f"football_analysis_results_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.csv"
                self.create_extended_comparison_table(save_to_file=True, filename=filename)
                print(f"\n Results saved to: {filename}")

        # Example predictions
        print("\n" + "="*50)
        print("EXAMPLE MATCH PREDICTIONS")
        print("="*50)

        # Try some example predictions
        example_matches = [
            ('Internazionale', 'Milan'),
            ('Roma', 'Lazio'),
            ('Juventus', 'Napoli')
        ]

        for home, away in example_matches:
            try:
                # Basic model prediction
                pred_basic = self.predict_match(home, away, 'basic')
                if pred_basic:
                    print(f"\n{home} vs {away} (Basic Model):")
                    print(f"  Expected goals: {pred_basic['expected_home_goals']:.2f} - {pred_basic['expected_away_goals']:.2f}")
                    print(f"  Probabilities: Home {pred_basic['prob_home_win']:.3f}, Draw {pred_basic['prob_draw']:.3f}, Away {pred_basic['prob_away_win']:.3f}")

                # Mixture model prediction (if available)
                if self.mixture_trace is not None:
                    pred_mixture = self.predict_match(home, away, 'mixture')
                    if pred_mixture:
                        print(f"{home} vs {away} (Mixture Model):")
                        print(f"  Expected goals: {pred_mixture['expected_home_goals']:.2f} - {pred_mixture['expected_away_goals']:.2f}")
                        print(f"  Probabilities: Home {pred_mixture['prob_home_win']:.3f}, Draw {pred_mixture['prob_draw']:.3f}, Away {pred_mixture['prob_away_win']:.3f}")

            except Exception as e:
                print(f"Could not predict {home} vs {away}: {e}")

        print(f"\n{'='*70}")
        print("ANALYSIS COMPLETE!")
        print(f"{'='*70}")

        return {
            'basic_trace': self.basic_trace,
            'mixture_trace': self.mixture_trace,
            'comparison_df': comparison_df
        }

# ===== USAGE EXAMPLE =====

def run_example_analysis(data_file):
    """
    Example function showing how to use the integrated model
    """

    # Initialize the model
    print("Initializing Bayesian Football Model...")
    model = BayesianFootballModel(data_file)

    # Run complete analysis
    results = model.run_complete_analysis(
        draws_basic=2000,    # Adjust based on your computational resources
        draws_mixture=2000,   # Mixture model typically needs fewer draws
        save_results=True,
        include_traceplots=True  # Set to False if you don't want trace plots
    )

    return model, results

# Main execution
if __name__ == "__main__":

    print("="*70)
    print("INTEGRATED BAYESIAN FOOTBALL MODEL")
    print("Combines Visualization + Comparison Tables + Predictions + Traceplots")
    print("="*70)

    # Replace with your actual file path
    data_file = '/data/dataset/final_dataset_2022-23_stadium&distance&date.xlsx'

    try:
        # Run the complete analysis
        model, results = run_example_analysis(data_file)

        print("\n SUCCESS! All analysis completed.")
        print("\n You can now use the model object to:")
        print("- model.predict_match('Team1', 'Team2', 'basic')")
        print("- model.plot_team_effects('mixture')")
        print("- model.print_model_summary('basic')")
        print("- model.create_extended_comparison_table()")
        print("- model.generate_all_traceplots('basic')")
        print("- model.plot_traceplots('mixture')")
        print("- model.plot_team_effect_traceplots('basic', ['Juventus', 'Milan'])")


    except Exception as e:
        print(f"\n Error during analysis: {e}")
        print("Check that your data file exists and has the correct format.")
        print("Required columns: hometeam_name, awayteam_name, y1, y2")