In [None]:
class ExpectationMaximization(ParameterEstimator):
    def __init__(self, model, data, **kwargs):

        if not isinstance(model, BayesianNetwork):
            raise NotImplementedError(
                "Expectation Maximization is only implemented for BayesianNetwork"
            )

        super(ExpectationMaximization, self).__init__(model, data, **kwargs)
        self.model_copy = self.model.copy()

    def _get_likelihood(self, datapoint):

        likelihood = 1
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            for cpd in self.model_copy.cpds:
                scope = set(cpd.scope())
                likelihood *= cpd.get_value(
                    **{key: value for key, value in datapoint.items() if key in scope}
                )
        return likelihood

    def _compute_weights(self, latent_card):

        cache = []

        data_unique = self.data.drop_duplicates()
        n_counts = self.data.groupby(list(self.data.columns)).size().to_dict()

        for i in range(data_unique.shape[0]):
            v = list(product(*[range(card) for card in latent_card.values()]))
            latent_combinations = np.array(v, dtype=int)
            df = data_unique.iloc[[i] * latent_combinations.shape[0]].reset_index(
                drop=True
            )
            for index, latent_var in enumerate(latent_card.keys()):
                df[latent_var] = latent_combinations[:, index]

            weights = df.apply(lambda t: self._get_likelihood(dict(t)), axis=1)
            df["_weight"] = (weights / weights.sum()) * n_counts[
                tuple(data_unique.iloc[i])
            ]
            cache.append(df)

        return pd.concat(cache, copy=False), weights.sum()

    def _is_converged(self, new_cpds, atol=1e-08):
        """
        Checks if the values of `new_cpds` is within tolerance limits of current
        model cpds.
        """
        for cpd in new_cpds:
            print(type(cpd))
            if not cpd.__eq__(self.model_copy.get_cpds(node=cpd.scope()[0]), atol=atol):
                return False
        return True

    def get_parameters(
        self,
        latent_card=None,
        max_iter=100,
        atol=1e-08,
        n_jobs=-1,
        seed=None,
        show_progress=True,
    ):

        # Step 1: Parameter checks
        if latent_card is None:
            latent_card = {var: 2 for var in self.model_copy.latents}

        # Step 2: Create structures/variables to be used later.
        n_states_dict = {key: len(value) for key, value in self.state_names.items()}
        n_states_dict.update(latent_card)
        for var in self.model_copy.latents:
            self.state_names[var] = list(range(n_states_dict[var]))

        # Step 3: Initialize random CPDs if starting values aren't provided.
        if seed is not None:
            np.random.seed(seed)

        cpds = []
        for node in self.model_copy.nodes():
            parents = list(self.model_copy.predecessors(node))
            cpds.append(
                TabularCPD.get_random(
                    variable=node,
                    evidence=parents,
                    cardinality={
                        var: n_states_dict[var] for var in chain([node], parents)
                    },
                    state_names={
                        var: self.state_names[var] for var in chain([node], parents)
                    },
                )
            )

        self.model_copy.add_cpds(*cpds)

        if show_progress and SHOW_PROGRESS:
            pbar = tqdm(total=max_iter)

        # Step 4: Run the EM algorithm.
        iter_counter = 0
        for i in range(max_iter):
            print(iter_counter)
            
            # Step 4.1: E-step: Expands the dataset and computes the likelihood of each
            #           possible state of latent variables.
            weighted_data, log_lik = self._compute_weights(latent_card)
            # Step 4.2: M-step: Uses the weights of the dataset to do a weighted MLE.
            new_cpds = MaximumLikelihoodEstimator(
                self.model_copy, weighted_data
            ).get_parameters(n_jobs=n_jobs, weighted=True)
            
            iter_counter += 1
            print(iter_counter)

            # Step 4.3: Check of convergence and max_iter
            if self._is_converged(new_cpds, atol=atol):
                if show_progress and SHOW_PROGRESS:
                    pbar.close()
                return {"cpds": new_cpds,
                        "iter": 'converged',
                        "LL": log_lik
                
                }

            else:

                self.model_copy.cpds = new_cpds
                if show_progress and SHOW_PROGRESS:
                    pbar.update(1)
        

        return {    "cpds": cpds,
                    "iter": "non-converged",
                    "LL": log_lik
        }
