In [1]:
# Import necessary libraries
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import numpy as np
from scipy.sparse import csc_matrix
from model import create_logistic_regression_model
from train import train_model
from get_dataset import get_dataset
from utils import get_trace, relative_round

# Set up visualization styles
sns.set(style="whitegrid", font_scale=1.2, context="talk", palette=sns.color_palette("bright"), color_codes=False)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['figure.figsize'] = (8, 6)

# Load dataset
dataset = 'w8a.txt'
A, b = get_dataset(dataset)

# Create logistic regression model
model = create_logistic_regression_model(A, b)

# Set training parameters
n, dim = A.shape
L = model.smoothness()
l2 = L / np.sqrt(n)
model.l2 = l2
x0 = csc_matrix((dim, 1))
n_epoch = 600
batch_size = 512
n_seeds = 2  # was set to 20 in the paper
stoch_it = 250 * n // batch_size
trace_len = 300
trace_path = f'../results/log_reg_{dataset}_l2_{relative_round(l2)}/'

# Train with different optimizers
optimizers = ['Nesterov', 'Sgd', 'Ig', 'Shuffling']
traces = []
labels = ['Nesterov', 'SGD', 'IG', 'Shuffling']
markers = [',', 'o', 'D', '*']

for opt in optimizers:
    trace = get_trace(f'{trace_path}{opt}')
    if not trace:
        trained_model, trace = train_model(model, opt, x0, n_epoch, 1/l2, trace_path=f'{trace_path}{opt}')
    traces.append(trace)

# Plotting the results
f_opt = np.min([np.min(trace.loss_vals) for trace in traces])
x_opt = trace.xs[-1]

for trace, label, marker in zip(traces, labels, markers):
    trace.plot_losses(f_opt=f_opt, label=label, marker=marker)
plt.yscale('log')
plt.legend()
plt.xlabel('Data passes')
plt.tight_layout()
plt.savefig(f'./plots/{dataset}_func.pdf', dpi=300)

for trace, label, marker in zip(traces, labels, markers):
    trace.plot_distances(x_opt=x_opt, label=label, marker=marker)
plt.yscale('log')
plt.legend()
plt.xlabel('Data passes')
plt.tight_layout()
plt.savefig(f'./plots/{dataset}_dist.pdf', dpi=300)

# Additional experiments can be added here, following a similar structure


TypeError: '>' not supported between instances of 'NoneType' and 'int'