In [None]:
# -*- coding: utf-8 -*-
"""Myrtle_Cifar10_SGD_Nesterov_minibatches_piecewise_linear_LR_stop_osc_epoch_foriloop_weight_decay.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1A0Tuq5UEuoSjdZp6IaD9EWNhzJ-vk_-i
"""

from tensorflow.keras.datasets import cifar10
import sys
import time
import jax
from jax import grad, jit
from jax.tree_util import tree_multimap
import jax.numpy as jnp
from jax import random
from jax.experimental import stax
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax
import functools
from jax.experimental import optimizers
import pickle
import matplotlib.pyplot as plt
import argparse
import random as random_1
import numpy as np
from jax.tree_util import tree_flatten

# Utilities for working with tree-like container data structures.
# This module provides a small set of utility functions for working with tree-like data structures,
# such as nested tuples, lists, and dicts. We call these structures pytrees. They are trees
# in that they are defined recursively (any non-pytree is a pytree, i.e. a leaf,
# and any pytree of pytrees is a pytree) and can be operated on recursively
# (object identity equivalence is not preserved by mapping operations, and the structures cannot contain
# reference cycles).
# The primary purpose of this module is to enable the interoperability between user defined data
# structures and JAX transformations (e.g. jit). This is not meant to be a general purpose
# tree-like data structure handling library.


### TAKING A SUBSET OF CIFAR10 ###
def make_dataset_cifar10(num_samples, classes_to_classify, C):

  '''
  :param num_samples: number of samples taken for the newYtrain and newXtrain arrays (<= size of CIFAR 10)
  :param classes_to_classify: from the type of classes [0, 1, 2...10] of the data, focus in one in particular
  :param C:
  :return:
  '''

  (x_train, y_train), (x_test, y_test) = cifar10.load_data()  # 50 k for training, # 10 k for testing
  x_train, x_test = x_train / 255.0, x_test / 255.0  # normalization

  PosPerLabel = []  # list that stores the training example indexes ordered by their class
  PosPerLabel.append(np.where(y_train == 0)[0])
  for j in range(1, 10):
      PosPerLabel.append(np.where(y_train == j)[0])

  img = np.full(
      shape=num_samples,  # 1 - D array
      fill_value=22,
      dtype=np.int)

  weights = [1 for phi in classes_to_classify]  # weights for random selection of the classes!!
  for t in range(num_samples):
      img[t] = random_1.choices(classes_to_classify, weights)[0]

  # Allocates space for the new training sets

  newXtrain = np.full(
  shape=(num_samples, np.shape(x_train)[1], np.shape(x_train)[2], np.shape(x_train)[3]),  # 1, 2, and 3 are the RGB channel, the X position and the Y position in the picture
  fill_value=0,
  dtype=np.float)    # filled with zeros. Float because it will be normalized btw (0, 1)

  newYtrain = np.full(
  shape=num_samples,
  fill_value=0,
  dtype=np.float)

  for jj in range(num_samples):   # creates the new training subset, randomly selected

      indImg = random_1.choices(PosPerLabel[img[jj]])
      newXtrain[jj] = x_train[indImg]
      newYtrain[jj] = y_train[indImg]
      
  x_train = newXtrain
  y_train = newYtrain

  # for aa in range(10):
  #   plt.imshow(x_train[aa,:,:,:])  (M, N, 3):
  # array-like or PIL image The image data. Supported array shapes are: an image with RGB values
  #  (0-1 float or 0-255 int).
  #   plt.show()
  
  PosPerLabel_2 = []
  PosPerLabel_2.append(np.where(y_test == 0)[0])
  for j in range(1,10):    
      PosPerLabel_2.append(np.where(y_test == j)[0])

  num_test_samples = 0
  for aa in classes_to_classify:
      num_test_samples += np.shape(PosPerLabel_2[aa])[0]

  # all test samples from cifar10.load_data that have a class in classes to classify  will be taken

  newXtest = np.full(
  shape=(num_test_samples, np.shape(x_test)[1], np.shape(x_test)[2], np.shape(x_train)[3]),
  fill_value=0,
  dtype=np.float)

  newYtest = np.full(
  shape=num_test_samples,
  fill_value=0,
  dtype=np.float)

  ind_im = 0
  for jj in range(np.shape(y_test)[0]):

      if y_test[jj] in classes_to_classify: 
          newXtest[ind_im] = x_test[jj]
          newYtest[ind_im] = y_test[jj]
          ind_im += 1
  
  x_test = newXtest
  y_test = newYtest

  # after this, now we have our training and test samples

  y_train_vec = jnp.zeros((num_samples, C))  # like numpy zeros, allocating
  y_train_vec = jax.ops.index_update(y_train_vec, jax.ops.index[[iii for iii in range(num_samples)], (y_train.flatten()).astype(int)], 1)
  # y_train transformed into a 1 - D array, and stored in y_train_vect
  y_test_vec = jnp.zeros((num_test_samples, C))
  y_test_vec = jax.ops.index_update(y_test_vec, jax.ops.index[[iii for iii in range(num_test_samples)], (y_test.flatten()).astype(int)], 1)
  # samw for training
  
  # for aa in range(10):
  #   print(y_test[aa])
  #   plt.imshow(x_test[aa,:,:,:])
  #   plt.show()

  '''
  The output of device_put still acts like an NDArray, but it only copies values back to the CPU 
  when they’re needed for printing, plotting, saving to disk, branching, etc.
  The behavior of device_put is equivalent to the function jit(lambda x: x), but it’s faster.
  '''
  return jax.device_put(x_train), jax.device_put(y_train_vec), jax.device_put(x_test), jax.device_put(y_test_vec)


