In [1]:
%run ComputeRateFromParams.ipynb

Generating cache table of size 512x512x128 (~25.0 bits)...


  0%|          | 0/511 [00:00<?, ?it/s]

Done. Table generation took 1.1 seconds


In [2]:
def compute_func_and_jacobian(f, p: tuple):
    """
    Given some function f and a point p, computes the value of f at p and estimates its Jacobian by looking at p + epsilon e_i.
    """
    DELTA = 1E-10
    base_point = f(*p)
    derivatives = np.zeros(len(p))
    for i in range(len(p)):
        pt = p[:i] + (p[i] + DELTA,) + p[i+1:]
        derivatives[i] = (f(*pt) - base_point) / DELTA
    return (base_point, derivatives)

def get_direction(jacobian: np.ndarray):
    return jacobian / np.linalg.norm(jacobian)

def get_second_derivative(f, f_at_p: float, p: tuple, d: np.ndarray):
    '''
    Given a function f, a starting point p, the value of f at p and a direction d, 
        computes the first and second derivatives of f along d.
    '''
    DELTA = 1E-6
    plus_delta = f(*[p[i] + DELTA*d[i] for i in range(len(p))])
    minus_delta = f(*[p[i] - DELTA*d[i] for i in range(len(p))])
    first_derivative = (plus_delta - minus_delta) / (2 * DELTA)
    second_derivative = (plus_delta + minus_delta - (2 * f_at_p)) / (DELTA ** 2)
    return (first_derivative, second_derivative)
    

def do_optimization_step(f, current_point: tuple, jump_limit: float = 0.1):
    '''
    Given a function f, a starting point and an upper bound on the distance of the jump allowed,
        performs a step of gradient descent to maximize the function f.
    Will ensure that all entries of the output point and all intermediary points are positive.
    Returns as tuple of:
        - The next point
        - The value of f at the starting point
        - The jacobian of f at the starting point
    '''
    bp, jac = compute_func_and_jacobian(f, current_point)
    d = get_direction(jac)
    first_derivative, second_derivative = get_second_derivative(f, bp, current_point, d)
    jump_size = first_derivative / (-(second_derivative))
    jump_size = min(jump_size, jump_limit)
    if jump_size < 0:
        jump_size = jump_limit
    for i in range(len(current_point)):
        if current_point[i] + (jump_size * d[i]) < 0:
            jump_size = min(jump_size, -current_point[i] / (2 * d))
    print(jump_size, jump_size * np.dot(d, jac))
    next_point = tuple([current_point[i] + (jump_size * d[i]) for i in range(len(p))])
    return (next_point, bp, jac)

In [3]:
p = (1.79, 0.1, 0.23, 0.1)
t0 = time.time()
data = []
for i in it.count():
    p, bp, jac = do_optimization_step(compute_rate_from_beta_g, p, 0.025)
    data.append((p, bp, jac))
    print(p, bp, '%.1f' % (time.time() - t0))

