In [None]:
# JBY: Set up env using "ipython --pylab" imports
%pylab

In [None]:
%autoreload 2

In [None]:
%matplotlib inline
figsize(17,6)
rcParams['font.size'] = 16

# SB Figure 4 (normal and rescaled)

In [None]:
pct = linspace(0, 1, 200)

In [None]:
plot(pct, 0*pct + 1, 'k')
plot(pct, pct, 'r')
plot(pct, pct**2, 'b')
plot(pct, pct**3, 'g')
legend(('SGD', 'SB Beta 1', 'SB Beta 2', 'SB Beta 3'))
title('SB Figure 4')
xlabel('Percentile'); ylabel('Probability of Selection\n == Per-example Weight')

In [None]:
plot(pct, 0*pct + 1, 'k')
plot(pct, pct / pct.mean(), 'r')
plot(pct, pct**2 / (pct**2).mean(), 'b')
plot(pct, pct**3 / (pct**3).mean(), 'g')
legend(('SGD', 'SB Beta 1', 'SB Beta 2', 'SB Beta 3'))
title('SB Figure 4, with normalization (integral is 1, so average batch weight does not change)')
xlabel('Percentile'); ylabel('Per-example weight')

# Power parameterization

In [None]:
def power_reweighting(pct, aa):
    '''Reweight the 100th percentile example to have aa times the weight of
    the 0th percentile example.
    Valid values for aa are 0 < aa < inf.'''
    if aa == 1:
        # Return equal weights everywhere
        return 1.0 + pct * 0
    else:
        return aa ** pct * log(aa) / (aa - 1)

In [None]:
aas = [1/100, 1/25, 1/10, 1/4, 1/3, 1/2, 1, 2, 3, 4, 10, 25, 100]

N_aas = len(aas)

color_highlights = {1: 'k'}
base_clr = (.7, .7, .7)
#base_clr = 'b'
for ii, aa in enumerate(aas):
    subplot(1, N_aas, 1 + ii)
    yy = power_reweighting(pct, aa)
    clr = color_highlights.get(aa, base_clr)
    plot(pct, yy, c=clr, lw=2)

    gca().axhline(0, ls='-', c=base_clr)
    gca().axhline(1, ls=':', c=base_clr)
    gca().set_xticks([])
    gca().set_yticks([])
    xlabel('p=%g' % aa)

Approximate (though not equivalent!) power reweightings to SB

In [None]:
plot(pct, 0*pct + 1, 'k')
plot(pct, pct / pct.mean(), 'r')
plot(pct, power_reweighting(pct, 5), 'r:')
plot(pct, pct**2 / (pct**2).mean(), 'b')
plot(pct, power_reweighting(pct, 16), 'b:')
plot(pct, pct**3 / (pct**3).mean(), 'g')
plot(pct, power_reweighting(pct, 50), 'g:')
legend(('SGD', 'SB Beta 1', 'Power: 5', 'SB Beta 2', 'Power: 16', 'SB Beta 3', 'Power: 50'))
title('SB vs. Power re-weightings')
xlabel('Percentile'); ylabel('Per-example weight')

# SCRATCH: Linear-Quadratic parameterizations (two dimensions instead of one; abandoned)

In [None]:
# Parameterization v1
#def linear_quadratic_reweighting(pct, aa, bb):
#    return 1.0 + aa * (pct-.5) + bb * (pct-.5)**2 - bb/12

In [None]:
# Parameterization v2
def linear_quadratic_reweighting(pct, aa, bb):
    #return aa * pct**2 + bb * pct
    return aa * pct**2 + bb * pct + (1 - aa/3.0 - bb/2.0)
    #return aa * pct**2 + bb * pct + (1 - aa/3.0 + bb/2.0)

