In [None]:
#*****************************************************************************************************
# Fix settings concerning the ansatz, optimisation and backend
#*****************************************************************************************************
from time import time 
import pickle 
import numpy as np

from discopy import Ty, Id, Box, Diagram, Word
from discopy.rigid import Cup, Cap, Functor, Swap
from discopy.quantum.circuit import bit, qubit
from discopy.quantum import Measure
from discopy.quantum.tk import to_tk
from discopy.quantum.tk import Circuit as tk_Circuit_inDCP

from pytket import Circuit as tk_Circuit

#-----------------------------
# atomic pregroup grammar types
#-----------------------------
s, n = Ty('S'), Ty('N')

#----------------------------------------
# settings concerning the ansaetze
#----------------------------------------
q_s = 1        # number of qubits for sentence type s
q_n = 1        # number of qubits for noun type n
depth = 1      # number of IQP layers for non-single-qubit words
p_n = 3        # number of parameters for a single-qubit word (noun); valued in {1,2,3}.

#----------------------------------------
# Parameters concerning the optimisation
#----------------------------------------
n_runs = 20       # number of runs over training procedure
niter  = 2000     # number of iterations for any optimisation run of training.


In [None]:
#******************************************
# Data import
#******************************************

# import train and test datasets: the data entries are all strings of the form 'label sentence' 
# with the label in {0,1} and with the sentence of the form "word1_POStag1 word2_POStag2 ..."

with open('../datasets/mc_train_data.txt') as f:
    training_data_raw = f.readlines()
    
with open('../datasets/mc_dev_data.txt') as f:
    dev_data_raw = f.readlines()

with open('../datasets/mc_test_data.txt') as f:
    testing_data_raw = f.readlines()

In [None]:
#***************************************************************
# Turn the raw input data into data structures convenient below
#***************************************************************

vocab = dict()          # dictionary to be filled with the vocabulary in the form { word : POStag }
data = dict()           # dictionary to be filled with all the data (train, dev and test subsets); entries of the 
                        # form { sentence : label } with label encoding '1' as [1.0, 0.0] and '0' as [0.0, 1.0]
training_data = []      # list of sentences in the train dataset as strings "word1 word2 ..."
dev_data = []           # list of sentences in the dev dataset as strings "word1 word2 ..."
testing_data = []       # list of sentences in the test dataset as strings "word1 word2 ..."

# Go through the train data
for sent in training_data_raw:
    words = sent[2:].split() 
    sent_untagged = ''
    for word in words:
        word_untagged, tag = word.split('_')
        vocab[word_untagged] = tag
        sent_untagged += word_untagged + ' '
    sentence = sent_untagged[:-1]
    training_data.append(sentence)
    label = np.array([1.0, 0.0]) if sent[0] == '1' else np.array([0.0, 1.0])
    data[sentence] = label

# Go through the dev data
for sent in dev_data_raw:
    words = sent[2:].split() 
    sent_untagged = ''
    for word in words:
        word_untagged, tag = word.split('_')
        vocab[word_untagged] = tag
        sent_untagged += word_untagged + ' '
    sentence = sent_untagged[:-1]
    dev_data.append(sentence)
    label = np.array([1.0, 0.0]) if sent[0] == '1' else np.array([0.0, 1.0])
    data[sentence] = label
    
# Go through the test data
for sent in testing_data_raw:
    words = sent[2:].split() 
    sent_untagged = ''
    for word in words:
        word_untagged, tag = word.split('_')
        vocab[word_untagged] = tag
        sent_untagged += word_untagged + ' '
    sentence = sent_untagged[:-1]
    testing_data.append(sentence)
    label = np.array([1.0, 0.0]) if sent[0] == '1' else np.array([0.0, 1.0])
    data[sentence] = label

In [None]:
#*****************************************************
# The sentences as diagrams via CFG production rules
#*****************************************************

#----------------------------
# Further POS tags:
#----------------------------
nphr, adj, tv, iv, vphr = Ty('NP'), Ty('ADJ'), Ty('TV'), Ty('IV'), Ty('VP')

#----------------------------
# The vocabulary in DisCoPy
#----------------------------
vocab_dict_boxes = dict()
for word, tag in vocab.items():
    if tag == 'N':
        vocab_dict_boxes.update({word: Word(word, n)})
    if tag == 'TV':
        vocab_dict_boxes.update({word: Word(word, tv)})
    if tag == 'ADJ':
        vocab_dict_boxes.update({word: Word(word, adj)})