0.015810148031582854 0.0005979157866015759
(1.791933286366758, 0.09975267533590945, 0.2143184675684512, 0.10050155330074194) 0.11891728072703167 5.9
0.019442279585798816 0.00014361412201387423
(1.7917294640834236, 0.09756030733488208, 0.21893332550330286, 0.11925941300696624) 0.11920838308631056 11.9
0.007939667787766003 0.00014152279099094073
(1.7926851960237333, 0.09724727282574759, 0.2112774558436848, 0.12110726857204467) 0.11928009278938469 18.3
0.014919729377223195 0.00010940076369106593
(1.7925171846254346, 0.09559991400612647, 0.21557290019817169, 0.1352990078444715) 0.11934982669299855 24.6
0.007033803019882058 0.00010783774528967395
(1.7933728358998173, 0.09528188601925681, 0.20889520384364202, 0.13731133868591336) 0.11940484525340893 30.5
0.012416454698619163 9.081534597274518e-05
(1.793280720344384, 0.09394609267064273, 0.2129930503173269, 0.1489553574258165) 0.1194580639584693 37.3
0.006514913257630268 9.007066474669933e-05
(1.794089216612277, 0.09362231486409583, 0.2068933

0.0036133770345573033 3.45444016529349e-05
(1.8170259064413317, 0.07094565382326466, 0.18371223199879502, 0.3411761662146975) 0.12056842877258821 340.6
0.007871456005197699 3.705318614325585e-05
(1.8188492209284686, 0.07014503845090593, 0.17956388405784296, 0.3475625241994256) 0.1205858036577816 346.6
0.0036271311869682087 3.612445138593028e-05
(1.8188634735455704, 0.06995222379046889, 0.18263476709012452, 0.3494830390832911) 0.12060404839951859 352.7
0.008669893286040703 3.9265566102605207e-05
(1.8209441692106683, 0.06906330554539519, 0.17831849251519555, 0.35665362854915605) 0.12062223860925055 358.9
0.0036620122154115585 3.825391358506201e-05
(1.820948985572901, 0.06888208383512051, 0.18148191814979978, 0.3584894409618312) 0.1206415579475905 364.5
0.009900798864929809 4.304679389697208e-05
(1.823433249332435, 0.06786448692560526, 0.17689514748374252, 0.3668429003125196) 0.12066082608019493 370.3
0.003745650828084908 4.1710013146543484e-05
(1.823418825091269, 0.06769307652573056, 0.1

0.025 2.3306894030280292e-05
(2.0515143806057816, 0.02456285310349418, 0.15828528041616097, 0.6529852411576488) 0.12149428239559887 661.0
0.0011751255352984386 4.911257955096292e-06
(2.051615171298665, 0.02452262267207235, 0.1594260531248101, 0.6532455878486735) 0.12151325692371095 666.7
0.00959183434856176 7.624311844578415e-06
(2.060088878397938, 0.023043910543879116, 0.15812317879215512, 0.6572848332090382) 0.12151571976404177 672.3
0.001384473207675222 6.900309684371553e-06
(2.0601587689669185, 0.02300354902941527, 0.1594764652919195, 0.657565667624386) 0.12151951774099777 677.8
0.020053485588009224 1.4538943426204372e-05
(2.0778583203535304, 0.019591078592916016, 0.15781437339202728, 0.6661947267259049) 0.12152298534000681 683.4
0.0017917117149417571 1.1677132559229994e-05
(2.0778666375467845, 0.019548119145190827, 0.15957655009853788, 0.6665157398637022) 0.12153044450588643 688.9
0.025 1.5021497757686621e-05
(2.099785258515983, 0.014438008618915626, 0.16028099471345944, 0.6773768

0.0005715247740247383 1.9137033598872848e-07
(2.124129284744939, 0.005747110661644662, 0.160185504561826, 0.6886168364405847) 0.12156017459058857 1046.9
0.0002531962456695285 1.8923886340775276e-07
(2.1242319626666175, 0.005708982550618443, 0.15995730006007353, 0.688622713153135) 0.12156026982743069 1054.9
0.0006016281455280445 2.0130534711608656e-07
(2.12467986742426, 0.005498879037562672, 0.1601958498919991, 0.6888682498002215) 0.12156036426399723 1063.0
0.0002562656125761218 1.9469262104708834e-07
(2.124782197094766, 0.00546133631306775, 0.15996399314084778, 0.6888740544109914) 0.12156046370857042 1071.1
0.0005410607902385016 1.7470064370285497e-07
(2.1251742777658906, 0.005266458731571678, 0.1601867768199328, 0.6891007914539015) 0.12156056107239431 1079.6
0.00024833025778887234 1.791159056099019e-07
(2.1252794891376428, 0.00522799590039065, 0.1599652691362617, 0.6891081017808216) 0.12156064932839747 1087.8
0.000494882779297286 1.607527355041306e-07
(2.125635960072034, 0.00505504618

0.0001873978358599059 9.362046131202442e-08
(2.133777606817588, 0.0008551231875648347, 0.16004609014589408, 0.6929829233929589) 0.12156347133922964 1380.8
0.0003368096761796551 9.586455788742469e-08
(2.1340048900403668, 0.0007221033823257987, 0.16021014790568888, 0.6931139725344165) 0.12156351791668094 1386.5
0.0001870018210902573 9.359434449330834e-08
(2.1340913266865265, 0.0006803628303410094, 0.16004992640800814, 0.6931232021223088) 0.12156356497951555 1392.2
0.00031859032386726227 8.635472246886018e-08
(2.1343006058396594, 0.0005501954381352365, 0.16020798681282944, 0.6932488022375951) 0.12156361151433615 1398.0
0.00018336876014931353 8.827742867289374e-08
(2.134385656425472, 0.00050737941917422, 0.16005162905469772, 0.6932592683755634) 0.12156365568552124 1403.7
0.00031209595817288123 8.439555867823231e-08
(2.1345884296674065, 0.0003731582685099157, 0.1602085941234937, 0.6933760311665351) 0.12156369990268925 1409.4
0.00018064547982280863 8.551507213153203e-08
(2.1346724210939096, 

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [4]:
1 / 0.12156387313014512 

8.226128160044798

In [7]:
1 / 0.1185

8.438818565400844

In [None]:
p =(1.79, 0.1, 0.23, 0.1)
bp, jac = compute_func_and_jacobian(compute_rate_from_beta_g, p)

In [None]:
d = get_direction(jac)

In [None]:
get_second_derivative(compute_rate_from_beta_g, bp, p, d)

In [None]:
np.dot(d, jac)

In [None]:
np.dot((-0.5 * (_13[0] / _13[1]) * d), jac)

In [None]:
bp