In [None]:
plot(pct, linear_quadratic_reweighting(pct, 0, 0), 'k')
plot(pct, linear_quadratic_reweighting(pct, 0, 2), 'r')
plot(pct, linear_quadratic_reweighting(pct, 3, 0), 'b')
#plot(pct, linear_quadratic_reweighting(pct, N/A, N/A), 'g')
legend(('LQ(0, 0) == SGD', 'LQ(0, 2) == SB Beta 1', 'LQ(3, 0) == SB Beta 2'))
title('SB Figure 4, using LQ reweighting. (Beta = 3 is not possible using LQ)')
xlabel('Percentile'); ylabel('Per-example weight')

In [None]:
curvatures = [-2, -1, 0, 1, 2, 3, 4]
slopes = [-.5, 0, .5, 1, 1.5, 2, 3]

curvatures.sort()
slopes.sort()
N_curvatures = len(curvatures)
N_slopes = len(slopes)

# Reverse curvatures so subplots are in order -_-
curvatures.reverse()

color_highlights = {(0, 0): 'k', (0, 2): 'r', (3, 0): 'b'}

for cc, curvature in enumerate(curvatures):
    for ss, slope in enumerate(slopes):
        #print(N_curvatures, N_slopes, 1 + cc * N_slopes + ss)
        subplot(N_curvatures, N_slopes, 1 + cc * N_slopes + ss)
        #plot(rand(10))
        yy = linear_quadratic_reweighting(pct, curvature, slope)
        is_valid = yy.min() >= 0
        base_clr = (.6, .6, .6) if is_valid else (.9, .9, .9)
        clr = color_highlights.get((curvature, slope), base_clr)
        plot(pct, yy, c=clr, lw=2)
        
        gca().axhline(0, ls=':', c=base_clr)
        gca().set_xticks([])
        gca().set_yticks([])
        plt.setp(gca().spines.values(), color=base_clr)
        if cc == 0:
            title('s=%g' % slope)
        if ss == 0:
            ylabel('c=%g' % curvature)

In [None]:
# Parameterization v3
def linear_quadratic_reweighting(pct, uu, vv):
    '''uu = weight of 0th perctile point
    vv = weight of 100th percentile point
    Curvature is solved for to force average weight to be 1.'''
    
    # a, b, and c are the components of the polynomial
    # y = a x^2 + b x + c
    c = uu
    a = -6 + 3 * uu + 3 * vv
    b = vv - uu - a
    vals = a * pct**2 + b * pct + c
    return vals


In [None]:
plot(pct, linear_quadratic_reweighting(pct, 1, 1), 'k')
plot(pct, linear_quadratic_reweighting(pct, 0, 2), 'r')
plot(pct, linear_quadratic_reweighting(pct, 0, 3), 'b')
#plot(pct, linear_quadratic_reweighting(pct, N/A, N/A), 'g')
legend(('LQ(1, 1) == SGD', 'LQ(0, 2) == SB Beta 1', 'LQ(0, 3) == SB Beta 2'))
title('SB Figure 4, using LQ reweighting. (Beta = 3 is not possible using LQ)')
xlabel('Percentile'); ylabel('Per-example weight')

In [None]:
uus = [0, 1, 2, 3]
vvs = [0, 1, 2, 3]

uus.sort()
vvs.sort()
N_uus = len(uus)
N_vvs = len(vvs)

# Reverse aas so subplots are in order -_-
uus.reverse()

color_highlights = {(1, 1): 'k', (0, 2): 'r', (0, 3): 'b'}

for ii, uu in enumerate(uus):
    for jj, vv in enumerate(vvs):
        subplot(N_uus, N_vvs, 1 + ii * N_vvs + jj)
        yy = linear_quadratic_reweighting(pct, uu, vv)
        is_valid = yy.min() >= 0
        base_clr = (.6, .6, .6) if is_valid else (.9, .9, .9)
        clr = color_highlights.get((uu, vv), base_clr)
        plot(pct, yy, c=clr, lw=2)
        
        gca().axhline(0, ls='-', c=base_clr)
        gca().axhline(1, ls=':', c=base_clr)
        gca().set_xticks([])
        gca().set_yticks([])
        plt.setp(gca().spines.values(), color=base_clr)
        if ii == len(uus)-1:
            xlabel('v=%g' % vv)
        if jj == 0:
            ylabel('u=%g' % uu)