In [None]:
#default_exp ets

In [None]:
#hide
import warnings
warnings.simplefilter('ignore')

In [None]:
#export
import math
import os
import sys
import warnings
from collections import namedtuple
from functools import partial
from typing import Optional, Dict, Union, Tuple

import numpy as np
import pandas as pd
from numba import njit

In [None]:
#exporti
# Global variables 
eta = sys.float_info.epsilon
are = eta 
mre = 2. * math.sqrt(2) * sys.float_info.epsilon
infin = sys.float_info.max
smalno = sys.float_info.epsilon
base = sys.float_info.radix

### Independent Complex Polynomial Utilities

In [None]:
#exporti
@njit
def polyev(n: int, s_r: float, s_i: float, p_r: float, 
           p_i: float, q_r: float, q_i: float, 
           v_r: float, v_i: float):
    # Evaluates a polynomial p at s by the horner recurrence placing the partial sums 
    # in q and the computed value in v_. 
    q_r[0] = p_r[0]
    q_i[0] = p_i[0]
    v_r = q_r[0]
    v_i = q_i[0]
    for i in range(1, n):
        t = v_r * s_r - v_i * s_i + p_r[i]
        q_i[i] = v_i = v_r * s_i + v_i * s_r + p_i[i]
        q_r[i] = v_r = t

In [None]:
#exporti
@njit
def errev(n: int, qr: float, qi: float, 
          ms: float, mp: float, 
          a_re: float, m_re: float) -> float:
    # bounds the error in evaluating the polynomial by the horner recurrence. 
    # qr, qi - the partial sum vectors 
    # ms - modulus of the point 
    # mp - modulus of polynomial value
    # a_re, m_re - error bounds on complex addition and multiplication 
    e = math.hypot(qr[0], qi[0]) * m_re / (a_re + m_re)
    for i in range(n):
        e = e * ms + math.hypot(qr[i], qi[i])
    return e * (a_re + m_re) - mp * m_re 

In [None]:
#exporti
@njit
def cpoly_cauchy(n: int, pot: float, q: float) -> float:
    # Computes a lower bound on the moduli of the zeros of a polynomial
    # pot[1:nn] is the modulus of the coefficients 
    n1 = n - 1
    pot[n1] = -pot[n1]
    # Compute upper estimate of bound 
    x = math.exp((math.log(-pot[n1]) - math.log(pot[0])) / n1)
    # If newton step at the origin is better, use it
    if pot[n1 - 1] != 0.:
        xm = -pot[n1] / pot[n1 - 1] 
        if xm < x:
            x = xm 
    # Chop the interval (0, x) until f < 0
    while True:
        xm = x * 0.1 
        f = pot[0]
        for i in range(1, n):
            f = f * xm + pot[i]
        if f <= 0.:
            break 
        x = xm
    dx = x 
    # Do Newton iteration until x converges to two decimal places
    while math.fabs(dx / x) > 0.005: 
        q[0] = pot[0]
        for i in range(1, n):
            q[i] = q[i-1] * x + pot[i]
        f = q[n1]
        delf = q[0]
        for i in range(1, n1):
            delf = delf * x + q[i]
        dx = f / delf 
        x -= dx
    return x

In [None]:
#exporti
@njit
def cpoly_scale(n: int, pot: float, eps: float, 
                BIG: float, small: float, base: float) -> float:
    # Returns a scale factor to multiply the coefficients of the polynomial.
    # * The scaling is done to avoid overflow and to avoid
    # *	undetected underflow interfering with the convergence criterion.
    # * The factor is a power of the base.
    # * pot[1:n] : modulus of coefficients of p
    # * eps,BIG,
    # * small,base - constants describing the floating point arithmetic.
    # find largest and smallest moduli of coefficients
    high = math.sqrt(BIG)
    lo = small / eps 
    max_ = 0. 
    min_ = BIG
    for i in range(n):
        x = pot[i] 
        if x > max_:
            max_ = x 
        if x != 0. and x < min_:
            min_ = x
    # scale only if there are very large or very small components
    if min_ < lo or max_ > high:
        x = lo / min_ 
        if x <= 1.:
            sc = 1. / (math.sqrt(max_) * math.sqrt(min_))
        else:
            sc = x
            if BIG / sc > max_:
                sc = 1.0 
        ell = int(math.log(sc) / math.log(base) + 0.5)
        return math.pow(base, ell)  
    else:
        return 1.0 