#-------------------------------------
# The CFG production rules as boxes
#-------------------------------------
r0 = Box('R0', nphr @ vphr, s)
r1 = Box('R1', tv @ nphr, vphr)
r2 = Box('R2', adj @ n, nphr)
r3 = Box('R3', iv, vphr)
r4 = Box('R4', n, nphr)

#---------------------------------------------
# The needed grammatical sentence structures
#---------------------------------------------
grammar_dict = {
    'N_TV_N': ((Id(n @ tv) @ r4) >> (r4 @ r1) >> r0),
    'N_TV_ADJ_N': ((Id(n @ tv) @ r2) >> (r4 @ r1) >> r0),
    'ADJ_N_TV_N': ((Id(adj @ n @ tv) @ r4) >> (r2 @ r1) >> r0),
}

#---------------------------------------------
# Create CFG diagrams for the sentences
#---------------------------------------------
sentences_dict = dict()
for sentstr in list(data.keys()):
    grammar_id = ''
    sentence = Id(Ty())
    for word in sentstr.split(' '):
        grammar_id += (vocab[word] + '_')
        sentence = sentence @ vocab_dict_boxes[word]
    grammar_id = grammar_id[:-1]
    sentence = sentence >> grammar_dict[grammar_id]
    sentences_dict.update({sentstr: [sentence, grammar_id]})

In [None]:
#***************************************************************
# Translation to pregroup grammar 
#***************************************************************
from discopy.grammar.pregroup import draw

# From POS tags to Pregroup types:
ob_pg = {n: n, s: s, adj: n @ n.l, tv: n.r @ s @ n.l, vphr:  n.r @ s, nphr: n}

# From CFG rules to Pregroup reductions: 
ar_pg = {
    r0: Cup(n, n.r) @ Id(s),
    r1: Id(n.r @ s) @ Cup(n.l, n),
    r2: Id(n) @ Cup(n.l, n),
    r3: Id(n.r @ s),
    r4: Id(n)
}

# The vocabulary as DisCoPy boxes with pregroup types
vocab_pg = [Word(vocab_dict_boxes[word].name, ob_pg[vocab_dict_boxes[word].cod]) for word in vocab.keys()]

# The mapping of morphisms
ar_pg.update({vocab_dict_boxes[word]: Word(vocab_dict_boxes[word].name, ob_pg[vocab_dict_boxes[word].cod]) for word in vocab.keys()})

# The functor that translates from CFG to pregroup
t2p = Functor(ob_pg, ar_pg)

sentences_pg_dict = dict()
for sentstr in sentences_dict:
    sentences_pg_dict.update({sentstr: [t2p(sentences_dict[sentstr][0]), sentences_dict[sentstr][1]]})

In [None]:
#******************************************************************************************************
# (Optional) For visualisation: the sentences as pregroup diagrams -- before 'bending nouns around'
#******************************************************************************************************

for sentstr in sentences_pg_dict:
    sentences_pg_dict[sentstr][0].draw()

In [None]:
#******************************************************
# Bending the nouns around
#******************************************************
sentences_pg_psr_dict = dict()

for sentstr in sentences_pg_dict:
    grammar_id = sentences_pg_dict[sentstr][1]
    num_words = len(grammar_id.split('_'))
    words = sentences_pg_dict[sentstr][0][:num_words].boxes
    grammar = sentences_pg_dict[sentstr][0][num_words:]
    if grammar_id == 'N_TV_N':
        noun1 = Box(words[0].name, n.r, Ty())
        noun2 = Box(words[2].name, n.l, Ty())
        words_new = (Cap(n.r, n) @ Cap(n, n.l)) >> (noun1 @ Id(n) @ words[1] @ Id(n) @ noun2)
    if grammar_id == 'ADJ_N_TV_N':
        noun1 = Box(words[1].name, n.l, Ty())
        noun2 = Box(words[3].name, n.l, Ty())
        words_new = (Cap(n, n.l) @ Cap(n, n.l)) >> (words[0] @ Id(n) @ noun1 @ words[2] @ Id(n) @ noun2)
    if grammar_id == 'N_TV_ADJ_N':
        noun1 = Box(words[0].name, n.r, Ty())
        noun2 = Box(words[3].name, n.l, Ty())
        words_new = (Cap(n.r, n) @ Cap(n, n.l)) >> (noun1 @ Id(n) @ words[1] @ words[2] @ Id(n) @ noun2)
    # add newly wired sentence to dictionary
    sentence = words_new >> grammar
    # Yank snakes and add to dictionary
    sentences_pg_psr_dict.update({sentstr: sentence.normal_form()})

