# Test

In [48]:
class OnlineDRO:
    class OnlineCressieReadLB:
        from math import inf
        
        @staticmethod
        def intervalimpl(n, sumw, sumwsq, sumwr, sumwsqr, sumwsqrsq,
                         wmin, wmax, alpha=0.05,
                         rmin=0, rmax=1, raiseonerr=False):
            from math import inf, isclose, sqrt
            from scipy.stats import chi2

            assert wmin < 1
            assert wmax > 1
            assert rmin <= rmax

            uncwfake = wmax if sumw < n else wmin
            if uncwfake == inf:
                uncgstar = 1 + 1 / n
            else:
                unca = (uncwfake + sumw) / (1 + n)
                uncb = (uncwfake**2 + sumwsq) / (1 + n)
                uncgstar = (n + 1) * (unca - 1)**2 / (uncb - unca*unca)
            Delta = chi2.isf(q=alpha, df=1)
            phi = (-uncgstar - Delta) / (2 * (n + 1))

            bounds = []
            for r, sign in ((rmin, 1),):
                candidates = []
                for wfake in (wmin, wmax):
                    if wfake == inf:
                        x = sign * (r + (sumwr - sumw * r) / n)
                        y = (  (r * sumw - sumwr)**2 / (n * (1 + n))
                             - (r**2 * sumwsq - 2 * r * sumwsqr + sumwsqrsq) / (1 + n)
                            )
                        z = phi + 1 / (2 * n)
                        if isclose(y*z, 0, abs_tol=1e-9):
                            y = 0

                        if z <= 0 and y * z >= 0:
                            kappa = sqrt(y / (2 * z))
                            if isclose(kappa, 0):
                                candidates.append((sign * r, None))
                            else:
                                gstar = x - sqrt(2 * y * z)
                                gamma = ( -kappa * (1 + n) / n
                                         + sign * (r * sumw - sumwr) / n )
                                beta = -sign * r
                                candidates.append((gstar, {
                                    'kappastar': kappa,
                                    'betastar': beta,
                                    'gammastar': gamma,
                                    'wfake': wfake,
                                # Q_{w,r} &= -\frac{\gamma + \beta w + w r}{(N+1) \kappa} \\
                                    'qfunc': lambda c, w, r, k=kappa, g=gamma, b=beta, s=sign, num=n: -c * (g + (b + s * r) * w) / ((num + 1) * k),
                                }))
                    else:
                        barw = (wfake + sumw) / (1 + n)
                        barwsq = (wfake*wfake + sumwsq) / (1 + n)
                        barwr = sign * (wfake * r + sumwr) / (1 + n)
                        barwsqr = sign * (wfake * wfake * r + sumwsqr) / (1 + n)
                        barwsqrsq = (wfake * wfake * r * r + sumwsqrsq) / (1 + n)
                        
                        if barwsq > barw**2:
                            x = barwr + ((1 - barw) * (barwsqr - barw * barwr) / (barwsq - barw**2))
                            y = (barwsqr - barw * barwr)**2 / (barwsq - barw**2) - (barwsqrsq - barwr**2)
                            z = phi + (1/2) * (1 - barw)**2 / (barwsq - barw**2)
                            
                            if isclose(y*z, 0, abs_tol=1e-9):
                                y = 0

                            if z <= 0 and y * z >= 0:
                                kappa = sqrt(y / (2 * z)) if y * z > 0 else 0
                                if isclose(kappa, 0):
                                    candidates.append((sign * r, None))
                                else:
                                    gstar = x - sqrt(2 * y * z)
                                    beta = (-kappa * (1 - barw) - (barwsqr - barw * barwr)) / (barwsq - barw*barw)
                                    gamma = -kappa - beta * barw - barwr
                                    
                                    candidates.append((gstar, {
                                        'kappastar': kappa,
                                        'betastar': beta,
                                        'gammastar': gamma,
                                        'wfake': wfake,
                                    # Q_{w,r} &= -\frac{\gamma + \beta w + w r}{(N+1) \kappa} \\
                                        'qfunc': lambda c, w, r, k=kappa, g=gamma, b=beta, s=sign, num=n: -c * (g + (b + s * r) * w) / ((num + 1) * k),
                                    }))

                best = min(candidates, key=lambda x: x[0])
                vbound = min(rmax, max(rmin, sign*best[0]))
                bounds.append((vbound, best[1]))

            return (bounds[0][0], ), (bounds[0][1], )
        
        def __init__(self, alpha, tau=1, wmin=0, wmax=inf):
            import numpy as np
            
            self.alpha = alpha
            self.tau = tau
            self.n = 0
            self.sumw = 0
            self.sumwsq = 0
            self.sumwr = 0
            self.sumwsqr = 0
            self.sumwsqrsq = 0
            self.wmin = wmin
            self.wmax = wmax
            
            self.duals = None
            self.mleduals = None
            
        def update(self, c, w, r):
            if c > 0:
                assert w + 1e-6 >= self.wmin and w <= self.wmax + 1e-6, 'w = {} < {} < {}'.format(self.wmin, w, self.wmax)
                assert r >= 0 and r <= 1, 'r = {}'.format(r)
                
                decay = self.tau ** c
                self.n = decay * self.n + c
                self.sumw = decay * self.sumw + c * w
                self.sumwsq = decay * self.sumwsq + c * w**2
                self.sumwr = decay * self.sumwr + c * w * r
                self.sumwsqr = decay * self.sumwsqr + c * (w**2) * r
                self.sumwsqrsq = decay * self.sumwsqrsq + c * (w**2) * (r**2)
                    
                self.duals = None
                self.mleduals = None
                
            return self
        
        def recomputeduals(self):
            from MLE.MLE import CrMinusTwo as CrMinusTwo
            
            self.duals = self.intervalimpl(self.n, self.sumw, self.sumwsq, 
                                           self.sumwr, self.sumwsqr, self.sumwsqrsq,
                                           self.wmin, self.wmax, self.alpha, raiseonerr=True)
            
    def flass():
        from math import exp, pi
        from pprint import pformat
        import numpy
        
        ocrl = OnlineDRO.OnlineCressieReadLB(alpha=0.05, tau=0.999)
        
        ws = numpy.random.RandomState(seed=42).exponential(size=10)
        rs = numpy.random.RandomState(seed=2112).random_sample(size=10)
        duals = []
        
        print(list(zip(ws, rs)))
        
        for (w, r) in zip(ws, rs):
            ocrl.update(1, w, r)
            ocrl.recomputeduals()
            if ocrl.duals[1][0] is None:
                duals.append( ( True, 0, 0, 0, 0 ) )
            else:               
                duals.append( ( False, ocrl.duals[1][0]['kappastar'], 
                               ocrl.duals[1][0]['gammastar'], ocrl.duals[1][0]['betastar'], ocrl.n ) )
            
        print(pformat(duals))
        
OnlineDRO.flass()

[(0.4692680899768591, 0.08779271803562538), (3.010121430917521, 0.1488852932982503), (1.3167456935454493, 0.5579699034039329), (0.9129425537759532, 0.6202896863254631), (0.16962487046234628, 0.7284609393916186), (0.16959629191460518, 0.10028857734263497), (0.059838768608680676, 0.5165390922355259), (2.0112308644799395, 0.4596272470443479), (0.9190821536272645, 0.4352012556681023), (1.2312500617045903, 0.40365563207132593)]
[(True, 0, 0, 0, 0),
 (False, 0.186284935714629, -0.5242563567278763, 0, 1.999),
 (False,
  0.24176630719751424,
  -0.3939735949427358,
  -0.1283677781597634,
  2.997001),
 (False,
  0.2789701026811336,
  -0.5061803928309371,
  -0.11471449055314126,
  3.994003999),
 (False,
  0.28140131203664326,
  -0.43475483491188227,
  -0.16912405473103076,
  4.9900099950009995),
 (False,
  0.29153800750906095,
  -0.3748233156965521,
  -0.22291421513333443,
  5.985019985005999),
 (False, 0.37075115114133017, -0.7039218308392182, 0, 6.979034965020992),
 (False, 0.563270986745603, -