# # Network architecture described in 
# # Shankar et al., Neural Kernels Without Tangents, 2020.
# # https://arxiv.org/abs/2003.02237

### NORMAL INITIALIZATION ###

def MyrtleNetwork_ini(depth, myrtle_width, C):
  layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]}
  width = myrtle_width
  activation_fn = jax.experimental.stax.Relu
  layers = []
  '''
  functools.partial(func, /, *args, **keywords) Return a new partial object which when called 
  will behave like func called with the positional arguments args and keyword arguments keywords. 
  If more arguments are supplied to the call, they are appended to args. If additional keyword arguments are supplied,
  they extend and override keywords
  
  jax.experimental.stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_init=None,
  b_init=<function normal.<locals>.init>) Layer construction function for a general convolution layer.
  '''
  conv = functools.partial(jax.experimental.stax.Conv, padding='SAME')  # , W_init = initializer, b_init = initializer
 
  layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][0]
  layers += [jax.experimental.stax.AvgPool((2, 2), strides=(2, 2))]
  layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][1]
  layers += [jax.experimental.stax.AvgPool((2, 2), strides=(2, 2))]
  layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][2]
  layers += [jax.experimental.stax.AvgPool((2, 2), strides=(2, 2))] * 3

  '''
  def Flatten():
    """Layer construction function for flattening all but the leading dim."""
    def init_fun(rng, input_shape):
      output_shape = input_shape[0], functools.reduce(op.mul, input_shape[1:], 1)
      return output_shape, ()
    def apply_fun(params, inputs, **kwargs):
      return jnp.reshape(inputs, (inputs.shape[0], -1))
    return init_fun, apply_fun
    Flatten = Flatten()
  
  jax.experimental.stax.Dense(out_dim, W_init=<function variance_scaling.<locals>.init>, b_init=<function normal.<locals>.init>)[source]
  Layer constructor function for a dense (fully-connected) layer.
  '''
  layers += [Flatten, Dense(C)]

  '''
  jax.experimental.stax.serial(*layers)[source]
  Combinator for composing layers in serial.

  Parameters
  *layers – a sequence of layers, each an (init_fun, apply_fun) pair.

  Returns
  A new layer, meaning an (init_fun, apply_fun) pair, representing the serial composition
  of the given sequence of layers.
  '''

  return jax.experimental.stax.serial(*layers)  # Returns: A new layer, meaning an (init_fun, apply_fun) pair, 
  # representing the serial composition of the given sequence of layers