# Now for the vocab
vocab_psr = []
for word in vocab_pg:
    if word.cod == Ty('N'):
        vocab_psr.append(Box(word.name, n.r, Ty()))   # n.l case is dealt with in definition of quantum functor
    else:
        vocab_psr.append(word)

In [None]:
#******************************************************************************************************
# (Optional) For visualisation: the sentences as pregroup diagrams -- after 'bending nouns around'
#******************************************************************************************************
for sentstr in sentences_pg_psr_dict:
    sentences_pg_psr_dict[sentstr].draw()

In [None]:
#*****************************************************
# Translation to quantum circuits
#*****************************************************
from discopy.quantum import Ket, IQPansatz, Bra
from discopy.quantum.gates import sqrt, H, CZ, Rz, Rx, CX
from discopy.quantum.circuit import Id
from discopy import CircuitFunctor
from discopy.quantum.circuit import Circuit as DCP_Circuit

ob = {s: q_s, n: q_n}                           # assignment of number of qubits to atomic grammatical types
ob_cqmap = {s: qubit ** q_s, n: qubit ** q_n}   # the form in which it is needed for discopy's cqmap module

#-----------------------------------------
# parametrised part of ansaetze
#-----------------------------------------

def single_qubit_iqp_ansatz(params):
    if len(params) == 1:
        return Rx(params[0])  
    if len(params) == 2:
        return Rx(params[0]) >> Rz(params[1])
    if len(params) == 3:
        return IQPansatz(1, params)       

def ansatz_state(state, params):  
    arity = sum(ob[Ty(factor.name)] for factor in state.cod)
    if arity == 1:
        return Ket(0) >> single_qubit_iqp_ansatz(params)
    else:
        return Ket(*tuple([0 for i in range(arity)])) >> IQPansatz(arity, params)
    
def ansatz_effect(effect, params):  
    arity = sum(ob[Ty(factor.name)] for factor in effect.dom)
    if arity == 1:
        return single_qubit_iqp_ansatz(params) >> Bra(0)
    else:
        return IQPansatz(arity, params) >> Bra(*tuple([0 for i in range(arity)]))
       
def ansatz(box, params):
    dom_type = box.dom
    cod_type = box.cod
    if len(dom_type) == 0 and len(cod_type) != 0:
        return ansatz_state(box, params)
    if len(dom_type) != 0 and len(cod_type) == 0:
        return ansatz_effect(box, params)

#----------------------------------------------------------
# Define parametrised functor to quantum circuits
#----------------------------------------------------------

def F(params): 
    ar = dict()
    for i in range(len(vocab_psr)):
        pgbox = vocab_psr[i]
        qbox = ansatz(vocab_psr[i], params[i])
        ar.update({pgbox: qbox})
        if pgbox.cod == Ty():
            ar.update({Box(pgbox.name, n.l, Ty()): qbox})
    return CircuitFunctor(ob_cqmap, ar)

In [None]:
#*****************************************************
# The functions to deal with the parametrisation
#*****************************************************

def param_shapes(vocab_psr):
    parshapes = []    
    for box in vocab_psr:
        dom_type = box.dom
        cod_type = box.cod
        dom_arity = sum(ob[Ty(factor.name)] for factor in box.dom)
        cod_arity = sum(ob[Ty(factor.name)] for factor in box.cod)
        if dom_arity == 0 or cod_arity == 0:  # states and effects
            arity = max(dom_arity, cod_arity)
            assert arity != 0
            if arity == 1:
                parshapes.append((p_n,))       
            if arity != 1:
                parshapes.append((depth, arity-1))
    return parshapes

def rand_params(par_shapes):
    params = np.array([]) 
    for i in range(len(par_shapes)):
         params = np.concatenate((params, np.ravel(np.random.rand(*par_shapes[i]))))
    return params 

def reshape_params(unshaped_pars, par_shapes):
    pars_reshaped = [[] for ii in range(len(par_shapes))]
    shift = 0
    for ss, s in enumerate(par_shapes):
        idx0 = 0 + shift
        if len(s) == 1:
            idx1 = s[0] + shift
        elif len(s) == 2:
            idx1 = s[0] * s[1] + shift
        pars_reshaped[ss] = np.reshape(unshaped_pars[idx0:idx1], s)
        if len(s) == 1:
            shift += s[0]
        elif len(s) == 2:
            shift += s[0] * s[1]
    return pars_reshaped

