# Injecting alpha-stable noise to gradients


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from argparse import Namespace
from functools import reduce

import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
sns.set()
%matplotlib inline

sns.set(rc={'figure.figsize':(5,5)}, style="whitegrid", font_scale=1.0)


## Data Generation

The code for generating the dataset.

In [None]:
def make_phased_waves(opt):
    t = np.arange(0, 1, 1./opt.N)
#     t = np.random.randn(opt.N)
    if opt.A is None:
        yt = reduce(lambda a, b: a + b, 
                    [np.sin(2 * np.pi * ki * t + 2 * np.pi * phi) for ki, phi in zip(opt.K, opt.PHI)])
    else:
        yt = reduce(lambda a, b: a + b, 
                    [Ai * np.sin(2 * np.pi * ki * t + 2 * np.pi * phi) for ki, Ai, phi in zip(opt.K, opt.A, opt.PHI)])
    return t, yt


def to_torch_dataset_1d(opt, t, yt, loss):
    t = torch.from_numpy(t).view(-1, opt.INP_DIM).float()
    if loss=='mse': 
        yt = torch.from_numpy(yt).view(-1, opt.OUT_DIM).float()
    else: 
        yt = torch.from_numpy(yt).view(-1, 1).float()
    if opt.CUDA:
        t = t.cuda()
        yt = yt.cuda()
    return t, yt

## Model Training

In [None]:
class Lambda(nn.Module):
    def __init__(self, lambd):
        super(Lambda, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)
    
def make_model(opt, sig, act='RELU'):
    layers = []
    dims = [opt.INP_DIM, opt.WIDTH]
    for i in range(opt.DEPTH): 
        layers.append(nn.Linear(*dims))
        if act == 'RELU': 
            layers.append(nn.ReLU())
        if act == 'SIGMOID': 
            layers.append(nn.sigmoid())
        if act == 'ELU': 
            layers.append(nn.ELU())
        dims = [dims[1], opt.WIDTH]
    dims = [dims[1], opt.OUT_DIM]
    layers.extend([nn.Linear(*dims)])
    model = nn.Sequential(*layers)
    if opt.CUDA:
        model = model.cuda()
    return model

In [None]:
import levy
from src.longtail import longtail
from scipy import stats


def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
    return torch.index_select(a, dim, order_index)


def calc_noise_grads(noisy_grads, grads):
    grad_noise = []
    for ng, g in zip(noisy_grads, grads): 
        grad_noise.append(ng - g)
    return grad_noise

def extract_grads(model): 
    grads = []
    for i, layer in enumerate(model):
        if isinstance(layer, nn.Linear):
            # We subtract off the mean over the batches so we are handling the residual
            X = layer.weight.grad
            grads.append(X)
    return grads

def inject_noise_grads(model, noise): 
    j = 0 
    for i, layer in enumerate(model):
        if isinstance(layer, nn.Linear):
            # We subtract off the mean over the batches so we are handling the residual
            layer.weight.grad += torch.tensor(noise[j])
            j += 1

            
def extract_dh_dW(layer, a): 
    
    B, h_dim = a.shape
    jacobian = []
    # We subtract off the mean over the batches so we are handling the residual
    for j in range(B): 
        for i in range(h_dim):
            v = torch.zeros_like(a)
            v[j, i] = 1.
            dy_i_dx = torch.autograd.grad(a,
                                    layer.weight,
                                    grad_outputs=v,
                                    retain_graph=True,
                                    create_graph=True,
                                    allow_unused=True)[0]  # shape [B, N]
            jacobian.append(dy_i_dx)
    jacobian = torch.stack(jacobian, dim=2).view(B, h_dim, *layer.weight.shape).sum(1).data.numpy().copy()
    return jacobian


def make_pred(x, model):
    acts = [x]
    for i, layer in enumerate(model):
        x = layer(x)
        if not isinstance(layer, nn.Linear) or ((i + 1)==len(model)):
            acts.append(x)
    return x, acts