In [None]:
#exporti
@njit
def cdivid(ar: float, ai: float, br: float, bi: float, 
           cr: float, ci: float):
    # complex division c = a/b, i.e., (cr +i*ci) = (ar +i*ai) / (br +i*bi), avoiding overflow.
    if br == 0. and bi == 0.:
        # division by zero, c = infinity. 
        cr = ci = -np.Infinity # make sure to change this later (R_posInf)
    elif math.fabs(br) >= math.fabs(bi):
        r = bi / br
        d = br + r * bi
        cr = (ar + r * ai) / d
        ci = (ai - r * ar) / d
    else:
        r = br / bi
        d = bi + r * br
        cr = (ar * r + ai) / d
        ci = (ai * r - ar) / d

### cpolyroot

In [None]:
#exporti
@njit
def noshft(l1, nn, tr, ti, pr, pi, hr, hi):
    n = nn - 1
    nm1 = n - 1 
    
    for i in range(n):
        xni = float(nn - i - 1)
        hr[i] = xni * pr[i] / n 
        hi[i] = xni * pi[i] / n

    for jj in range(1, l1 + 1):
        constant_term = math.hypot(hr[n - 1], hi[n - 1])
        comparison = eta * 10.0 * math.hypot(pr[n-1], pi[n-1])
        if constant_term <= comparison:
            # If the constant term is essentially zero,
            # then shift h coefficients. 
            for i in range(1, nm1 + 1):
                j = nn - i 
                hr[j-1] = hr[j-2]
                hi[j-1] = hi[j-2]
            hr[0] = 0.
            hi[0] = 0.
        else:
            cdivid(-pr[nn-1], -pi[nn-1], hr[n-1], hi[n-1], tr, ti)
            for i in range(1, nm1 + 1):
                j = nn - i
                t1 = hr[j - 2]
                t2 = hi[j - 2]
                hr[j - 1] = tr * t1 - ti * t2 + pr[j - 1]
                hi[j - 1] = tr * t2 + ti * ti + pi[j - 1]
            hr[0] = pr[0]
            hi[0] = pi[0]

In [None]:
#exporti
@njit
def fxshft(nn, tr, ti, hr, hi, qhr, qhi, 
           sr, si, pr, pi, 
           qpr, qpi, pvr, pvi, 
           shr, shi,
           l2: int, zr: float, zi: float) -> bool:
    # l2 - limit of fixed shift steps 
    # zr, zi - approximate zero if convergence (result TRUE)
    # Return value indicates convergence of stage 3 iteration
    # Uses global (sr,si), nn, pr[], pi[], .. (all args of polyev() !)
    n = nn - 1
    # Evaluate p at s 
    polyev(nn, sr, si, pr, pi, qpr, qpi, pvr, pvi)
    test = True
    pasd = False 
    # calculate first t = -p(s) / h(s)
    boolean = False
    calct(nn, tr, ti, sr, si, hr, hi, pvr, pvi, qhr, qhi, boolean)
    # main loop for one second stage step. 
    for j in range(1, l2 + 1):
        otr = tr
        oti = ti
        # compute next h polynomial and new t 
        nexth(nn, tr, ti, hr, hi, qhr, qhi, qpr, qpi, boolean)
        calct(nn, tr, ti, sr, si, hr, hi, pvr, pvi, qhr, qhi, boolean)

        zr = sr + tr
        zi = si + ti
        # test for convergence unless stage 3 has failed once or 
        # this is the last h polynomial.
        if ((not boolean) and test and j != l2):
            if math.hypot(tr - otr, ti - oti) >= math.hypot(zr, zi) * 0.5:
                pasd = False 
            elif (not pasd):
                pasd = True 
            else:
                # The weak convergence test has been passed twice, start the third 
                # stage iteration, after saving the current h polynomial and shift. 
                for i in range(n):
                    shr[i] = hr[i]
                    shi[i] = hi[i]
                svsr = sr
                svsi = si 
                if vrshft(nn, tr, ti, hr, hi, qhr, qhi, 
                          pr, pi, 
                          qpr, qpi, pvr, pvi, 
                          shr, shi, 10, zr, zi):
                    return True 
                # The iteration failed to converge. 
                # Turn off testing and restore h, s, pv, and t 
                test = False 
                for i in range(1, n + 1):
                    hr[i - 1] = shr[i - 1]
                    hi[i - 1] = shi[i - 1]
                sr = svsr
                si = svsi 
                polyev(nn, sr, si, pr, pi, qpr, qpi, pvr, pvi)
                calct(nn, tr, ti, sr, si, hr, hi, pvr, pvi, qhr, qhi, boolean)
    # Attempt an iteration with final h polynomial from second stage. 
    return vrshft(nn, tr, ti, hr, hi, qhr, qhi, 
                  pr, pi, 
                  qpr, qpi, pvr, pvi, 
                  shr, shi, 10, zr, zi)