In [None]:
#****************************************
# The parameters of the current model
#****************************************

par_shapes = param_shapes(vocab_psr)
rand_unshaped_pars = rand_params(par_shapes)
rand_shaped_pars = reshape_params(rand_unshaped_pars, par_shapes)

print('Number of parameters:    ', len(rand_unshaped_pars))

In [None]:
#**************************************************************
# (Optional) Quantum circuit diagrams for the sentences 
#**************************************************************

func = F(rand_shaped_pars)

for sentstr in sentences_pg_psr_dict:
    func(sentences_pg_psr_dict[sentstr]).draw(draw_box_labels=True, figsize=(5, 5), nodesize = 0.3)

In [None]:
#********************************************************************************************
# Encode data such that the circuits (for one call of cost function etc.) can be sent as one
# job to quantum hardware.
#********************************************************************************************

train_labels = []
train_circuits_pg_psr = []
for sentstr in training_data:
    train_circuits_pg_psr.append(sentences_pg_psr_dict[sentstr])
    train_labels.append(np.array(data[sentstr]))
train_labels = np.array(train_labels)

dev_labels = []
dev_circuits_pg_psr = []
for sentstr in dev_data:
    dev_circuits_pg_psr.append(sentences_pg_psr_dict[sentstr])
    dev_labels.append(np.array(data[sentstr]))
dev_labels = np.array(dev_labels)

test_labels = []
test_circuits_pg_psr = []
for sentstr in testing_data:
    test_circuits_pg_psr.append(sentences_pg_psr_dict[sentstr])
    test_labels.append(np.array(data[sentstr]))
test_labels = np.array(test_labels)

In [None]:
#**********************************************************************************************************
# The cost function for optimisation and the error functions 
#**********************************************************************************************************
from jax import numpy as jnp
from jax import jit

def get_cost(unshaped_params):
    func = F(reshape_params(unshaped_params, par_shapes))
    train_circuits = [func(circ) for circ in train_circuits_pg_psr]
    results = DCP_Circuit.eval(*train_circuits)
    results_tweaked = [jnp.abs(jnp.array(res.array) - 1e-9) for res in results]
    pred_labels_distrs = [res / jnp.sum(res) for res in results_tweaked]
    cross_entropies = jnp.array([jnp.sum(train_labels[s] * jnp.log2(pred_labels_distrs[s])) for s in range(len(train_labels))])
    return -1 / len(training_data) * jnp.sum(cross_entropies)

def get_train_error(unshaped_params):
    func = F(reshape_params(unshaped_params, par_shapes))
    train_circuits = [func(circ) for circ in train_circuits_pg_psr]
    results = DCP_Circuit.eval(*train_circuits)
    results_tweaked = [jnp.abs(jnp.array(res.array) - 1e-9) for res in results]
    pred_labels_distrs = [res / jnp.sum(res) for res in results_tweaked]
    assert len(pred_labels_distrs[0]) == 2  # rounding only makes sense if labels are binary tuples
    pred_labels = [jnp.round(res) for res in pred_labels_distrs]
    error = 0.0
    for i in range(len(pred_labels)):
        diff = jnp.sum(jnp.abs(train_labels[i] - pred_labels[i]))
        error += jnp.min(jnp.array([diff, 1.0]))
    return error * 100 / len(training_data)

def get_dev_error(unshaped_params):
    func = F(reshape_params(unshaped_params, par_shapes))
    dev_circuits = [func(circ) for circ in dev_circuits_pg_psr]
    results = DCP_Circuit.eval(*dev_circuits)
    results_tweaked = [jnp.abs(jnp.array(res.array) - 1e-9) for res in results]
    pred_labels_distrs = [res / jnp.sum(res) for res in results_tweaked]
    assert len(pred_labels_distrs[0]) == 2  # rounding only makes sense if labels are binary tuples
    pred_labels = [jnp.round(res) for res in pred_labels_distrs]
    error = 0.0
    for i in range(len(pred_labels)):
        diff = jnp.sum(jnp.abs(dev_labels[i] - pred_labels[i]))
        error += jnp.min(jnp.array([diff, 1.0]))
    return error * 100 / len(dev_data)