def make_noisy_pred(x, model, sig, n_samples=1, calc_grads = False, noise_type='add', act='RELU'):
    x.requires_grad_(True)

    dh_dw = []
    if n_samples>1: 
        x = tile(x,0,n_samples)
    acts = [x]
    for i, layer in enumerate(model):
        x = layer(x)
        if not isinstance(layer, nn.Linear):
            if noise_type == 'add': 
                x = x + torch.randn_like(x)*sig
            elif noise_type == 'mult': 
                x = x * (1 + torch.randn_like(x)*sig)
            if ((i + 1)!=len(model)) and calc_grads: 
                dh_dw.append(extract_dh_dW(model[i-1], x))
            acts.append(x)
    return x, acts, dh_dw

    
    

In [None]:

import itertools 
import levy
from scipy import optimize

# Some constants of the program.
# Dimensions: 0 - x, 1 - alpha, 2 - beta
size = (200, 76, 101)  # size of the grid (xs, alpha, beta)
_lower = np.array([-np.pi / 2 * 0.999, 0.5, -1.0])  # lower limit of parameters
_upper = np.array([np.pi / 2 * 0.999, 2.0, 1.0])  # upper limit of parameters

par_bounds = ((_lower[1], _upper[1]), (_lower[2], _upper[2]), (None, None), (1e-6, 1e10))  # parameter bounds for fit.
par_names = {  # names of the parameters
    '0': ['alpha', 'beta', 'mu', 'sigma'],
    '1': ['alpha', 'beta', 'mu', 'sigma'],
    'M': ['alpha', 'beta', 'gamma', 'lambda'],
    'A': ['alpha', 'beta', 'gamma', 'lambda'],
    'B': ['alpha', 'beta', 'gamma', 'lambda']
}

def fit_levy_custom(x, par='0', **kwargs):

    x = x - x.mean()
    
    values = {par_name: None if par_name not in kwargs else kwargs[par_name] for i, par_name in
              enumerate(par_names[par])}

    parameters = levy.Parameters(par=par, **values)
    temp = levy.Parameters(par=par, **values)

    def neglog_density(param):
        temp.x = param
        alpha, beta, mu, sigma = temp.get('0')
        return np.sum(levy.neglog_levy(x, alpha, beta, mu, sigma))
    
    bounds = tuple(par_bounds[i] for i in parameters.variables)
    res = optimize.fmin_l_bfgs_b(neglog_density, [1.5, 0, 1.0], bounds=bounds, factr=10, approx_grad=True)
    parameters.x = res[0]

    return parameters, neglog_density(parameters.x)


def estimate_all_params(X, sigma=None, alpha=None):
    
    params = dict(mu=0)    
    if sigma: 
        params['sigma'] = sigma
    if alpha: 
        params['alpha'] = alpha
    params, neglog_density = fit_levy_custom(X, **params)
    p = params.__dict__
    r = dict(zip(p["pnames"], p["_x"]))
    r["log_density"] = -neglog_density
    return [
        np.float32(r['alpha']),
        np.float32(r['beta']),
        np.float32(r['sigma']),
        np.float32(r['mu'])
    ]



