In [1]:
import math
import time

import numpy as np
import pandas as pd

In [2]:
df = pd.read_csv('data/poisson_sim.csv')
dummy_prog = pd.get_dummies(df.prog)
df['intercept'] = 1
df['general'] = dummy_prog[1]
df['academic'] = dummy_prog[2]
df['vocational'] = dummy_prog[3]

features = ['intercept', 'math', 'vocational', 'academic']
X = df[features].values
Y = df[['num_awards']].values.flatten().astype('float64')

In [42]:
def fitCD(X, Y, b=None, precision=1e-10):
    n, p = X.shape
    std = X.std(axis=0)
    std[std == 0] = 1
    X /= std
    b = np.zeros(p) if b is None else b * std
    sum_update = 1
    nb_iteration = 0
    r = Y - X @ b 
    divisor = n * (1 + np.square(X).mean(axis=0)) # 2np + 3p
    Xj = [row for row in X.T]
    while math.sqrt(sum_update / p) > precision:
        nb_iteration += 1
        sum_update = 0
        for j in range(p):
            #print((np.dot(r, Xj[j]) - b[j] * np.dot(ones, Xj[j]))/ n - b[j])
            update = float((np.dot(r, Xj[j]) + b[j] * Xj[j].sum())/ n - b[j]) #  2n + 1
            print(update)
            sum_update += update**2
            b[j] += update
            print('beta', b[j])
            r -= update * Xj[j] # 2n
    print("number or iterations", nb_iteration, math.sqrt(sum_update / p))
    return b / std

In [43]:
def irls_poisson(X, Y, solver):
    n, p = X.shape
    b = np.zeros(p) # p
    yp = np.zeros(n) # n
    mu = np.ones(n) # n
    previous_ll = n
    current_ll = n
    while previous_ll >= current_ll:
        previous_ll = current_ll
        W = np.sqrt(mu) # n
        Z = yp + Y / mu - 1 # 3n
        b = solver(W[:, np.newaxis] * X, W * Z, b) # ?
        yp = X @ b # 2np
        mu = np.exp(yp) # n
        current_ll = -(Y.T @ yp - sum(mu)) # 2n
    return b

In [44]:
irls_poisson(X, Y, fitCD)

-0.37
beta -0.37
0.5220600745457049
beta 0.5220600745457049
-1.7223872551759316
beta -1.7223872551759316
-2.9430706582442214
beta -2.9430706582442214
1.14749191501185
beta 0.7774919150118501
3.1101798157802936
beta 3.6322398903259985
-8.280121045463837
beta -10.002508300639768
-17.95530634608931
beta -20.89837700433353
6.136028525594251
beta 6.913520440606101
19.484881316645385
beta 23.117121206971383
-53.16386785241647
beta -63.16637615305624
-111.88951828738985
beta -132.78789529172337
38.557308108199216
beta 45.470828548805315
121.29906452677929
beta 144.41618573375067
-329.94473571130305
beta -393.1111118643593
-696.8034365858948
beta -829.5913318776181
239.71583399904736
beta 285.1866625478527
755.6316373497679
beta 900.0478230835186
-2056.0499453802877
beta -2449.161057244647
-4340.453740821517
beta -5170.045072699136
1493.3975685029704
beta 1778.584231050823
4706.819369908391
beta 5606.86719299191
-12806.602955372102
beta -15255.76401261675
-27036.77153890485
beta -32206.8166116

beta 2.9947571650970567e+56
7.923629641467281e+56
beta 9.43894329852942e+56
-2.1559179337372703e+57
beta -2.5682153323183982e+57
-4.5514644176610565e+57
beta -5.4218857401847865e+57
1.5659692982880882e+57
beta 1.865445014797794e+57
4.9356574169178565e+57
beta 5.8795517467707986e+57
-1.3429290390137573e+58
beta -1.5997505722455972e+58
-2.8351235642440483e+58
beta -3.377312138262527e+58
9.754479110573371e+57
beta 1.1619924125371166e+58
3.0744387660028447e+58
beta 3.662393940679925e+58
-8.365153309431073e+58
beta -9.96490388167667e+58
-1.7660086703836e+59
beta -2.1037398842098527e+59
6.076100139551242e+58
beta 7.238092552088358e+58
1.915078970736921e+59
beta 2.2813183648049135e+59
-5.210684098519131e+59
beta -6.207174486686798e+59
-1.1000531557789932e+60
beta -1.3104271441999785e+60
3.7848246418239076e+59
beta 4.508633897032743e+59
1.1929095822998059e+60
beta 1.4210414187802972e+60
-3.2457538756581066e+60
beta -3.866471324326787e+60
-6.852270692852651e+60
beta -8.16269783705263e+60
2.3575

