In [122]:
import numpy as np
import torch
import pandas as pd
#from models.BivariateHMM import BivariateHMM
from models.CopulaHMM import CopulaHMM
from utils.Plots import plotEPS_with_states,plotEPS_distribution

In [123]:
DATA_DIR="data/"
MAX_HIDDEN_STATES=5
MIN_HIDDEN_STATES=2

In [124]:
data = pd.read_csv(f"{DATA_DIR}hulls_df_matchday2_reduced.csv")
data = data.dropna()

events=pd.read_csv(f"{DATA_DIR}matchday2_events.csv")
goals_info=events[events["Subtype"].isin(["ON TARGET-GOAL","HEAD-ON TARGET-GOAL","WOODWORK-GOAL"])]
home_goals=goals_info[goals_info["Team"]=="Home"]
away_goals=goals_info[goals_info["Team"]=="Away"]
shots_info=events[events["Type"]=="SHOT"]
home_shot=shots_info[shots_info["Team"]=="Home"]
away_shot=shots_info[shots_info["Team"]=="Away"]


sequence_XY = torch.tensor(data[["HomeHull","AwayHull"]].values/100)

## Quantitative comparison

In [125]:
AIC_list=[]
for state in range(2,MAX_HIDDEN_STATES+1):
    posterior = torch.load(f"parameters/CopulaHMM_matchday2_{state}states.pt")
    model=CopulaHMM.from_posterior(posterior)
    AIC_list.append(model.AIC(sequence_XY).item())

In [126]:
for state,aic in enumerate(AIC_list):
    print(f"AIC of the model with {state+MIN_HIDDEN_STATES} hidden states-> {aic}")

AIC of the model with 2 hidden states-> 18766.37109375
AIC of the model with 3 hidden states-> 17957.654296875
AIC of the model with 4 hidden states-> 17971.048828125
AIC of the model with 5 hidden states-> 17797.033203125


## Qualitative comparison

In [237]:
HIDDEN_STATES=5

In [238]:
posterior=torch.load(f"parameters/CopulaHMM_matchday2_{HIDDEN_STATES}states.pt")

In [239]:
model=CopulaHMM.from_posterior(posterior)

In [240]:
MLS=model.viterbi(sequence_XY)

In [241]:
data["State"]=MLS.numpy()

In [242]:
colors_list=["#FD8033","#0DC2B7","#DAC11E","#D964DC","#37A010"]
class_colors = {k: colors_list[k] for k in range(HIDDEN_STATES)}

In [243]:
p1=plotEPS_distribution(data,class_colors)
p2=plotEPS_with_states(data,home_goals,away_goals,home_shot,away_shot,class_colors)

In [244]:
#p1.savefig(f"plots/EPS_distribution_matchday2_{HIDDEN_STATES}states.png",dpi=350)
#p2.savefig(f"plots/EPS_with_states_matchday2_{HIDDEN_STATES}states.png",dpi=350)