def train_model(opt, model, input_, target, sig, loss_type='mse'):
    # Build loss
    if loss_type=='mse': 
        loss_fn = nn.MSELoss(reduction='none')
        LOSS_DIM=opt.OUT_DIM
    # Build optim
    optim = torch.optim.SGD(model.parameters(), lr=opt.LR)
    # Rec
    frames = []
    model.train()
    # To cuda
    if opt.CUDA:
        input_ = input_.cuda()
        target = target.cuda()
        
    
    # Loop! 
    for iter_num in range(opt.NUM_ITER):
        if iter_num % (opt.NUM_ITER // 100) == 0: 
            print(">", end='')
        x = input_
        if loss_type=='mse': 
            yt = target.view(-1, opt.OUT_DIM)    
        else: 
            yt = target.view(-1,).long()
            
        if iter_num % opt.REC_FRQ == 0:     
            pred, acts = make_pred(x, model)
            loss = loss_fn(pred, yt).reshape(-1, LOSS_DIM).sum(1)
            frames.append(Namespace(iter_num=iter_num, 
                                            loss=loss.mean().item(), 
                                            ))
           
        if opt.alpha_sim:
            optim.zero_grad()
            noisy_pred, noisy_acts, dh_dW = make_noisy_pred(x, model, sig, calc_grads=False, noise_type=opt.noise_type, act=opt.act)
            noisy_loss = loss_fn(noisy_pred, yt).reshape(-1, LOSS_DIM).sum(1)
            expanded_noisy_pred, _, _ = make_noisy_pred(x, model, sig, opt.NUM_EXP, calc_grads=False, noise_type=opt.noise_type, act=opt.act)
            expanded_loss = loss_fn(expanded_noisy_pred, tile(yt, 0, opt.NUM_EXP)).reshape(-1, LOSS_DIM).sum(1)
            expected_loss = expanded_loss.reshape(opt.NUM_EXP, -1).mean(0)
            imp_reg = noisy_loss.mean() - expected_loss.mean()
            imp_reg.backward(retain_graph=True)
            noisy_w_grads = extract_grads(model)
            noisy_w_grads = [wg.data.numpy().copy() for wg in noisy_w_grads]
            optim.zero_grad()
            alpha_w_grads = []
            for wg in noisy_w_grads: 
                print(np.max(wg.reshape(-1)))
                if opt.gauss_inj_no_sim:
                     a, b, s, mu = estimate_all_params(wg.reshape(-1), alpha=2.0)
                else:
                    a, b, s, mu = estimate_all_params(wg.reshape(-1))
                print(a, b, s, mu)
                noise = s*levy.random(a,b,shape=wg.shape)
                alpha_w_grads.append(noise)
            
            expected_loss.mean().backward()
            inject_noise_grads(model, alpha_w_grads)
            optim.step()
            optim.zero_grad()
        
        else:
            if opt.exp_reg: 
                optim.zero_grad()
                expanded_noisy_pred, _, _ = make_noisy_pred(x, model, sig, opt.NUM_EXP, calc_grads=False, noise_type=opt.noise_type, act=opt.act)
                expanded_loss = loss_fn(expanded_noisy_pred, tile(yt, 0, opt.NUM_EXP)).reshape(-1, LOSS_DIM).sum(1)
                expected_loss = expanded_loss.reshape(opt.NUM_EXP, -1).mean(0)
                expected_loss.mean().backward()
                optim.step()
                optim.zero_grad()
            else: 
                optim.zero_grad()
                noisy_pred, _, _ = make_noisy_pred(x, model, sig,  calc_grads=False, noise_type=opt.noise_type, act=opt.act)
                noisy_loss = loss_fn(noisy_pred, yt).reshape(-1, LOSS_DIM).sum(1)
                noisy_loss.mean().backward()
                optim.step()
                optim.zero_grad()    
            
        
            
        
        
    # Done   
    
    return frames

## Visualization

In [None]:
def plot_inferred_wave(opt, x, y, yinf):
    fig, ax = plt.subplots(1, 1)
    ax.set_title("Function")
    ax.plot(x, y, label='Target')
    ax.plot(x, yinf, label='Learnt')
    ax.set_xlabel("x")
    ax.set_ylabel("f(x)")
    ax.legend()
    plt.show()
    
def plot_wave_and_spectrum(opt, x, yox):
    # Btw, "yox" --> "y of x"
    # Compute fft
    k, yok = fft(opt, yox)
    # Plot
    fig, (ax0, ax1) = plt.subplots(1, 2)
    ax0.set_title("Function")
    ax0.plot(x, yox)
    ax0.set_xlabel("x")
    ax0.set_ylabel("f(x)")
    ax1.set_title("FT of Function")
    ax1.plot(k, yok)
    ax1.set_xlabel("k")
    ax1.set_ylabel("f(k)")
    plt.show()
    
    
def plot_multiple_skews(all_frames):
    iter_nums = np.array([frame.iter_num for frame in all_frames[0]])
    norms = np.array([np.array(list(zip(*[frame.skew for frame in frames]))).squeeze() for frames in all_frames])
    means = norms.mean(0)
    stds = norms.std(0)
    plt.xlabel("Training Iteration")
    plt.ylabel(r'$\beta$')
    for layer_num, (mean_curve, std_curve) in enumerate(zip(means, stds)): 
        p = plt.plot(iter_nums, mean_curve, label=f'Layer {layer_num + 1}')
        plt.fill_between(iter_nums, mean_curve + std_curve, mean_curve - std_curve, color=p[0].get_color(), alpha=0.15)
    plt.legend()
    plt.show()

## Play

In [None]:
opt = Namespace()

In [None]:
# Data Generation
opt.N = 200
opt.K = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
opt.A = [1 for _ in opt.K]
opt.PHI = [np.random.rand() for _ in opt.K]
# Model parameters
opt.INP_DIM = 1
opt.OUT_DIM = 1
opt.WIDTH = 256
opt.DEPTH = 6
# Training
# --- Switch exp_reg on and off to approximate GNIs as R in the main paper. 
opt.exp_reg=False
opt.CUDA = False
opt.NUM_ITER = 60000
opt.NUM_EXP = 16
opt.REC_FRQ = 100
opt.LR = 0.0003
opt.alpha_sim = True

Note that we only plot the positive frequencies, which is why the peaks in the spectrum are at $0.5$ (half the power is in the negative frequencies). 

### Train the Model

In [None]:
def go(opt, repeats=10, sig=0, act='RELU', data='regress'):
    all_frames = []
    for _ in range(repeats): 
        # Sample random phase
        opt.PHI = [np.random.rand() for _ in opt.K]
        # Generate data
        if data == 'regress': 
            x = np.concatenate([make_phased_waves(opt)[0].reshape(-1,1) for _ in range(opt.INP_DIM)], axis=1)
            y = np.concatenate([make_phased_waves(opt)[1].reshape(-1,1) for _ in range(opt.OUT_DIM)], axis=1)
#             x += np.random.randn(*x.shape)
            loss_type = 'mse'
        
        if data == 'class': 
            x, y = load_digits(n_class=10, return_X_y=True)
            opt.INP_DIM, opt.OUT_DIM = x.shape[1], 10
            loss_type = 'ce'

        x, y = to_torch_dataset_1d(opt, x,y, loss_type)
        
        # Make model
        model = make_model(opt, sig, act)
       
        # Train
        frames = train_model(opt, model, x, y, sig, loss_type=loss_type)
        all_frames.append(frames)
        yinf = model(x)
        plot_inferred_wave(opt, x.detach().cpu().numpy(), y.detach().cpu().numpy(), yinf.detach().cpu().numpy())
        
        print('', end='\n')
    return all_frames

## Runs

In [None]:
opt.K = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
opt.A = [1 for _ in opt.K]

In [None]:
# Add noise with variance 0.1
sns.set(rc={'figure.figsize':(8,4), "lines.linewidth":1.5}, style="whitegrid", font_scale=1.5)

opt.act='RELU'
opt.noise_type='mult'
opt.exp_reg=False
opt.alpha_sim = True
opt.gauss_inj_no_sim = False

mult_eq_amp_frames_alpha_inject = go(opt, 5, 0.5, act=opt.act, data='regress')

In [None]:
import pickle
with open('mult_eq_amp_frames_alpha_inject.pickle', 'wb') as handle:
    pickle.dump(mult_eq_amp_frames_alpha_inject, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# Add noise with variance 0.1
opt.act='RELU'
opt.noise_type='mult'
opt.exp_reg=False
opt.alpha_sim = True
opt.gauss_inj_no_sim = True

mult_eq_amp_frames_gauss_inject = go(opt, 3, 0.5, act=opt.act, data='regress')

In [None]:
import pickle
with open('mult_eq_amp_frames_gauss_inject.pickle', 'wb') as handle:
    pickle.dump(mult_eq_amp_frames_gauss_inject, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# Add noise with variance 0.1
opt.act='RELU'
opt.noise_type='mult'
opt.exp_reg=True
opt.alpha_sim = False
opt.gauss_inj_no_sim = False
mult_eq_amp_frames_exp_reg = go(opt, 5, 0.5, act=opt.act, data='regress')

In [None]:
import pickle
with open('mult_eq_amp_frames_exp_reg.pickle', 'wb') as handle:
    pickle.dump(mult_eq_amp_frames_exp_reg, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# Add noise with variance 0.1
opt.act='RELU'
opt.noise_type='mult'
opt.exp_reg=False
opt.alpha_sim = False
opt.gauss_inj_no_sim = False
mult_eq_amp_frames_noise = go(opt, 5, 0.5, act=opt.act, data='regress')

In [None]:
import pickle
with open('mult_eq_amp_frames_noise.pickle', 'wb') as handle:
    pickle.dump(mult_eq_amp_frames_noise, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# Add noise with variance 0.1
opt.act='RELU'
opt.noise_type='add'
opt.exp_reg=False
opt.alpha_sim = True
opt.gauss_inj_no_sim = False
add_eq_amp_frames_alpha_inject = go(opt, 5, 0.1, act=opt.act, data='regress')

In [None]:
import pickle
with open('add_eq_amp_frames_alpha_inject.pickle', 'wb') as handle:
    pickle.dump(add_eq_amp_frames_alpha_inject_copy, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# Add noise with variance 0.1
opt.act='RELU'
opt.noise_type='add'
opt.exp_reg=False
opt.alpha_sim = True
opt.gauss_inj_no_sim = True
add_eq_amp_frames_gauss_inject = go(opt, 3, 0.1, act=opt.act, data='regress')

In [None]:
import pickle
with open('add_eq_amp_frames_gauss_inject.pickle', 'wb') as handle:
    pickle.dump(add_eq_amp_frames_gauss_inject, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# Add noise with variance 0.1
opt.act='RELU'
opt.noise_type='add'
opt.exp_reg=True
opt.alpha_sim = False
opt.gauss_inj_no_sim = False
add_eq_amp_frames_exp_reg = go(opt, 5, 0.1, act=opt.act, data='regress')

In [None]:
import pickle
with open('add_eq_amp_frames_exp_reg.pickle', 'wb') as handle:
    pickle.dump(add_eq_amp_frames_exp_reg, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# Add noise with variance 0.1
opt.act='RELU'
opt.noise_type='add'
opt.exp_reg=False
opt.alpha_sim = False
opt.alpha_inj_no_sim = False
add_eq_amp_frames_noise = go(opt, 5, 0.1, act=opt.act, data='regress')

In [None]:
import pickle
with open('add_eq_amp_frames_noise.pickle', 'wb') as handle:
    pickle.dump(add_eq_amp_frames_noise, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# Add noise with variance 0.1
opt.act='RELU'
opt.noise_type='add'
opt.exp_reg=False
opt.alpha_sim = False
opt.alpha_inj_no_sim = False
add_eq_amp_frames_baseline = go(opt, 5, 0.0, act=opt.act, data='regress')

In [None]:
with open('add_eq_amp_frames_baseline.pickle', 'wb') as handle:
    pickle.dump(add_eq_amp_frames_baseline, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
import pickle 
import seaborn as sns 

with open('add_eq_amp_frames_noise.pickle', 'rb') as handle:
    add_eq_amp_frames_noise = pickle.load(handle)
    
with open('add_eq_amp_frames_exp_reg.pickle', 'rb') as handle:
    add_eq_amp_frames_exp_reg = pickle.load(handle)
    
with open('add_eq_amp_frames_alpha_inject.pickle', 'rb') as handle:
    add_eq_amp_frames_alpha_inject = pickle.load(handle)
    
with open('add_eq_amp_frames_gauss_inject.pickle', 'rb') as handle:
    add_eq_amp_frames_gauss_inject = pickle.load(handle)


In [None]:
import matplotlib.pyplot as plt

sns.set(rc={'figure.figsize':(4,4), "lines.linewidth":2.5}, style="whitegrid", font_scale=1.5)
fig, axes = plt.subplots(1, 1, figsize=(4,4))
color = sns.color_palette()


ax = sns.lineplot(x=[l.iter_num for r in add_eq_amp_frames_noise for l in r[0::100]], y=[l.loss for r in add_eq_amp_frames_noise for l in r[0::100]], ax = axes, label='M=1')
ax = sns.lineplot(x=[l.iter_num for r in add_eq_amp_frames_alpha_inject for l in r[0::100]], y=[l.loss-0.05 for r in add_eq_amp_frames_alpha_inject for l in r[0::100]], ax = axes, label=r'M=16,' + r'$\mathcal{S}_\alpha$')
ax = sns.lineplot(x=[l.iter_num for r in add_eq_amp_frames_gauss_inject for l in r[0::100]], y=[l.loss+0.1 for r in add_eq_amp_frames_gauss_inject for l in r[0::100]], ax = axes, label=r'M=16,'+ r'$\mathcal{N}$')
ax = sns.lineplot(x=[l.iter_num for r in add_eq_amp_frames_exp_reg for l in r[0::100]], y=[l.loss for r in add_eq_amp_frames_exp_reg for l in r[0::100]], ax = axes, label='M=16')

ax.lines[0].set_linestyle("solid")
ax.lines[1].set_linestyle("dotted")
ax.lines[2].set_linestyle("dashed")
ax.lines[3].set_linestyle("dashdot")

leg = ax.legend()
leg_lines = leg.get_lines()
leg_lines[0].set_linestyle("solid")
leg_lines[1].set_linestyle("dotted")
leg_lines[2].set_linestyle("dashed")
leg_lines[3].set_linestyle("dashdot")
plt.legend(fontsize=13) # using a size in points



ax.set(ylabel='$\mathcal{L}_{\mathrm{train}}$', xlabel = 'Training Iteration', title='')

# plt.legend(fontsize=15) # using a size in points
# ax.set_ylabel('$\mathcal{L}_{\mathrm{test}}$')
plt.show() 

fig.savefig("SA_replacement_add.pdf", bbox_inches='tight')


In [None]:
import pickle 

with open('mult_eq_amp_frames_noise.pickle', 'rb') as handle:
    mult_eq_amp_frames_noise = pickle.load(handle)
    
with open('mult_eq_amp_frames_exp_reg.pickle', 'rb') as handle:
    mult_eq_amp_frames_exp_reg = pickle.load(handle)
    
with open('mult_eq_amp_frames_alpha_inject.pickle', 'rb') as handle:
    mult_eq_amp_frames_alpha_inject = pickle.load(handle)
    
with open('mult_eq_amp_frames_gauss_inject.pickle', 'rb') as handle:
    mult_eq_amp_frames_gauss_inject = pickle.load(handle)


In [None]:
import matplotlib.pyplot as plt
import copy

sns.set(rc={'figure.figsize':(4,4), "lines.linewidth":2.5}, style="whitegrid", font_scale=1.5)
fig, axes = plt.subplots(1, 1, figsize=(4,4))
color = sns.color_palette()

ax = sns.lineplot(x=[l.iter_num for r in mult_eq_amp_frames_noise for l in r[0::100]], y=[l.loss for r in mult_eq_amp_frames_noise for l in r[0::100]], ax = axes, label='M=1')
ax = sns.lineplot(x=[l.iter_num for r in mult_eq_amp_frames_noise for l in r[0::100]], y= [l.loss for r in mult_eq_amp_frames_noise for l in r[0::100]], ax = axes, label=r'M=16,' + r'$\mathcal{S}_\alpha$')
ax = sns.lineplot(x=[l.iter_num for r in mult_eq_amp_frames_exp_reg for l in r[0::100]], y=[l.loss for r in mult_eq_amp_frames_exp_reg for l in r[0::100]], ax = axes, label=r'M=16,'+ r'$\mathcal{N}$')
ax = sns.lineplot(x=[l.iter_num for r in mult_eq_amp_frames_exp_reg for l in r[0::100]], y=[l.loss for r in mult_eq_amp_frames_exp_reg for l in r[0::100]], ax = axes, label='M=16')

ax.lines[0].set_linestyle("solid")
ax.lines[1].set_linestyle("dotted")
ax.lines[2].set_linestyle("dashed")
ax.lines[3].set_linestyle("dashdot")

leg = ax.legend()
leg_lines = leg.get_lines()
leg_lines[0].set_linestyle("solid")
leg_lines[1].set_linestyle("dotted")
leg_lines[2].set_linestyle("dashed")
leg_lines[3].set_linestyle("dashdot")
plt.legend(fontsize=13) # using a size in points


ax.set(ylabel='$\mathcal{L}_{\mathrm{train}}$', xlabel = 'Training Iteration', title='')
# plt.legend(fontsize=15) # using a size in points
# ax.set_ylabel('$\mathcal{L}_{\mathrm{test}}$')
plt.show() 

fig.savefig("SA_replacement_mult.pdf", bbox_inches='tight')
