In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
datasets = ['a9a', 'covtype']
reg_type = 'l2'

nim_minibatch_size = 100 if reg_type == 'l2' else 5000
methods = ['NIM', 'SAG', 'newton']
if reg_type == 'l2':
    methods.append('LBFGS')
    if not 'alpha' in datasets and not 'mnist8m' in datasets:
        methods.append('SFO')
methods.append('SGD')

def get_nd(dataset):
    if dataset == 'a9a': return (32561, 123)
    if dataset == 'mushrooms': return (8124, 112)
    if dataset == 'ijcnn1': return (49990, 22)
    if dataset == 'cod-rna': return (59535, 8)
    if dataset == 'covtype': return (581012, 54)
    if dataset == 'w8a': return (49749, 300)
    if dataset == 'protein': return (145751, 74)
    if dataset == 'quantum': return (50000, 65)
    if dataset == 'SUSY': return (5000000, 18)
    if dataset == 'alpha': return (500000, 500)
    if dataset == 'mnist8m': return (8100000, 784)
    if dataset == 'dna18m': return (18000000, 800)

    return (-1, -1)
        
def construct_fname(reg_type, dataset, method):
    suffix = 'dat'
    if method == 'NIM':
        suffix = 'minibatch_size=%d.dat' % (100 if reg_type == 'l2' else 5000)
    if method == 'SAG':
        suffix = 'minibatch_size=10.dat'
    if method == 'newton':
        suffix = 'exact=0.dat'
    return '%s/%s/%s.%s.%s.%s' % (reg_type, dataset, reg_type, dataset, method, suffix)

def tmethod(method):
    if method == 'newton': return 'Newton'
    return method

def tcolor(method):
    if method == 'NIM': return 'r'
    if method == 'SAG': return 'b'
    if method == 'newton' or method == 'SGD': return 'k'
    if method == 'SFO': return 'g'
    if method == 'LBFGS': return 'm'
    return 'y'
    
def tlinestyle(method):
    if method == 'SGD': return 'dashed'
    return 'solid'

for dataset in datasets:
    # Find optimal value
    f_opt = np.inf
    for method in methods:
        res_table = np.loadtxt(construct_fname(reg_type, dataset, method), skiprows=1)
        f_opt_idx = 2 if res_table.shape[1] == 4 else 3
        f_opt = min(f_opt, np.min(res_table[:, f_opt_idx]))
        
    # Plot
    fig = plt.figure()
    ax = fig.add_axes([0.15, 0.12, 0.6, 0.75])
    for method in methods:
        res_table = np.loadtxt(construct_fname(reg_type, dataset, method), skiprows=1)
        f_opt_idx = 2 if res_table.shape[1] == 4 else 3
        residual_f = res_table[:, f_opt_idx] - f_opt
        epochs = res_table[:, 0]
        # Clean a little bit
        mask = np.logical_and(residual_f > 1e-10, epochs < 50)
        residual_f = residual_f[mask]
        epochs = epochs[mask]
        
        ax.semilogy(epochs, residual_f, label=tmethod(method),
                     linewidth=4, color=tcolor(method), linestyle=tlinestyle(method))
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Residual in function')
    ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0)
    ax.grid()
    ax.set_title('%s (n=%d, d=%d)' % (dataset, get_nd(dataset)[0], get_nd(dataset)[1]))
    plt.savefig('pdf/%s.%s.epochs.pdf' % (reg_type, dataset))