In [1]:
from exam_functions import *
import exam_toolbox as et
import numpy as np

In [47]:
def adaboost(delta, rounds=1, weights=None):
    """
    Calculates AdaBoost using weights and considering the order of misclassified observations.

    delta: list of misclassified observations, 0 = correctly classified, 1 = misclassified

    rounds: int, the number of rounds to run, default is 1

    weights: list of weights, default is 1/n

    Returns:
        alpha: the alpha of the classifier
        weights: the updated weights
    """

    delta = np.array(delta)
    n = len(delta)
    if weights is None:
        weights = np.ones(n) / n

    # Initialize error and alpha
    error = np.sum(weights[delta == 1]) / np.sum(weights)
    alpha = 0.5 * np.log((1 - error) / error)

    # Iterate over rounds
    for _ in range(rounds):
        # Update weights
        for idx, misclassified in enumerate(delta):
            if misclassified == 1:
                # Add weights of misclassified observations
                for other_idx in np.where(delta == 1)[0]:
                    if idx != other_idx:
                        weights[idx] += weights[other_idx]

        # Normalize weights
        weights /= np.sum(weights)

    return alpha, weights

In [52]:
w = np.array(
    [
        [0.1000, 0.0714, 0.0469, 0.0319],
        [0.1000, 0.0714, 0.0469, 0.0319],
        [0.1000, 0.1667, 0.1094, 0.2059],
        [0.1000, 0.0714, 0.0469, 0.0319],
        [0.1000, 0.1667, 0.1094, 0.2059],
        [0.1000, 0.0714, 0.0469, 0.0882],
        [0.1000, 0.0714, 0.0469, 0.0319],
        [0.1000, 0.1667, 0.3500, 0.2383],
        [0.1000, 0.0714, 0.1500, 0.1021],
        [0.1000, 0.0714, 0.0469, 0.0319],
    ]
)

w1 = w[:, 0]
w2 = w[:, 1]
w3 = w[:, 2]
w4 = w[:, 3]

miss1 = [0, 0, 1, 0, 1, 0, 0, 1, 0, 0]
miss2 = [0, 0, 0, 0, 0, 0, 0, 1, 1, 0]
miss3 = [0, 0, 1, 0, 1, 1, 0, 0, 0, 0]
miss4 = [0, 0, 1, 0, 0, 1, 0, 0, 0, 0]
adaboost(miss3, weights=w3)

(0.5084110714240822,
 array([0.02174518, 0.02174518, 0.12319177, 0.02174518, 0.19566024,
        0.34059718, 0.02174518, 0.16227745, 0.06954748, 0.02174518]))