def get_test_error(unshaped_params):
    func = F(reshape_params(unshaped_params, par_shapes))
    test_circuits = [func(circ) for circ in test_circuits_pg_psr]
    results = DCP_Circuit.eval(*test_circuits)
    results_tweaked = [jnp.abs(jnp.array(res.array) - 1e-9) for res in results]
    pred_labels_distrs = [res / jnp.sum(res) for res in results_tweaked]
    assert len(pred_labels_distrs[0]) == 2  # rounding only makes sense if labels are binary tuples
    pred_labels = [jnp.round(res) for res in pred_labels_distrs]
    error = 0.0
    for i in range(len(pred_labels)):
        diff = jnp.sum(jnp.abs(test_labels[i] - pred_labels[i]))
        error += jnp.min(jnp.array([diff, 1.0]))
    return error * 100 / len(testing_data)

In [None]:
#*************************************************************************************************
# Define the jitted versions of above three functions. Then one by one do a 
# loop over two calls to let jit do its thing so that function call is fast when doing optimisation.
#*************************************************************************************************

get_cost_jit = jit(get_cost)
get_train_error_jit = jit(get_train_error)
get_dev_error_jit = jit(get_dev_error)
get_test_error_jit = jit(get_test_error)

In [None]:
for i in range(2):
    rand_unshaped_pars = rand_params(par_shapes)
    print('-------------')
    start = time()
    print('Cost: ', get_cost_jit(rand_unshaped_pars))
    print('Time taken for this iteration: ', time()-start)

In [None]:
for i in range(2):
    rand_unshaped_pars = rand_params(par_shapes)
    print('-------------')
    start = time()
    print('Train Error: ', get_train_error_jit(rand_unshaped_pars))
    print('Time taken for this iteration: ', time()-start)

In [None]:
for i in range(2):
    rand_unshaped_pars = rand_params(par_shapes)
    print('-------------')
    start = time()
    print('Dev Error: ', get_dev_error_jit(rand_unshaped_pars))
    print('Time taken for this iteration: ', time()-start)

In [None]:
for i in range(2):
    rand_unshaped_pars = rand_params(par_shapes)
    print('-------------')
    start = time()
    print('Test Error: ', get_test_error_jit(rand_unshaped_pars))
    print('Time taken for this iteration: ', time()-start)

In [None]:
#**********************************************************************************
# Minimization algorithm
#**********************************************************************************

# This is building on the minimizeSPSA function from the noisyopt package (https://github.com/andim/noisyopt);
# here only adjusted for our purposes (mostly with quantum implementations in mind)

def my_spsa(get_cost, get_train_error, x0,
            bounds=None, niter=100, a=1.0, c=1.0, alpha=0.602, gamma=0.101,
            print_iter=False, filename='spsa_output'):
    A = 0.01 * niter
    N = len(x0)
    if bounds is None:
        project = lambda x: x
    else:
        bounds = np.asarray(bounds)
        project = lambda x: np.clip(x, bounds[:, 0], bounds[:, 1])    
    param_history = []
    func_history = []
    error_history = []
    x = x0    
    
    # Loop over iterations
    for k in range(niter):
        if print_iter:
            print('-------------', '\n', 'iteration: ', k, sep='')
        start = time()
        
        # determine stepping parameters
        ak = a/(k+1.0+A)**alpha
        ck = c/(k+1.0)**gamma
        delta = np.random.choice([-1, 1], size=N)
        
        # move in + direction from previous x
        xplus = project(x + ck*delta)        
        if print_iter:
            print('Call for xplus')
        funcplus = get_cost(xplus)
        
        # move in - direction from previous x
        xminus = project(x - ck * delta)
        if print_iter:
            print('Call for xminus')
        funcminus = get_cost(xminus)
        
        # new step
        grad = (funcplus - funcminus) / (xplus-xminus)
        x = project(x - ak*grad)
        param_history.append(x)
        
        # determine current func and error
        current_func_value = get_cost(x)
        error = get_train_error(x)
        func_history.append(current_func_value)
        error_history.append(error)

        # save to file
        dump_data = {
            'param_history': param_history,
            'func_history': func_history,
            'error_history': error_history
        }
        with open(filename+'.pickle', 'wb') as file_handle:
            pickle.dump(dump_data, file_handle)
        
        if print_iter:
            print('Time taken for this iteration: ', time() - start)
    return param_history, func_history, error_history 

In [None]:
#*****************************************************
# Optimisation settings
#*****************************************************

bounds = [[0.0, 1.0] for ii in range(len(rand_unshaped_pars))]