In [None]:
#exporti
@njit
def vrshft(nn, tr, ti, hr, hi, qhr, qhi, 
           pr, pi, 
           qpr, qpi, pvr, pvi, 
           shr, shi,
           l3: int, zr: float, zi: float) -> bool:
    # l3 - limit of steps in stage 3.
    # zr,zi   - on entry contains the initial iterate;
    # if the iteration converges it contains
    # the final iterate on exit.
    # Returns TRUE if iteration converges
    # Assign and uses  GLOBAL sr, si
    r1 = 0.
    r2 = 0.
    mp = 0.
    ms = 0.
    omp = 0.
    relstp = 0.
    boolean = False
    b = False
    sr = zr
    si = zi 
    # Main loop for stage three 
    for i in range(1, l3 + 1):
        # Evaluate p at s and test for convergence. 
        polyev(nn, sr, si, pr, pi, qpr, qpi, pvr, pvi)
        mp = math.hypot(pvr, pvi)
        ms = math.hypot(sr, si)
        if mp <= 20. * errev(nn, qpr, qpi, ms, mp, eta, mre):
            zr = sr
            zi = si
            return True 
        # Polynomial value is smaller in value than a bound on the error in evaluating p
        # terminate the iteration.
        if i != -1:
            if ((not b) and mp >= omp and relstp < 0.05):
                # Iteration has stalled. Probably a cluster of zeros.
                # Do 5 fixed shift steps into the cluster to force
                # one zero to dominate. 
                tp = relstp
                b = True
                if relstp < eta:
                    tp = eta 
                r1 = math.sqrt(tp)
                r2 = sr * (r1 + 1.) - si * r1
                si = si * r1 + sr * (r1 + 1.)
                sr = r2 
                polyev(nn, sr, si, pr, pi, qpr, qpi, pvr, pvi)
                for j in range(1, 6):
                    calct(nn, tr, ti, sr, si, hr, hi, pvr, pvi, qhr, qhi, boolean)
                    nexth(nn, tr, ti, hr, hi, qhr, qhi, qpr, qpi, boolean) 
                omp = infin 
                # calculate next iterate 
                calct(nn, tr, ti, sr, si, hr, hi, pvr, pvi, qhr, qhi, boolean)
                nexth(nn, tr, ti, hr, hi, qhr, qhi, qpr, qpi, boolean) 
                calct(nn, tr, ti, sr, si, hr, hi, pvr, pvi, qhr, qhi, boolean)
                if not boolean:
                    relstp = math.hypot(tr, ti) / math.hypot(sr, si)
                    sr += tr
                    si += ti
            else:
                # exit if polynomial value increases significantly
                if(mp * 0.1 > omp):
                    return False 
        omp = mp
        return False


In [None]:
#exporti
@njit
def calct(nn, tr, ti, sr, si, hr, hi, pvr, pvi, qhr, qhi, boolean: bool):
    # Computes t = -p(s) / h(s).
    # bool - logical, set true if h(s) is essentially zero. 
    n = nn - 1
    hvi, hvr = 0., 0.
    #Evaluate h(s)
    polyev(n, sr, si, hr, hi, qhr, qhi, hvr, hvi)
    boolean = math.hypot(hvr, hvi) <= are * 10. * math.hypot(hr[n-1], hi[n-1])
    if not boolean:
        cdivid(-pvr, -pvi, hvr, hvi, tr, ti)
    else:
        tr = 0.
        ti = 0. 

