In [8]:
from tensorflow.keras.models import Model
import logging
from typing import Tuple, Union, Dict, Annotated
import numpy as np

In [10]:
class ModelEvaluate:
    """
    Class for model evaluation and calculating predictions
    """
    def __init__(self, 
                 model: Model, 
                 test_dataset: Dict[str, Dict[str, np.ndarray]]
                 ) -> None:
        """
        Args:
            model: Trained tensorflow.keras model
            test_dataset: Dictionary test dataset
        """
        self.model = model
        self.test_labels = test_dataset["labels"]
        self.test_features = test_dataset["input_features"]


    def model_predict(self) -> Tuple[np.ndarray, np.ndarray]:
        """
        Makes predictions on a test dataset
        
        Returns:
            Tuple[np.ndarray, np.ndarray]:
                - 'home_score' predictions
                - 'away_score' predictions
        """
        try:
            predictions = self.model.predict(self.test_features)
            home_score_predictions, away_score_predictions = predictions
            return home_score_predictions, away_score_predictions
        except Exception as e:
            logging.error(f"Error in model prediction: {e}")
            raise e
        
    def calculate_metrics(self) -> Tuple[
        Annotated[float, "loss"],
        Annotated[float, "home_sc_loss"],
        Annotated[float, "away_sc_loss"],
        Annotated[float, "home_sc_rmse"],
        Annotated[float, "away_sc_rmse"]
    ]:
        """
        Calculates model loss

        Returns:
            Tuple[
        Annotated[float, "loss"],
        Annotated[float, "home_sc_mse"],
        Annotated[float, "away_sc_mse"],
        Annotated[float, "home_sc_rmse"],
        Annotated[float, "away_sc_rmse"]
            ]:
            - loss: loss value
            - home_sc_loss: 'home_score' loss value
            - away_sc_loss: 'away_score' loss value
            - home_sc_rmse: 'home_score' Root Mean Squared Error value
            - away_sc_rmse: 'away_score' Root Mean Squared Error value
        """
        try:
            loss, home_sc_loss, away_sc_loss, home_sc_rmse, away_sc_rmse = self.model.evaluate(self.test_features,
                                                                                             self.test_labels)
            return loss, home_sc_loss, away_sc_loss, home_sc_rmse, away_sc_rmse
        except Exception as e:
            logging.error(f"Erorr while model evaluation: {e}")
            raise e 