In [5]:
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
#from models.BivariateHMM import BivariateHMM
from models.CopulaHMM import CopulaHMM
from utils.Plots import plotEPS_with_states, plotEPS_distribution, plotEPS_hist

In [6]:
DATA_DIR="data/"
MAX_HIDDEN_STATES=5
MIN_HIDDEN_STATES=2

In [7]:
data = pd.read_csv(f"{DATA_DIR}hulls_df_matchday2_reduced.csv")
data = data.dropna()

events=pd.read_csv(f"{DATA_DIR}matchday2_events.csv")
goals_info=events[events["Subtype"].isin(["ON TARGET-GOAL","HEAD-ON TARGET-GOAL","WOODWORK-GOAL"])]
home_goals=goals_info[goals_info["Team"]=="Home"]
away_goals=goals_info[goals_info["Team"]=="Away"]
shots_info=events[events["Type"]=="SHOT"]
home_shot=shots_info[shots_info["Team"]=="Home"]
away_shot=shots_info[shots_info["Team"]=="Away"]


sequence_XY = torch.tensor(data[["HomeHull","AwayHull"]].values/100)

## Quantitative comparison

In [8]:
AIC_list=[]
for state in range(2,MAX_HIDDEN_STATES+1):
    posterior = torch.load(f"parameters/CopulaHMM_matchday2_{state}states.pt")
    model=CopulaHMM.from_posterior(posterior)
    AIC_list.append(model.AIC(sequence_XY).item())

In [9]:
for state,aic in enumerate(AIC_list):
    print(f"AIC of the model with {state+MIN_HIDDEN_STATES} hidden states-> {aic}")

AIC of the model with 2 hidden states-> 18766.857421875
AIC of the model with 3 hidden states-> 17959.25390625
AIC of the model with 4 hidden states-> 17970.115234375
AIC of the model with 5 hidden states-> 17792.74609375


## Qualitative comparison

In [10]:
colors_list=["#FD8033","#0DC2B7","#DAC11E","#D964DC","#37A010"]

In [11]:
for state in range(2,MAX_HIDDEN_STATES+1):
    posterior = torch.load(f"parameters/CopulaHMM_matchday2_{state}states.pt")
    model=CopulaHMM.from_posterior(posterior)
    MLS=model.viterbi(sequence_XY)
    data["State"]=MLS.numpy()
    class_colors = {k: colors_list[k] for k in range(state)}
    p1=plotEPS_distribution(data,class_colors)
    p2=plotEPS_with_states(data,home_goals,away_goals,home_shot,away_shot,class_colors)
    p3 = plotEPS_hist(data,class_colors)
    # -----Save-----
    # p1.savefig(f"plots/EPS_distribution_matchday2_{state}states.png",dpi=350)
    # p2.savefig(f"plots/EPS_with_states_matchday2_{state}states.png",dpi=350)
    # p3.savefig(f"plots/EPS_hist_matchday2_{state}states.png",dpi=350)