In [None]:
#exporti
@njit
def nexth(nn, tr, ti, hr, hi, qhr, qhi, qpr, qpi, boolean: bool):
    #Calculates the next shifted h polynomial.
    #bool: if TRUE h(s) is essentially zero. 
    n = nn - 1
    t1, t2 = 0., 0.
    if not boolean:
        for j in range(1, n):
            t1 = qhr[j - 1]
            t2 = qhi[j - 1]
            hr[j] = tr * t1 - ti * t2 + qpr[j]
            hi[j] = tr * t2 + ti * t1 + qpi[j]
        hr[0] = qpr[0]
        hi[0] = qpi[0]
    else:
        # if h(s) is zero replace h with qh. 
        for j in range(1, n):
            hr[j] = qhr[j - 1]
            hi[j] = qhi[j - 1]
        hr[0] = 0.
        hi[0] = 0.


### cpolyroot

In [None]:
#exporti
@njit
def cpolyroot(opr, opi, degree, zeror, zeroi, fail):
    sr = 0.
    si = 0.
    tr = 0.
    ti = 0.
    pvr = 0.
    pvi = 0.
    zi = 0.
    zr = 0.
    
    cosr = -0.06975647374412529990 #cos94
    sinr = 0.99756405025982424767 #sin94
    
    xx = 1 / math.sqrt(2)
    yy = -xx 
    fail = False 

    nn = degree 
    d1 = nn - 1 
    #Algorithm fails if the leading coefficient is zero. 
    if(opr[0] == 0. and opi[0] == 0.):
        fail = True
        return

    while (opr[nn] == 0. and opi[nn] == 0.):
        d_n = d1 - nn + 1
        zeror[d_n] = 0.
        zeroi[d_n] = 0.
        nn -= 1 
    nn += 1 
    #Now, global var.  nn := #{coefficients} = (relevant degree)+1
    if nn == 1:
        return
    #Use a single allocation as these as small
    tmp = np.zeros(10 * nn)
    pr = tmp.copy()
    pi = tmp + nn 
    hr = tmp + 2 * nn 
    hi = tmp + 3 * nn
    qpr = tmp + 4 * nn 
    qpi = tmp + 5 * nn
    qhr = tmp + 6 * nn 
    qhi = tmp + 7 * nn
    shr = tmp + 8 * nn
    shi = tmp + 9 * nn
    # make a copy of the coefficients and shr[] = | p[] | 
    for i in range(nn):
        pr[i] = opr[i]
        pi[i] = opi[i]
        shr[i] = math.hypot(pr[i], pi[i])
    # scale the polynomial with factor 'bnd' 
    bnd = cpoly_scale(nn, shr, eta, infin, smalno, base)
    if bnd != 1.:
        for i in range(nn):
            pr[i] *= bnd
            pi[i] *= bnd
    # start the algorithm for one zero 
    while nn > 2:
        # Calculate bnd, a lower bound on the modulus of the zeros 
        for i in range(nn):
            shr[i] = math.hypot(pr[i], pi[i])
        bnd = cpoly_cauchy(nn, shr, shi)
        # Outer loop to control 2 major passes with different sequences of shifts 
        for i1 in range(1, 3):
            # First stage calculation, no shift
            noshft(5, nn, tr, ti, pr, pi, hr, hi)
            # Inner loop to select a shift
            for i2 in range(1, 10):
                #shift is chosen with modulus bnd 
                #and amplitude rotated by 94 degrees
                #from the previous shift
                xxx = cosr * xx - sinr * yy 
                yy = sinr * xx + cosr * yy
                xx = xxx 
                sr = bnd * xx
                si = bnd * yy
                # second stage calculation, fixed shift
                conv = fxshft(nn, tr, ti, hr, hi, qhr, qhi, 
                              sr, si, pr, pi, qpr, qpi, pvr, pvi, 
                              shr, shi,
                              i2 * 10, zr, zi)
                if conv:
                    d_n = d1 + 2 - nn
                    zeror[d_n] = zr
                    zeroi[d_n] = zi
                    nn -= 1 
                    for i in range(nn):
                        pr[i] = qpr[i]
                        pi[i] = qpi[i]
                    break
            if conv:
                break

        fail = True
        return 
    #calculate the final zero and return 
    cdivid(-pr[1], -pi[1], pr[0], pi[0], zeror[d1], zeroi[d1])
    print(hr)