beta -5.960829584373286e+90
-1.0563952099968542e+91
beta -1.2584200598946637e+91
3.6346158376951284e+90
beta 4.3296991853834865e+90
1.1455664320225278e+91
beta 1.3646443720654492e+91
-3.1169392397641675e+91
beta -3.713022198201496e+91
-6.580323777540858e+91
beta -7.838743837435522e+91
2.264015284495948e+91
beta 2.6969852030342968e+91
7.135774528372701e+91
beta 8.50041890043815e+91
-1.941552668781162e+92
beta -2.3128548886013114e+92
-4.098907360380642e+92
beta -4.882781744124194e+92
1.4102632677905646e+92
beta 1.6799617880939943e+92
4.444899631867993e+92
beta 5.294941521911808e+92
-1.2094001440773888e+93
beta -1.44068563293752e+93
-2.5532241447337015e+93
beta -3.0415023191461207e+93
8.784580643509688e+92
beta 1.0464542431603682e+93
2.768743975699271e+93
beta 3.2982381278904517e+93
-7.533397017823927e+93
beta -8.974082650761447e+93
-1.5904125075531747e+94
beta -1.894562739467787e+94
5.4719468942294895e+93
beta 6.518401137389858e+93
1.7246605858115516e+94
beta 2.0544843986005968e+94
-4.69

4.2684943498907675e+123
beta 5.084800522207563e+123
-1.161402530833283e+124
beta -1.3835089638634256e+124
-2.4518940220075043e+124
beta -2.9207938400622237e+124
8.435945903962427e+123
beta 1.0049234840589773e+124
2.6588604907597645e+124
beta 3.167340542980521e+124
-7.234418157728235e+124
beta -8.61792712159166e+124
-1.527293609469727e+125
beta -1.8193729934759493e+125
5.254781060400454e+124
beta 6.259704544459431e+124
1.6562137676260886e+125
beta 1.9729478219241408e+125
-4.506345103563483e+125
beta -5.368137815722649e+125
-9.51356685317588e+125
beta -1.1332939846651829e+126
3.273222031897271e+125
beta 3.899192486343214e+125
1.0316615157534585e+126
beta 1.2289562979458726e+126
-2.8070185811304386e+126
beta -3.343832362702703e+126
-5.926035027493566e+126
beta -7.059329012158748e+126
2.0389017823858127e+126
beta 2.4288210310201342e+126
6.426256706054682e+126
beta 7.655213004000554e+126
-1.7485019752661162e+127
beta -2.0828852115363866e+127
-3.6913485435126157e+127
beta -4.39728144472849e+

OverflowError: (34, 'Result too large')

In [5]:
irls_poisson(X, Y, fitCD)

number or iterations 1576 9.933620773673003e-11
number or iterations 939 9.839206338178205e-11
number or iterations 1063 9.877450642927196e-11
number or iterations 965 9.87371509075679e-11
number or iterations 624 9.951406951952467e-11
number or iterations 98 9.906425074093217e-11
number or iterations 1 9.584553914760944e-11


array([-5.24712438,  0.0701524 ,  0.36980922,  1.08385914])

In [704]:
def irls1(X, Y, b):
    yp = X @ b # 2np
    mu = np.exp(yp) # n
    W = np.sqrt(mu) # n
    Z = yp + Y / mu - 1 # 3n
    return np.linalg.lstsq(W[:, np.newaxis] * X, W * Z, rcond=None)[0]