c_fix = 0.1  
#leave alpha and gamma as the default (which is as recommended), i.e. alpha = 0.602 and gamma = 0.101

#-----------------------------------
# calculate well-educated guess for parameter 'a'. 
# (below calcucation follows the heuristics from: 
#  https://www.jhuapl.edu/SPSA/PDF-SPSA/Spall_Implementation_of_the_Simultaneous.PDF )
#-----------------------------------

desired_param_change = 0.005  # rough change of a parameter in early iterations.

alpha = 0.602     
A = niter*0.01    
c_0 = c_fix

a_est = 0.0
nruns = 1000
for l in range(nruns):
    rand_unshaped_pars = rand_params(par_shapes)
    delta_0 = np.array([np.random.randint(0,2) for i in range(len(rand_unshaped_pars))])*2\
              - np.array([1 for i in range(len(rand_unshaped_pars))])
    g0_estimate = (get_cost_jit(rand_unshaped_pars + c_0*delta_0) - get_cost_jit(rand_unshaped_pars - c_0*delta_0))/(2*c_0)
    a_est += np.abs(desired_param_change * ((A +1)**alpha) / g0_estimate )
a_est = a_est/nruns
print('Calculated good choice for a=', a_est)

In [None]:
#**********************************************************
# Training the model
#**********************************************************
param_histories = []
cost_histories = np.zeros((n_runs, niter))
error_train_histories = np.zeros((n_runs, niter))

for i in range(n_runs):
    print('---------------------------------')
    print('Start run ', i+1)
    rand_unshaped_pars = rand_params(par_shapes)
    start = time()
    res = my_spsa(get_cost_jit, get_train_error_jit, rand_unshaped_pars,
                  bounds=bounds, niter=niter, a=a_est, c=c_fix,
                  print_iter=False, filename=('RP_task_SPSAOutput_ECS_Run' + str(i)))
    param_histories.append(res[0])   
    cost_histories[i, :] = res[1]
    error_train_histories[i, :] = res[2]    
    print('run', i+1, 'done')
    print('Time taken: ', round(time() - start,2))

In [None]:
#************************************
# Calculate dev errors
#************************************
error_dev_histories = np.zeros((n_runs,niter))

for i in range(n_runs):
    dev_errors = []
    for params in param_histories[i]:
        dev_errors.append(get_dev_error_jit(params))
    error_dev_histories[i,:] = dev_errors

In [None]:
#************************************
# Calculate test errors
#************************************
error_test_histories = np.zeros((n_runs,niter))

for i in range(n_runs):
    test_errors = []
    for params in param_histories[i]:
        test_errors.append(get_test_error_jit(params))
    error_test_histories[i,:] = test_errors

In [None]:
#************************************
# Calculate average cost and errors
#************************************

cost_history_mean = np.zeros(niter)
error_train_history_mean = np.zeros(niter)
error_dev_history_mean = np.zeros(niter)
error_test_history_mean = np.zeros(niter)

for i in range(niter):
    cost_history_mean[i] = np.mean(cost_histories[:,i])
    error_train_history_mean[i] = np.mean(error_train_histories[:,i]) 
    error_dev_history_mean[i] = np.mean(error_dev_histories[:,i]) 
    error_test_history_mean[i] = np.mean(error_test_histories[:,i]) 

In [None]:
#****************************************************
# Summary plot
#****************************************************
from matplotlib import pyplot as plt

plt.rcParams.update({"text.usetex": True})
fig, ax1 = plt.subplots(figsize=(13, 8))


ax1.plot(range(len(cost_history_mean)), cost_history_mean, '-k', markersize=4, label='cost')
ax1.set_ylabel(r"Cost", fontsize='x-large')
ax1.set_xlabel(r"SPSA~iterations", fontsize='x-large')
ax1.legend(loc='upper center', fontsize='x-large')

ax2 = ax1.twinx()
ax2.set_ylabel(r"Error in \%", fontsize='x-large')
ax2.plot(range(len(error_train_history_mean)), error_train_history_mean, '-g', markersize=4, label='train error')
ax2.plot(range(len(error_dev_history_mean)), error_dev_history_mean, '-b', markersize=4, label='dev error')
ax2.plot(range(len(error_test_history_mean)), error_test_history_mean, '-r', markersize=4, label='test error')
ax2.legend(loc='upper right', fontsize='x-large')


plt.title('MC task, classical simulation -- results', fontsize='x-large')
plt.savefig('MC_task_ECS_Results.png', dpi=300, facecolor='white')  
plt.show()