In [None]:
m = 12
phi = 0.9
alpha = beta = gamma = 0.5
opr = np.empty(m + 1)
opr[0] = 1.
opr[1] = alpha + beta - phi
opr[2:-2] = alpha + beta - alpha * phi
opr[-2] = alpha + beta - alpha * phi + gamma - 1
opr[-1] = phi * (1 - alpha - gamma)
degree = opr.size - 1
opi = np.zeros_like(opr)
zeror = np.zeros(degree)
zeroi = np.zeros(degree)
fail = False
cpolyroot(opr, opi, degree, zeror, zeroi, fail)

## etscalc

In [None]:
#export
# Global variables 
NONE = 0
ADD = 1
MULT = 2
DAMPED = 1
TOL = 1.0e-10
HUGEN = 1.0e10
NA = -99999.0

In [None]:
#exporti
@njit
def etscalc(y, n, x, m, 
            error, trend, season, 
            alpha, beta, 
            gamma, phi, e, 
            lik, amse, nmse):
    oldb = 0.
    olds = np.zeros(24)
    s = np.zeros(24)
    f = np.zeros(30)
    denom = np.zeros(30)
    if m > 24 and season > NONE:
        return; 
    elif m < 1:
        m = 1 
    if nmse > 30:
        nmse = 30 
    nstates = m * (season > NONE) + 1 + (trend > NONE) 
    #Copy initial state components 
    l = x[0]
    if trend > NONE:
        b = x[1]
    if season > NONE:
        for j in range(m):
            s[j] = x[(trend > NONE) + j + 1]
    lik = 0.
    lik2 = 0.
    for j in range(nmse):
        amse[j] = 0.
        denom[j] = 0.
    for i in range(n):
        # Copy previous state
        oldl = l 
        if trend > NONE:
            oldb = b
        if season > NONE:
            for j in range(m):
                olds[j] = s[j]
        # one step forecast 
        forecast(oldl, oldb, olds, m, trend, season, phi, f, nmse)
        if math.fabs(f[0] - NA) < TOL:
            lik = NA
            return 
        if error == ADD:
            e[i] = y[i] - f[0]
        else:
            e[i] = (y[i] - f[0]) / f[0]
        for j in range(nmse):
            if (i + j) < n:
                denom[j] += 1.
                tmp = y[i + j] - f[j]
                amse[j] = (amse[j] * (denom[j] - 1.0) + (tmp * tmp)) / denom[j]
        # update state
        update(oldl, l, oldb, b, olds, s, m, trend, season, alpha, beta, gamma, phi, y[i])
        # store new state
        x[nstates * (i + 1)] = l 
        if trend > NONE:
            x[nstates * (i + 1) + 1] = b 
        if season > NONE:
            for j in range(m):
                x[nstates * (i + 1) + (trend > NONE) + j + 1] = s[j]
        lik = lik + e[i] * e[i]
        lik2 += math.log(math.fabs(f[0]))
    lik = n * math.log(lik)
    if error == MULT:
        lik += 2 * lik2 

In [None]:
#exporti
@njit
def etssimulate(x, m, error, trend, 
                season, alpha, beta, 
                gamma, phi, h, 
                y, e):
    oldb = 0.
    olds = np.zeros(24)
    s = np.zeros(24)
    f = np.zeros(10)
    if m > 24 and season > NONE:
        return 
    elif m < 1:
        m = 1 
    nstates = m * (season > NONE) + 1 + (trend > NONE)
    # Copy initial state components 
    l = x[0]
    if trend > NONE:
        b = x[1]
    if season > NONE:
        for j in range(m):
            s[j] = x[(trend > NONE) + j + 1]
    for i in range(h):
        # Copy previous state
        oldl = l 
        if(trend > NONE):
            oldb = b 
        if(season > NONE):
            for j in range(m):
                olds[j] = s[j]
        # one step forecast
        forecast(oldl, oldb, olds, m, trend, season, phi, f, 1)
        if math.fabs(f[0] - NA) < TOL:
            y[0] = NA
            return 
        if error == ADD:
            y[i] = f[0] + e[i]
        else:
            y[i] = f[0] * (1.0 + e[i])
        # Update state 
        update(oldl, l, oldb, b, olds, s, m, trend, season, alpha, beta, gamma, phi, y[i])