In [705]:
def irls2(X, Y, b, precision=0.0000001): #2np + 4n
    yp = X @ b # 2np
    mu = np.exp(yp) # n
    W = np.sqrt(mu) # n
    Z = yp + Y / mu - 1 # 3n
    return fitCD(W[:, np.newaxis] * X, W * Z, b, precision)

In [714]:
def irls3(X, Y, b):
    yp = X @ b # 2np
    mu = np.exp(yp) # n
    W = np.sqrt(mu) # n
    Z = yp + Y / mu - 1 # 3n
    Xt = W[:, np.newaxis] * X
    return np.linalg.solve(Xt.T @ Xt, Xt.T @ (W * Z)) #npp

In [715]:
def ll(X, Y, b):
    yp = X @ b
    return -(Y.T @ yp - sum(np.exp(yp)))

In [716]:
#%%timeit
b = np.zeros(X.shape[1])
previous_ll = ll(X, Y, b)
current_ll = previous_ll
while previous_ll >= current_ll:
    previous_ll = current_ll
    b = irls3(X, Y, b)
    current_ll = ll(X, Y, b)
print(b)

ValueError: operands could not be broadcast together with shapes (4,4) (200,1) 

In [634]:
b

array([-5.2471244 ,  0.0701524 ,  0.36980923,  1.08385915])

In [670]:
#%%timeit
b = np.zeros(X.shape[1])
previous_ll = ll(X, Y, b)
current_ll = previous_ll
while previous_ll >= current_ll:
    previous_ll = current_ll
    b = irls2(X, Y, b, precision=0.00001)
    current_ll = ll(X, Y, b)
    print(b)

number or iterations 581 3.9931638543425046e-05
[-3.19286779  0.0478418   0.21204731  0.47864324]
number or iterations 310 3.9979100887482767e-05
[-4.67012341  0.06490974  0.3048447   0.84415958]
number or iterations 299 3.990422639741532e-05
[-5.18416698  0.0697457   0.35575101  1.04694408]
number or iterations 154 3.954099722352021e-05
[-5.24293673  0.07010468  0.36850545  1.08262421]
number or iterations 26 3.9767711313582736e-05
[-5.24452538  0.07011935  0.36893462  1.08329661]
number or iterations 1 3.891291737691065e-05
[-5.24456849  0.07011981  0.36895369  1.08331127]
number or iterations 1 3.8064361516512385e-05
[-5.24461069  0.07012027  0.36897224  1.08332545]
number or iterations 1 3.724270344034643e-05
[-5.24465204  0.07012072  0.36899026  1.08333915]
number or iterations 1 3.6443810414091456e-05
[-5.24469256  0.07012116  0.36900778  1.08335241]
number or iterations 1 3.5667230166605375e-05
[-5.24473227  0.0701216   0.3690248   1.08336524]
number or iterations 1 3.4912355929

[-5.24705412  0.07015141  0.36979066  1.0838498 ]
number or iterations 1 8.297964031608715e-07
[-5.2470551   0.07015143  0.36979092  1.08384993]
number or iterations 1 8.181708891015468e-07
[-5.24705607  0.07015144  0.36979117  1.08385006]
number or iterations 1 8.067084874876725e-07
[-5.24705703  0.07015145  0.36979143  1.08385019]
number or iterations 1 7.954068997328453e-07
[-5.24705797  0.07015147  0.36979168  1.08385032]
number or iterations 1 7.842638599626949e-07
[-5.2470589   0.07015148  0.36979192  1.08385044]
number or iterations 1 7.732771356926303e-07
[-5.24705982  0.07015149  0.36979217  1.08385056]
number or iterations 1 7.624445248748544e-07
[-5.24706072  0.0701515   0.3697924   1.08385068]
number or iterations 1 7.517638576512243e-07
[-5.24706162  0.07015152  0.36979264  1.0838508 ]
number or iterations 1 7.412329947994308e-07
[-5.24706249  0.07015153  0.36979287  1.08385092]
number or iterations 1 7.308498276474724e-07
[-5.24706336  0.07015154  0.3697931   1.08385103]


