In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
plt.figure(figsize=(6, 4))

plt.rcParams['font.size'] = 16
np.random.seed(12345678) # 

n_pts = 50
log_size_mean = -.75
log_norm_scale = 1.25
log_size = log_size_mean + log_norm_scale * np.random.randn(n_pts)
size = 10 ** log_size

reg_coef = 1.
log_noise_scale = .75
log_complexity = reg_coef * log_size + log_noise_scale * np.random.randn(n_pts)
additive_intercept = .1
complexity = 10 ** log_complexity 
complexity += additive_intercept

n_add_pts = 5
size = np.concatenate((
    size, 
    10 ** (-2 + .5 * np.random.randn(n_add_pts))
))
complexity = np.concatenate((
    complexity,
    10 ** (-.1 + .5 * np.random.randn(n_add_pts))
))


plt.semilogx(size, np.log10(complexity), 'x', ms=6.5)
plt.xlabel('Data size (GB)')
plt.ylabel('Complexity')
plt.yticks([], minor=False)

ax = plt.gca()
for side in ['right', 'top']:
    ax.spines[side].set_visible(False)
    ax.spines[side].set_position(("data", 0))
ax.tick_params(which='minor', length=4)
    
plt.xlim(left = 10 ** -2.9)
plt.ylim(bottom = - 1.15)
    # `x/ylim` don't work well without manually setting the lower lim
ax.plot(
    1, plt.ylim()[0], 
    ">k", ms=9, zorder=4,
    transform=ax.get_yaxis_transform(), 
    clip_on=False
)
ax.plot(
    plt.xlim()[0], 1, 
    "^k", ms=9, zorder=4,
    transform=ax.get_xaxis_transform(), 
    clip_on=False
)

plt.savefig(
    'data_size_vs_complexity.png',
    dpi=300,
    bbox_inches='tight',
    pad_inches=0.
)