#### MINIBATCH GENERATOR ####
def data_stream(batch_size, num_batches, X_data, Y_data):
  seed = 0
  while True:
    # perm = rng.permutation(jnp.shape(X_data)[0])
    '''
    Unlike the stateful pseudorandom number generators (PRNGs) that users of NumPy and SciPy may be accustomed to,
    JAX random functions all require an explicit PRNG state to be passed as a first argument. 
    The random state is described by two unsigned 32-bit integers that we call a key,
    usually generated by the jax.random.PRNGKey() function: 
    
    jax.random.permutation(key, x)[source]
    Permute elements of an array along its first axis or return a permuted range.

    If x is a multi-dimensional array, it is only shuffled along its first index.
    '''
    key = jax.random.PRNGKey(seed)
    perm = jax.random.permutation(key, jnp.shape(X_data)[0])
    seed += 1
    for i in range(num_batches):
      batch_idx = perm[i * batch_size:(i + 1) * batch_size]  # from the index before the ":" to the index after that
      yield X_data[batch_idx], Y_data[batch_idx]

  # yield gives values at the precise moment, thus acting more like a generator

#### MINIBATCH GENERATOR ####

def data_stream_seed(batch_size, num_batches, X_data, Y_data, seed):
  while True:
    # perm = rng.permutation(jnp.shape(X_data)[0])
    key = jax.random.PRNGKey(seed)
    perm = jax.random.permutation(key, jnp.shape(X_data)[0])
    seed += 1
    for i in range(num_batches):
      batch_idx = perm[i * batch_size:(i + 1) * batch_size]
      yield X_data[batch_idx], Y_data[batch_idx]


 ### Piecewise_linear takes the epoch and gives you the learning rate following a 
 ### linear piecewise function going through the (points_x, points_y) ###

def piecewise_linear(x, points_x, points_y):
  for aaa in range(jnp.shape(points_x)[0]):
    if points_x[aaa]-x <= 0:
      ind = aaa
  slope = (points_y[ind+1]-points_y[ind])/(points_x[ind+1]-points_x[ind])
  y = points_y[ind] + (x-points_x[ind]) * slope
  return y