number or iterations 1 4.48094439820546e-08
[-5.24712066  0.07015234  0.36980824  1.08385865]
number or iterations 1 4.418213273963659e-08
[-5.24712071  0.07015235  0.36980825  1.08385866]
number or iterations 1 4.356360370231114e-08
[-5.24712076  0.07015235  0.36980827  1.08385866]
number or iterations 1 4.295373340841745e-08
[-5.24712081  0.07015235  0.36980828  1.08385867]
number or iterations 1 4.235240166749347e-08
[-5.24712086  0.07015235  0.3698083   1.08385868]
number or iterations 1 4.175948790645504e-08
[-5.24712091  0.07015235  0.36980831  1.08385868]
number or iterations 1 4.117487494219998e-08
[-5.24712096  0.07015235  0.36980832  1.08385869]
number or iterations 1 4.0598446356132376e-08
[-5.24712101  0.07015235  0.36980833  1.0838587 ]
number or iterations 1 4.003008731529282e-08
[-5.24712105  0.07015235  0.36980835  1.0838587 ]
number or iterations 1 3.9469684950324606e-08
[-5.2471211   0.07015235  0.36980836  1.08385871]
number or iterations 1 3.8917127928110763e-08
[-5

In [494]:
b

array([-5.24378781,  0.07010557,  0.36892796,  1.08341622])

In [411]:
list(zip(features, b))

[('intercept', -5.247124398538536),
 ('math', 0.07015239749371581),
 ('vocational', 0.3698092298424559),
 ('academic', 1.083859145620779)]

In [412]:
list(zip(features, np.exp(b)))

[('intercept', 0.005262629887655122),
 ('math', 1.0726716412681612),
 ('vocational', 1.447458456445014),
 ('academic', 2.9560654540575952)]

In [695]:
pt = 10
nt = 1000000
Xt = np.random.rand(nt * pt).reshape(nt, pt)
Yt = np.random.rand(nt)

In [696]:
start = time.time()

b = np.zeros(Xt.shape[1])
previous_ll = ll(Xt, Yt, b)
current_ll = previous_ll
#while previous_ll >= current_ll:
#    previous_ll = current_ll
b = irls3(Xt, Yt, b)
b = irls3(Xt, Yt, b)
b = irls3(Xt, Yt, b)
b = irls3(Xt, Yt, b)
#    current_ll = ll(Xt, Yt, b)
print(b)

print(time.time() - start)

[-0.13315003 -0.13711501 -0.13642069 -0.13795983 -0.134634   -0.13868153
 -0.13523663 -0.13392137 -0.13330512 -0.13625787]
0.49159812927246094


In [698]:
start = time.time()

b = np.zeros(Xt.shape[1])
previous_ll = ll(Xt, Yt, b)
#current_ll = previous_ll
#while previous_ll >= current_ll:
#    previous_ll = current_ll
b = irls2(Xt, Yt, b,  0.00000001)
b = irls2(Xt, Yt, b,  0.00000001)
b = irls2(Xt, Yt, b,  0.00000001)
b = irls2(Xt, Yt, b,  0.00000001)
#    current_ll = ll(Xt, Yt, b)
print(b)

print(time.time() - start)

number or iterations 142 9.959726289188059e-08
number or iterations 125 9.810231123293773e-08
number or iterations 100 9.422837971855437e-08
number or iterations 53 9.289810937541226e-08
[-0.13314979 -0.13711497 -0.13642087 -0.13796015 -0.13463436 -0.13868178
 -0.13523669 -0.13392123 -0.13330482 -0.13625754]
38.92078495025635


In [674]:
%%timeit
np.linalg.solve(Xt.T @ Xt, Xt.T @ Yt)

26.2 ms ± 1.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [553]:
a = np.array([1, 5, 10, 20, 30, 40])
0.01 * 

array([7.38905610e-02, 4.03428793e+00, 5.98741417e+02, 1.31881573e+07,
       2.90488497e+11, 6.39843494e+15])

In [658]:
ptt = 2000
Xtt = np.random.rand(ptt**2).reshape(ptt, ptt)
Ytt = np.random.rand(ptt)
Xttp = Xtt.T @ Xtt
Yttp = Xtt.T @ Ytt

In [663]:
%%timeit
np.linalg.solve(Xttp, Yttp)

123 ms ± 13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