In [None]:
@njit
def etsforecast(x, m, trend, season, 
                phi, h, f):
    s = np.zeros(24)
    if m > 24 and season > NONE:
        return 
    elif m < 1:
        m = 1 
    # Copy initial state components
    l = x[0]
    b = 0.0
    if trend > NONE:
        b = x[1]
    if season > NONE:
        for j in range(m):
            s[j] = x[(trend > NONE) + j + 1]

    # compute forecasts
    forecast(l, b, s, m, trend, season, phi, f, h) 

In [None]:
@njit
def forecast(l, b, s, m, 
             trend, season, phi, f, h):
    phistar = phi 
    #forecasts
    for i in range(h):
        if trend == NONE:
            f[i] = l
        elif trend == ADD:
            f[i] = l + phistar * b 
        elif b < 0:
            f[i] = NA
        else:
            f[i] = l * math.pow(b, phistar)
        j = m - 1 - i 
        while j < 0:
            j += m
        if season == ADD:
            f[i] += s[j]
        elif season == MULT:
            f[i] *= s[j]
        if i < h - 1:
            if math.fabs(phi - 1.0) < TOL:
                phistar = phistar + 1.0 
            else:
                phistar = phistar + math.pow(phi, i + 1)

In [None]:
@njit
def update(oldl, l, oldb, b, 
           olds, s, 
           m, trend, season, 
           alpha, beta, gamma, 
           phi, y):
    # New Level 
    if trend == NONE:
        q = oldl            # l(t - 1)
        phib = 0 
    elif trend == ADD:
        phib = phi * oldb
        q = oldl + phib     #l(t - 1) + phi * b(t - 1)
    elif math.fabs(phi - 1.0) < TOL:
        phib = oldb 
        q = oldl * oldb   #l(t - 1) * b(t - 1)
    else:
        phib = math.pow(oldb, phi)
        q = oldl * phib      #l(t - 1) * b(t - 1)^phi
    # season
    if season == NONE:
        p = y 
    elif season == ADD:
        p = y - olds[m - 1]  #y[t] - s[t - m]
    else:
        if math.fabs(olds[m - 1]) < TOL:
            p = HUGEN 
        else:
            p = y / olds[m - 1] #y[t] / s[t - m]
    l = q + alpha * (p - q)
    # New Growth 
    if trend > NONE:
        if trend == ADD:
            r = l - oldl    #l[t] - l[t-1]
        else: #if(trend == MULT)
            if math.fabs(oldl) < TOL:
                r = HUGEN
            else:
                r = l * oldl  #l[t] * l[t-1]
        b = phib + (beta / alpha) * (r - phib) 
        # b[t] = phi*b[t-1] + beta*(r - phi*b[t-1])
        # b[t] = b[t-1]^phi + beta*(r - b[t-1]^phi)
    # New Seasonal
    if season > NONE:
        if season == ADD:
            t = y - q 
        else: #if(season == MULT)
            if math.fabs(q) < TOL:
                t = HUGEN 
            else:
                t = y / q 
        s[0] = olds[m - 1] + gamma * (t - olds[m - 1]) # s[t] = s[t - m] + gamma * (t - s[t - m])
        for j in range(1, m):
            s[j] = olds[j - 1] # s[t] = s[t]

In [None]:
from statsforecast.utils import AirPassengers as ap
nmse_ = len(ap)
amse_ = np.zeros(30)
lik_ = 0.
e_ = np.zeros(len(ap))
init_states = np.ones(len(ap) * 1_0000)
etscalc(ap, len(ap), 
        init_states, 12, 1, 1, 1, 
        alpha, beta, gamma, phi, 
        e_, lik_, amse_, nmse_)

In [None]:
%%time
etscalc(ap, len(ap), 
        init_states, 12, 1, 1, 1, 
        alpha, beta, gamma, phi, 
        e_, lik_, amse_, nmse_)