def main():

  try:
      get_ipython
      CLUSTER = False
  except:
      CLUSTER = True
  
  if CLUSTER:  # in the cluster, arguments are parsed (e.g. via the linux console)

    parser = argparse.ArgumentParser()

    parser.add_argument("-so", "--stop_osc", type=int )
    parser.add_argument("-soae", "--stop_osc_at_epoch", type=int )
    parser.add_argument("-softm", "--start_osc_from_the_middle", type=int )
    parser.add_argument('-wml','--w_max_list', nargs='+', type=int)  # list
    parser.add_argument('-np','--num_period',  type=int) 
    parser.add_argument("-rs", "--repeat_sim", type=int)
    parser.add_argument("-md", "--myrtle_depth", type=int)
    parser.add_argument("-mw", "--myrtle_width", type=int)
    parser.add_argument("-lft", "--loss_function_type")
    parser.add_argument("-tds", "--type_data_set")
    parser.add_argument("-ot", "--optimizer_type" )
    parser.add_argument("-mn", "--momentum_nest", type=float)
    parser.add_argument("-wd", "--weight_decay", type=float)
    parser.add_argument("-mbs", "--mini_batch_size", type=int)
    parser.add_argument('-eelr','--epochs_extrema_LR', nargs='+', type=int)  # list
    parser.add_argument('-LRe','--LRs_extrema', nargs='+', type=float)  # list
    parser.add_argument("-stsd", "--steps_to_save_data", type=int)
    parser.add_argument("-si", "--seed_initialization", type=int)
    
    args = parser.parse_args()

    ### PARAMETERS ###
    # DEFINING THE OSCILLATIONS
    stop_osc = args.stop_osc
    stop_osc_at_epoch = args.stop_osc_at_epoch
    start_osc_from_the_middle = args.start_osc_from_the_middle
    w_max_list = args.w_max_list
    num_periods_list = [args.num_period]
    repeat_sim = args.repeat_sim

    #MYRTLE Neural Net
    myrtle_depth = args.myrtle_depth
    myrtle_width = args.myrtle_width

    #LOSS 
    loss_function_type = args.loss_function_type

    ### DEFINING THE DATASET ###
    type_data_set = args.type_data_set
    if type_data_set == 'complete':
      C = 10
    elif type_data_set == 'customed':
      num_samples = 2000 # number samples in training dataset
      C = 2 #number of classes
      classes_to_classify = [aa for aa in range(C)]  # classes included in training, right now they have to be consecutive!!

    ### TRAINING THE NN ###
    optimizer_type = args.optimizer_type
    momentum_nest = args.momentum_nest
    weight_decay = args.weight_decay
    mini_batch_size = args.mini_batch_size
    epochs_extrema_LR = args.epochs_extrema_LR
    LRs_extrema = args.LRs_extrema
    steps_to_save_data = args.steps_to_save_data
    seed_initialization = args.seed_initialization

  else:

    ### PARAMETERS ###

    # DEFINING THE OSCILLATIONS
    stop_osc = 1  # int(sys.argv[6])
    stop_osc_at_epoch = 6
    start_osc_from_the_middle = 0  # int(sys.argv[7])
    w_max_list = [1]
    num_periods_list = [11]  # [11,31,51,101]
    repeat_sim = 1

    # MYRTLE Neural Net
    myrtle_depth = 5  # 10 #int(sys.argv[9])
    myrtle_width = 64  # int(sys.argv[10])

    #LOSS 
    loss_function_type = 'cross_entropy'  # 'MSE_old' #  'old' #'old'  'MSE' #

    ### DEFINING THE DATASET ###
    type_data_set = 'complete' # 'customed'  
    if type_data_set == 'complete':
      C = 10
    elif type_data_set == 'customed':
      num_samples = 1280  # number samples in training dataset
      C = 2  # number of classes
      classes_to_classify = [aa for aa in range(C)]  # classes included in training, right now they have to be consecutive!!

    ### TRAINING THE NN ###
    optimizer_type = 'nesterov'  # 'sgd' #
    momentum_nest = 0.9  # only for 'nesterov'
    mini_batch_size = 128  # 1024
    weight_decay = 0.0  # 0.0005*mini_batch_size
    epochs_extrema_LR = [0, 10, 30]  # [0, 15, 30, 35] #[30,15] #[15,15,15,15] #[5000] #
    LRs_extrema = [0.0, 0.04, 0.002]  # [0.1,0.01,0.001,0.0001] #[0.1] #
    steps_to_save_data = 100
    seed_initialization = 1  # int(sys.argv[13])


  if type_data_set == 'complete':

    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    num_samples = jnp.shape(x_train)[0]
    num_test_samples = jnp.shape(x_test)[0]

    C = 10
    classes_to_classify = [aaa for aaa in range(C)]
    
    y_train_vec = jnp.zeros((num_samples, C))
    y_train_vec = jax.ops.index_update(y_train_vec, jax.ops.index[[iii for iii in range(num_samples)], y_train.flatten()], 1)
    y_test_vec = jnp.zeros((num_test_samples, C))
    y_test_vec = jax.ops.index_update(y_test_vec, jax.ops.index[[iii for iii in range(num_test_samples)], y_test.flatten()], 1)
  
    X_data, Y_data, X_test, Y_test = jax.device_put(x_train), jax.device_put(y_train_vec), jax.device_put(x_test), jax.device_put(y_test_vec)

  elif type_data_set == 'customed':
    X_data, Y_data, X_test, Y_test = make_dataset_cifar10(num_samples, classes_to_classify, C)

  #### HALF PRECISION ###
  X_data = jax.device_put(jnp.array(X_data, jnp.float16))
  X_test = jax.device_put(jnp.array(X_test, jnp.float16))

  print('np.shape(X_data)',jnp.shape(X_data))
  print('np.shape(Y_data)',jnp.shape(Y_data))
  print('np.shape(X_test)',jnp.shape(X_test))
  print('np.shape(Y_test)',jnp.shape(Y_test))

  if stop_osc and stop_osc_at_epoch > epochs_extrema_LR[-1]:
    sys.exit('You would stop the oscillations when the simulation has alreay finished!')

  ### LOOP FOR W_MAX ###

  for num_periods in num_periods_list:

    for sim_stats in range(repeat_sim):

      for w_max in w_max_list:

        print('epochs_extrema_LR', epochs_extrema_LR)
        print('LRs_extrema', LRs_extrema)
        
        num_mini_batches = int(jnp.shape(X_data)[0]/mini_batch_size)
        total_time = (epochs_extrema_LR[-1]) * num_mini_batches
        if stop_osc:
          total_time_osc = (stop_osc_at_epoch) * num_mini_batches
        else:
          total_time_osc = (epochs_extrema_LR[-1]) * num_mini_batches

        T = int(float(total_time_osc)/float(num_periods))

        ### Building the Neural Network ###
        neural_net_tf = 'False'
        if neural_net_tf == 'True':
          init_fn, apply_fn, kernel_fn = MyrtleNetwork_ini(myrtle_depth, myrtle_width, C)
        else:
          init_fn, apply_fn = MyrtleNetwork_ini(myrtle_depth,myrtle_width,C)
        
        print('\n \n \n \n', 'NEW SIMULATION, MAX WEIGHT: ', w_max, '\n \n \n \n')
        print('num_mini_batches per epoch: ', num_mini_batches)
        print('total_time', total_time)
        print('total_time_osc', total_time_osc)
        print('num_periods', num_periods)
        print('T', T)
        
        epoch_ind = 0
        w_max_saved = w_max
        file_name =  'T_{0}_w_max_{1}_sim_repet_{2}.pkl'.format(T, w_max, sim_stats)
        ind = 0
        L = []
        L_weighted = []
        A = []
        A_test = []
        W = []
        NTK_EVALS = []
        eigv_list = []
        eigv_plot_times = []
        times_saved = []

        if start_osc_from_the_middle:
          t = T // 2 
        else:
          t = 0

        focus_class = 0

        ### LOSS FUNCTIONS ###
        @jit
        def c_fn(t , i, w_max):

          t = t % T
          slope = 2 * (w_max - 1) / T 
          w_main_class = jnp.where(t < T / 2., 1+ t * slope, 2 * w_max - t * slope - 1)
          res = jnp.ones(C) + (w_main_class-1) * jnp.eye(C)[i]
          # numpy.eye(N, M=None, k=0, dtype=<class 'float'>, order='C', *, like=None)[source] Return a 2-D array
          # with ones on the diagonal and zeros elsewhere.
          # Parameters: N (int) Number of rows in the output. M (int), optional Number of columns in the output.
          # If None, defaults to N. k (int), optional. Index of the diagonal: 0 (the default) refers to the main diagonal
          # a positive value refers to an upper diagonal, and a negative value to a lower diagonal.
          res = res / jnp.sum(res) * C  # normalizes to C value
          return res
        
        @jit
        def l2_squared(pytree):
          leaves, _ = tree_flatten(pytree)
          return sum([jnp.vdot(x, x) for x in leaves])

        @jit 
        def weighted_loss(params, X, Y, t, i, w_max):
          w = c_fn(t, i, w_max)
          return -jnp.mean(jax.nn.log_softmax(apply_fn(params, X)) * Y * w) * C + weight_decay * l2_squared(params)

        @jit
        def loss_CE(params, X, Y):
          return -jnp.mean(jax.nn.log_softmax(apply_fn(params, X)) * Y) * C + weight_decay * l2_squared(params)

        @jit
        def accuracy(params, X, Y):
          return jnp.mean(jnp.argmax(apply_fn(params, X), axis=1) == jnp.argmax(Y, axis=1))


        @jit
        def loss_MSE(params, X, Y, weight_decay):
          return jnp.mean((apply_fn(params, X) - Y)**2)*C + weight_decay * l2_squared(params)

        @jit
        def compute_grad(params, X_mini, Y_mini, weight_decay, t, i, w_max):
          if loss_function_type == 'cross_entropy':
            g = grad(weighted_loss)(params, X_mini, Y_mini, t, i, w_max)
          elif loss_function_type == 'MSE':
            if w_max != 1:
              sys.exit('MSE only for weght = 1 (no oscillations')
            g = grad(loss_MSE)(params, X_mini, Y_mini, weight_decay)
          return g

        ### INITIALIZATION ###
        key = random.PRNGKey(seed_initialization)
        _, params_ini = init_fn(key, X_data.shape)

        # ### HALF PRESICION ###
        # NOTE(sam): Explicitly cast to float16. 
        params_ini = jax.tree_map(lambda x: jnp.array(x, jnp.float16), params_ini)
        
        # #creating array with learning rates per time step (minibatch)
        lr_per_step = jnp.array([piecewise_linear(iii/num_mini_batches, epochs_extrema_LR, LRs_extrema) for iii in range(total_time)], jnp.float16)
        
        #creating array with learning rates per time step (minibatch)
        # lr_per_step = jnp.array([piecewise_linear(iii/num_mini_batches, epochs_extrema_LR, LRs_extrema) for iii in range(total_time)])
        
        #callable to get Learning Rate
        @jit
        def lr_callable(x):
          return lr_per_step[x]

        ### DEFINING NESTEROV OPTIMIZER###
        if optimizer_type == 'nesterov':
          (nest_init_fun, nest_update_fun, nest_get_params) = jax.experimental.optimizers.nesterov(lr_callable, momentum_nest)
          opt_state = nest_init_fun(params_ini)


        ### MINIMIZATION STARTS ###
        t_old = time.time()

        @jit
        def step_nesterov_foriloop(_,all_params):
          t, opt_state, focus_class, w_max, perm = all_params
          i = t % num_mini_batches
          batch_idx = jax.lax.dynamic_slice(perm, [i * mini_batch_size], [mini_batch_size]) #[i * mini_batch_size],[mini_batch_size]) # perm[i * mini_batch_size:(i + 1) * mini_batch_size]
          X_mini, Y_mini = X_data[batch_idx], Y_data[batch_idx]
          focus_class =  (focus_class + ((t-1) % (T))//(T-1) )  % C 
          params = nest_get_params(opt_state)
          g = grad(weighted_loss)(params, X_mini, Y_mini, t, focus_class, w_max)
          return t + 1, nest_update_fun(t, g, opt_state), focus_class, w_max, perm
        

        @jit
        def compute_loss_acc(_, all_params):
          i, t, focus_class, w_max, A_a, L_a, L_weighted_a, params_a = all_params
          batch_idx = jax.lax.dynamic_slice(jnp.arange(jnp.shape(X_data)[0]), [i * mini_batch_size], [mini_batch_size]) #[i * mini_batch_size],[mini_batch_size]) # perm[i * mini_batch_size:(i + 1) * mini_batch_size]
          X_mini, Y_mini = X_data[batch_idx], Y_data[batch_idx]
          A_a += accuracy(params_a, X_mini, Y_mini) 
          if loss_function_type == 'cross_entropy':
            L_a += loss_CE(params_a, X_mini, Y_mini) 
            L_weighted_a += weighted_loss(params_a, X_mini, Y_mini, t, focus_class, w_max) 
          return i+1, t, focus_class, w_max, A_a, L_a, L_weighted_a, params_a

        @jit
        def compute_loss_acc_test(_, all_params):
          i, A_a, params_a = all_params
          batch_idx = jax.lax.dynamic_slice(jnp.arange(jnp.shape(X_test)[0]), [i * mini_batch_size], [mini_batch_size]) #[i * mini_batch_size],[mini_batch_size]) # perm[i * mini_batch_size:(i + 1) * mini_batch_size]
          X_mini, Y_mini = X_test[batch_idx], Y_test[batch_idx]
          A_a += accuracy(params_a, X_mini, Y_mini) 
          return i+1, A_a, params_a

        ##### TRAINING EPOCH BY EPOCH #####

        seed_minibatch = 0
        for epoch_ind_2 in range(epochs_extrema_LR[-1]):

          if stop_osc and epoch_ind_2>=stop_osc_at_epoch:
            print('\n Minimization without oscillations now \n')
            w_max = 1

          t_old = time.time()
          seed_minibatch += 1
          key = jax.random.PRNGKey(seed_minibatch)
          perm = jax.random.permutation(key, jnp.shape(X_data)[0])

          t, opt_state, focus_class, w_max, perm = jax.lax.fori_loop(0, num_mini_batches, step_nesterov_foriloop, (t, opt_state, focus_class, w_max, perm))

          if optimizer_type == 'nesterov':
            params = nest_get_params(opt_state)
          t_new = time.time()

          #### COMPUTING LOSSES AND ACCURACIES #### 
          A_a = 0.0
          L_a = 0.0
          L_weighted_a = 0.0
          i=0
          i, t, focus_class, w_max, A_a, L_a, L_weighted_a, params  = jax.lax.fori_loop(0, num_mini_batches, compute_loss_acc, (i, t, focus_class, w_max, A_a, L_a, L_weighted_a, params))

          A_a_t = 0.0
          i=0
          num_minibat_test = jnp.shape(X_test)[0]//mini_batch_size
          i, A_a_t, params = jax.lax.fori_loop(0, num_minibat_test, compute_loss_acc_test, (i, A_a_t, params))

          times_saved += [t]
          A += [A_a/num_mini_batches]
          L += [L_a/num_mini_batches]
          L_weighted += [L_weighted_a/num_mini_batches]
          A_test += [A_a_t/num_minibat_test]
          W += [c_fn(t, focus_class, w_max)]

          t_new_2 = time.time()
          print('Epoch', epoch_ind_2, 't', t, 'focus_class', focus_class, 'Acc: ', A[-1], 'Acc test: ',
                A_test[-1], 'loss: ', L[-1], 'weighted loss: ', L_weighted[-1], 'minimization time: ',
                (t_new-t_old), 'time acc and loss: ', t_new_2-t_new)


        dictionary_data =  {'epochs_extrema_LR': epochs_extrema_LR,
                            'LRs_extrema': LRs_extrema,
                            'stop_osc': stop_osc,
                            'stop_osc_at_epoch': stop_osc_at_epoch,
                            'start_osc_from_the_middle': start_osc_from_the_middle,
                            'T': T,
                            'total_time': total_time,
                            'total_time_osc': total_time_osc,
                            'w_max': w_max_saved,
                            'loss_function_type': loss_function_type,
                            'num_samples': num_samples,
                            'classes_to_classify': classes_to_classify,
                            'L': L,
                            'L_weighted': L_weighted,
                            'W': W,
                            'A': A,
                            'A_test': A_test,
                            'myrtle_depth': myrtle_depth,
                            'myrtle_width': myrtle_width,
                            'C': C,
                            'times_saved': times_saved,
                            'num_periods': num_periods,
                            'repeat_sim': repeat_sim,
                            'seed_initialization': seed_initialization,
                            'type_data_set': type_data_set,
                            'optimizer_type': optimizer_type,
                            'momentum_nest': momentum_nest,
                            'weight_decay': weight_decay,
                            'mini_batch_size': mini_batch_size,
                            'steps_to_save_data': steps_to_save_data
                                }


        if type_data_set == 'customed':
          dictionary_data['num_samples'] = num_samples
          dictionary_data['classes_to_classify'] = classes_to_classify

        
        seed_initialization += 1

        # print('creating file\n T, w_max,',T, w_max)
        print('file: ', file_name)
              
        a_file = open(file_name, "wb")
        pickle.dump(dictionary_data, a_file)
        a_file.close()


        t_final = time.time()

        print('total time simulation: ', t_final - t_old)

main()

5e-4*128