# Code of RNN

In [85]:
import re, random, math, csv, io, string, itertools, sys, functools,math
import numpy as np
import tensorflow as tf

class RNN:
    hp = dict(
        batch_sz=5,
        backprop_len=10,
        n_classes=2,
        state_sz=4,
        ckpt_path="./checkpoints/"
    )
    def __init__(self, **hyper_parameters):
        if hyper_parameters is not None:
            for k, v in hyper_parameters.items():
                self.hp[k] = v
    
    def _init_graph(self):
        tf.reset_default_graph()
        self.inputs_batch = tf.placeholder(tf.float32, [self.hp['batch_sz'], self.hp['backprop_len']], "inputs_batch")
        self.labels_batch = tf.placeholder(tf.int32,   [self.hp['batch_sz'], self.hp['backprop_len']],   "labels_batch")
        self.init_state = tf.placeholder(tf.float32,   [self.hp['batch_sz'], self.hp['state_sz']],     "init_state")
        
        self.w_tarns = tf.Variable(
            np.random.rand(
                self.hp['state_sz']+1, #prev_state + 1 for inputs
                self.hp['state_sz']
            ),
            dtype=tf.float32
        )
        self.b_trans = tf.Variable(np.zeros([1,self.hp['state_sz']]), dtype=tf.float32)
        self.w_out = tf.Variable(np.random.rand(self.hp["state_sz"], self.hp["n_classes"]), dtype=tf.float32)
        self.b_out = tf.Variable(np.zeros([1,self.hp['n_classes']]), dtype=tf.float32)
        
        self.inputs_seq = tf.unstack(self.inputs_batch, axis=1)
        self.labels_seq = tf.unstack(self.labels_batch, axis=1)

    def _forward(self):
        curr_state = self.init_state
        state_seq = []
        for curr_input in self.inputs_seq:
            curr_input = tf.reshape(curr_input, [self.hp['batch_sz'], 1])
            comm = tf.concat([curr_input, curr_state], axis=1)
            next_state = tf.tanh(tf.matmul(comm, self.w_tarns) + self.b_trans)
            state_seq.append(next_state)
            curr_state = next_state
            
        self.last_state = curr_state
        self.logits_seq = [tf.matmul(state, self.w_out)+self.b_out for state in state_seq]
        self.preds_seq = [tf.nn.softmax(logit) for logit in self.logits_seq]
        
    def _train_graph(self):
        self._init_graph()
        self._forward()
        
        # backward
        losses = [
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits = logits,
                labels = labels
            ) 
            for logits, labels in zip(self.logits_seq, self.labels_seq)
        ]
        self.total_loss = tf.reduce_mean(losses)
        self.train_step = tf.train.AdagradOptimizer(0.3).minimize(self.total_loss)
        tf.summary.histogram(self.total_loss.op.name, self.total_loss)
        tf.summary.histogram(self.last_state.op.name, self.last_state)
        self.merged = tf.summary.merge_all()
    
    def _infer_graph(self):
        self._init_graph()
        self._forward()
        
    def predict(self, x):
        nobs = len(x)
        self._infer_graph()
        prediction=[]
        divider = self.hp['batch_sz']*self.hp["backprop_len"]
        pad_len = divider - (nobs%divider)
        x_batch = np.reshape(
            np.pad(x, (0, pad_len), mode='constant'),
            (self.hp['batch_sz'], -1, self.hp["backprop_len"])
        )
        n_batches=x_batch.shape[1]
        saver = tf.train.Saver()
        print(f"Started trainning: n_batches={n_batches}")
        saver = tf.train.Saver()
        with tf.Session() as sess:
            train_writer = tf.summary.FileWriter( './mdl_tr', sess.graph)
            sess.run(tf.global_variables_initializer())
            ckpt = saver.restore(sess, "checkpoints/nn_model.ckpt")
            losses=[]
            curr_state = np.zeros([self.hp['batch_sz'], self.hp['state_sz']])
            for batch_idx in range(n_batches):
                start = batch_idx*self.hp['batch_sz']
                finish = start + self.hp['batch_sz']
                print(f"x_batch={x_batch.shape} batch_idx = {batch_idx}/{n_batches}")
                preds = sess.run(
                    [self.preds_seq],
                    feed_dict={
                        self.inputs_batch:x_batch[:,batch_idx,:],
                        self.init_state:curr_state
                    }
                )
                pred=np.array(preds)[0]
        return pred.transpose((1,0,2)).argmax(axis=2).reshape([1,-1]).squeeze().tolist()
    
    def train(self, x, y,n_epochs=10,report_freq=100, ckpt_freq=100):
        nobs = len(x)
        divider = self.hp['batch_sz']*self.hp["backprop_len"]*n_epochs
        pad_len = divider - (nobs%divider)
        
        if nobs<pad_len:
            raise Exception("Wrong n_epochs({n_epochs}): choose smaller value.")
        x_batch=np.reshape(
            np.pad(x, (0, pad_len), mode='constant'),
            (n_epochs,self.hp['batch_sz'], -1, self.hp["backprop_len"])
        )
        y_batch=np.reshape(
            np.pad(y, (0, pad_len), mode='constant'),
            (n_epochs,self.hp['batch_sz'], -1, self.hp["backprop_len"])
        )
        n_batches=x_batch.shape[2]
        print(f"Started trainning: n_epochs={n_epochs} n_batches={n_batches}")
        self._train_graph()
        saver = tf.train.Saver()
        with tf.Session() as sess:
            train_writer = tf.summary.FileWriter( './mdl_tr', sess.graph)
            sess.run(tf.global_variables_initializer())
            losses=[]
            for epoch_num in range(n_epochs):
                curr_state = np.zeros([self.hp['batch_sz'], self.hp['state_sz']])
                for batch_idx in range(n_batches):
                    start = batch_idx*self.hp['batch_sz']
                    finish = start + self.hp['batch_sz']
                    summary, train_step, loss, curr_state, preds = sess.run(
                        [self.merged, self.train_step, self.total_loss, self.last_state, self.preds_seq],
                        feed_dict={
                            self.inputs_batch:x_batch[epoch_num,:,batch_idx,:],
                            self.labels_batch:y_batch[epoch_num,:,batch_idx,:],
                            self.init_state:curr_state
                        }
                    )
                    if ckpt_freq>0 and batch_idx%ckpt_freq==0:
                        saver.save(sess, self.hp['ckpt_path'] + "nn_model.ckpt")
                    if report_freq>0 and batch_idx%report_freq==0:
                        train_writer.add_summary(summary, global_step=epoch_num)
                        ypred = [1 if x[0]<0.5 else 0 for pred in preds for x in pred]
                        ytrue = [x for pred in y_batch[epoch_num,:,batch_idx,:] for x in pred]
                        print(f"epoch={epoch_num} batch_idx={batch_idx}")
                        print(f"ytrue={y_batch[epoch_num,0, batch_idx, :]}")
                        print(f"ypred={[x[1] for x in np.array(preds)[:,0,:]]}")
                        print(f"loss={loss}")
                print(f"finished training")
                train_writer.add_summary(summary, global_step=epoch_num)
                saver.save(sess, self.hp['ckpt_path'] + "nn_model.ckpt")

# Tests

In [87]:
# learning of NN that represents input sequence with specified delay (echo_step)
echo_step = 2
x = np.array(np.random.choice(2, 500000, p=[0.5, 0.5]))
y = np.roll(x, echo_step)
y[0:echo_step] = 0

rnn = RNN()
rnn.train(x, y, n_epochs=100, report_freq=0)
x = [1,1,0,0,1,0,1,0,0,1,0,1,1,0]
print(rnn.predict(x))

Started trainning: n_epochs=100 n_batches=101
Started trainning: n_batches=1
INFO:tensorflow:Restoring parameters from checkpoints/nn_model.ckpt
x_batch=(5, 1, 10) batch_idx = 0/1
[0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [91]:
# learning of NN that invert binary sequence (I know this task is strange for RNNs)
x = np.array(np.random.choice(2, 5000000, p=[0.5, 0.5]))
y = np.array([0 if v else 1 for v in x])

rnn = RNN()
rnn.train(x, y, n_epochs=200, report_freq=250)
x = [1,0,1,1,0,0]
print(rnn.predict(x))

Started trainning: n_epochs=200 n_batches=501
epoch=0 batch_idx=0
ytrue=[1 0 0 1 0 0 0 1 0 1]
ypred=[0.5, 0.3183783, 0.17439745, 0.14298423, 0.1317624, 0.1306295, 0.1305178, 0.1338595, 0.13082455, 0.1339109]
loss=0.8340187668800354
epoch=0 batch_idx=250
ytrue=[0 1 1 1 1 1 1 0 1 1]
ypred=[0.0036752897, 0.9972037, 0.99604046, 0.99699736, 0.9969029, 0.99690175, 0.9968978, 0.003681124, 0.99718446, 0.99603313]
loss=0.0031327870674431324
epoch=0 batch_idx=500
ytrue=[1 0 0 0 0 0 0 0 1 1]
ypred=[0.9988374, 0.0011398558, 0.0013434829, 0.0012048739, 0.0012253836, 0.0012299516, 0.0012316428, 0.0012320094, 0.9986046, 0.9981628]
loss=0.0013353407848626375
epoch=1 batch_idx=0
ytrue=[1 1 1 0 0 1 1 1 0 0]
ypred=[0.99818254, 0.9985152, 0.9986639, 0.0016732583, 0.0014135306, 0.9986093, 0.9981871, 0.99869055, 0.0016631973, 0.0014084204]
loss=0.0013822460314258933
epoch=1 batch_idx=250
ytrue=[0 0 1 1 0 1 1 0 0 0]
ypred=[0.0007545013, 0.00075662136, 0.9990984, 0.9988512, 0.0009644856, 0.99924374, 0.9989067

epoch=13 batch_idx=250
ytrue=[1 1 1 0 0 1 1 0 0 1]
ypred=[0.99991846, 0.9999131, 0.99992156, 7.47024e-05, 7.104758e-05, 0.99991846, 0.9999131, 7.243232e-05, 7.1052054e-05, 0.99991834]
loss=7.726607145741582e-05
epoch=13 batch_idx=500
ytrue=[1 0 1 1 1 0 0 1 0 0]
ypred=[0.99991727, 6.979193e-05, 0.9999263, 0.9999176, 0.99992454, 7.175221e-05, 6.840695e-05, 0.9999218, 6.3816726e-05, 6.8405774e-05]
loss=7.331339293159544e-05
epoch=14 batch_idx=0
ytrue=[1 0 1 1 1 0 1 1 1 1]
ypred=[0.9998956, 6.416216e-05, 0.9999248, 0.99991727, 0.99992454, 7.1654635e-05, 0.9999267, 0.99991775, 0.99992454, 0.99992526]
loss=7.801223546266556e-05
epoch=14 batch_idx=250
ytrue=[0 1 1 0 1 1 0 1 0 1]
ypred=[6.697411e-05, 0.99992895, 0.99992085, 6.7144094e-05, 0.99992895, 0.99992085, 6.71496e-05, 0.99992895, 6.213032e-05, 0.999928]
loss=7.01807948644273e-05
epoch=14 batch_idx=500
ytrue=[0 1 1 0 0 0 1 0 1 1]
ypred=[6.452744e-05, 0.99993145, 0.9999236, 6.452621e-05, 6.350856e-05, 6.199632e-05, 0.9999267, 5.9240283e-0

epoch=26 batch_idx=250
ytrue=[1 0 1 0 0 0 1 0 0 1]
ypred=[0.99996245, 3.314749e-05, 0.9999622, 3.3097014e-05, 3.45956e-05, 3.42259e-05, 0.99996114, 3.298947e-05, 3.4606423e-05, 0.99996126]
loss=3.693035614560358e-05
epoch=26 batch_idx=500
ytrue=[0 1 1 1 1 1 0 0 1 0]
ypred=[3.3987882e-05, 0.9999633, 0.99996066, 0.9999622, 0.99996233, 0.99996245, 3.403387e-05, 3.3932723e-05, 0.9999621, 3.238157e-05]
loss=3.582174758776091e-05
epoch=27 batch_idx=0
ytrue=[1 1 0 0 0 0 1 1 0 1]
ypred=[0.9999492, 0.99995947, 3.309954e-05, 3.3909404e-05, 3.354045e-05, 3.3548353e-05, 0.99996185, 0.9999603, 3.3487206e-05, 0.99996305]
loss=3.5757369914790615e-05
epoch=27 batch_idx=250
ytrue=[0 1 1 1 1 0 0 0 0 0]
ypred=[3.2812102e-05, 0.9999639, 0.9999614, 0.9999628, 0.99996305, 3.3291064e-05, 3.330507e-05, 3.290209e-05, 3.2918382e-05, 3.2932952e-05]
loss=3.4047978260787204e-05
epoch=27 batch_idx=500
ytrue=[1 1 0 0 1 0 0 1 1 1]
ypred=[0.999962, 0.9999635, 3.249082e-05, 3.2698557e-05, 0.99996364, 3.1129246e-05, 3.2

epoch=39 batch_idx=250
ytrue=[0 1 1 1 0 1 0 1 0 0]
ypred=[2.279115e-05, 0.9999753, 0.9999739, 0.9999746, 2.2213324e-05, 0.9999757, 2.1778827e-05, 0.99997556, 2.1756614e-05, 2.2771683e-05]
loss=2.3936936486279592e-05
epoch=39 batch_idx=500
ytrue=[0 1 0 0 1 1 1 0 0 0]
ypred=[2.1474541e-05, 0.9999759, 2.1476835e-05, 2.246929e-05, 0.99997556, 0.99997437, 0.99997497, 2.1912974e-05, 2.2486716e-05, 2.2344904e-05]
loss=2.3369513655779883e-05
epoch=40 batch_idx=0
ytrue=[1 0 1 1 1 0 0 1 0 1]
ypred=[0.9999671, 2.0993786e-05, 0.9999757, 0.99997425, 0.99997497, 2.1908836e-05, 2.2485365e-05, 0.9999757, 2.1456179e-05, 0.9999759]
loss=2.3598389816470444e-05
epoch=40 batch_idx=250
ytrue=[1 0 0 1 1 0 0 0 1 1]
ypred=[0.9999763, 2.1184665e-05, 2.2189357e-05, 0.9999759, 0.9999746, 2.1526559e-05, 2.2201844e-05, 2.20673e-05, 0.9999759, 0.9999746]
loss=2.2957059627515264e-05
epoch=40 batch_idx=500
ytrue=[1 1 1 1 1 0 1 0 1 0]
ypred=[0.9999757, 0.9999757, 0.9999757, 0.9999757, 0.9999757, 2.1370635e-05, 0.999976

epoch=52 batch_idx=250
ytrue=[0 1 1 0 0 1 0 1 1 1]
ypred=[1.6115495e-05, 0.9999821, 0.9999809, 1.626407e-05, 1.6914446e-05, 0.999982, 1.6102946e-05, 0.9999821, 0.9999809, 0.99998116]
loss=1.7378179109073244e-05
epoch=52 batch_idx=500
ytrue=[0 0 1 1 0 1 0 0 1 1]
ypred=[1.674006e-05, 1.668746e-05, 0.99998224, 0.99998105, 1.6094793e-05, 0.99998236, 1.5957547e-05, 1.6737296e-05, 0.99998224, 0.99998116]
loss=1.7320959159405902e-05
epoch=53 batch_idx=0
ytrue=[1 0 1 1 1 0 0 0 1 0]
ypred=[0.9999759, 1.5570007e-05, 0.9999821, 0.99998105, 0.9999814, 1.6114282e-05, 1.6741447e-05, 1.6687094e-05, 0.99998224, 1.5945163e-05]
loss=1.7487847799202427e-05
epoch=53 batch_idx=250
ytrue=[1 0 0 1 0 1 1 0 1 1]
ypred=[0.9999826, 1.5803194e-05, 1.6572305e-05, 0.99998236, 1.579021e-05, 0.9999825, 0.9999813, 1.5937076e-05, 0.9999825, 0.9999813]
loss=1.7075393770937808e-05
epoch=53 batch_idx=500
ytrue=[1 1 1 1 0 1 0 0 0 0]
ypred=[0.9999815, 0.99998176, 0.99998176, 0.99998176, 1.581215e-05, 0.9999827, 1.5649981e-0

epoch=65 batch_idx=250
ytrue=[1 0 1 0 0 1 1 1 1 0]
ypred=[0.999985, 1.2820773e-05, 0.99998593, 1.2770376e-05, 1.3376469e-05, 0.99998593, 0.999985, 0.9999851, 0.9999851, 1.2823377e-05]
loss=1.3794804544886574e-05
epoch=65 batch_idx=500
ytrue=[0 1 1 1 1 0 0 0 0 0]
ypred=[1.2718989e-05, 0.99998605, 0.9999852, 0.9999852, 0.9999852, 1.272139e-05, 1.3269209e-05, 1.3259684e-05, 1.3263213e-05, 1.3263592e-05]
loss=1.3556388694269117e-05
epoch=66 batch_idx=0
ytrue=[0 1 0 0 0 0 0 1 1 0]
ypred=[1.0910857e-05, 0.9999857, 1.2586235e-05, 1.3260607e-05, 1.3257358e-05, 1.3262479e-05, 1.326301e-05, 0.99998605, 0.9999852, 1.2718721e-05]
loss=1.3770961231784895e-05
epoch=66 batch_idx=250
ytrue=[1 1 1 0 0 1 1 0 1 1]
ypred=[0.99998534, 0.99998534, 0.99998534, 1.2619071e-05, 1.3160978e-05, 0.9999862, 0.99998534, 1.2617626e-05, 0.9999862, 0.99998534]
loss=1.366844298900105e-05
epoch=66 batch_idx=500
ytrue=[1 1 1 0 1 1 0 1 1 0]
ypred=[0.99998546, 0.99998546, 0.99998546, 1.25196875e-05, 0.9999863, 0.99998546, 1

epoch=78 batch_idx=250
ytrue=[1 1 0 0 1 0 0 1 1 0]
ypred=[0.99998844, 0.9999877, 1.05660565e-05, 1.1051815e-05, 0.99998844, 1.0563538e-05, 1.1057909e-05, 0.99998844, 0.9999877, 1.0567588e-05]
loss=1.1444026313256472e-05
epoch=78 batch_idx=500
ytrue=[1 1 0 0 0 1 0 1 1 1]
ypred=[0.99998784, 0.99998784, 1.0486571e-05, 1.0979059e-05, 1.0988412e-05, 0.99998856, 1.0493394e-05, 0.99998856, 0.99998784, 0.99998784]
loss=1.1129317499580793e-05
epoch=79 batch_idx=0
ytrue=[1 0 0 0 0 1 1 0 1 1]
ypred=[0.9999845, 1.021809e-05, 1.0926789e-05, 1.0979353e-05, 1.0991263e-05, 0.99998856, 0.99998784, 1.0496117e-05, 0.99998856, 0.99998784]
loss=1.1410647857701406e-05
epoch=79 batch_idx=250
ytrue=[0 0 0 0 0 1 0 0 1 1]
ypred=[1.0910732e-05, 1.091478e-05, 1.0919019e-05, 1.0919497e-05, 1.0919581e-05, 0.9999887, 1.0427575e-05, 1.091091e-05, 0.9999887, 0.99998796]
loss=1.1050640750909224e-05
epoch=79 batch_idx=500
ytrue=[1 0 0 0 1 1 0 1 1 0]
ypred=[0.9999887, 1.0359855e-05, 1.0839029e-05, 1.0843516e-05, 0.999988

epoch=91 batch_idx=250
ytrue=[0 0 0 1 0 0 1 1 0 0]
ypred=[8.966267e-06, 9.392236e-06, 9.41409e-06, 0.9999902, 9.007353e-06, 9.402632e-06, 0.9999902, 0.9999896, 8.98468e-06, 9.394843e-06]
loss=9.572459930495825e-06
epoch=91 batch_idx=500
ytrue=[1 0 1 1 0 1 0 0 1 1]
ypred=[0.9999896, 8.9164005e-06, 0.9999902, 0.9999896, 8.9286195e-06, 0.9999902, 8.942612e-06, 9.346425e-06, 0.99999034, 0.99998975]
loss=9.5963014246081e-06
epoch=92 batch_idx=0
ytrue=[0 0 1 0 0 0 1 1 1 0]
ypred=[7.733809e-06, 9.227271e-06, 0.9999902, 8.942117e-06, 9.345801e-06, 9.36025e-06, 0.99999034, 0.99998975, 0.9999896, 8.917506e-06]
loss=9.787033377506305e-06
epoch=92 batch_idx=250
ytrue=[0 1 1 0 0 1 0 0 0 0]
ypred=[9.295651e-06, 0.99999034, 0.99998975, 8.878862e-06, 9.287905e-06, 0.99999034, 8.899496e-06, 9.295402e-06, 9.308194e-06, 9.312776e-06]
loss=9.524776032776572e-06
epoch=92 batch_idx=500
ytrue=[0 1 0 1 0 1 1 0 1 1]
ypred=[9.240762e-06, 0.99999046, 8.84909e-06, 0.99999034, 8.840789e-06, 0.99999034, 0.99998975,

epoch=104 batch_idx=250
ytrue=[1 1 0 0 1 0 1 0 0 0]
ypred=[0.99999154, 0.99999106, 7.796102e-06, 8.160974e-06, 0.99999154, 7.830578e-06, 0.99999154, 7.822346e-06, 8.168573e-06, 8.186464e-06]
loss=8.35415175970411e-06
epoch=104 batch_idx=500
ytrue=[1 1 0 1 1 0 0 1 1 0]
ypred=[0.99999166, 0.99999106, 7.761044e-06, 0.99999154, 0.99999106, 7.75721e-06, 8.120735e-06, 0.99999166, 0.99999106, 7.760296e-06]
loss=8.346999493369367e-06
epoch=105 batch_idx=0
ytrue=[0 0 0 1 1 1 1 0 1 1]
ypred=[6.742937e-06, 8.027635e-06, 8.124811e-06, 0.99999166, 0.99999106, 0.99999106, 0.99999106, 7.74014e-06, 0.99999154, 0.99999106]
loss=8.201564924092963e-06
epoch=105 batch_idx=250
ytrue=[0 1 1 1 0 0 0 1 1 0]
ypred=[8.112205e-06, 0.99999166, 0.9999912, 0.99999106, 7.704686e-06, 8.079932e-06, 8.105261e-06, 0.99999166, 0.9999912, 7.720434e-06]
loss=8.192028872144874e-06
epoch=105 batch_idx=500
ytrue=[1 0 1 0 0 0 0 0 0 1]
ypred=[0.99999106, 7.659018e-06, 0.99999166, 7.699412e-06, 8.048131e-06, 8.066342e-06, 8.0712

epoch=117 batch_idx=250
ytrue=[0 1 1 1 0 0 0 1 0 1]
ypred=[6.924698e-06, 0.9999926, 0.99999213, 0.999992, 6.8698923e-06, 7.209986e-06, 7.2376656e-06, 0.9999926, 6.9242224e-06, 0.9999926]
loss=7.374259439529851e-06
epoch=117 batch_idx=500
ytrue=[0 0 1 0 0 1 1 0 0 0]
ypred=[7.2076623e-06, 7.2125986e-06, 0.9999926, 6.8939394e-06, 7.1899017e-06, 0.9999926, 0.99999213, 6.856036e-06, 7.1811646e-06, 7.2066728e-06]
loss=7.359953997365665e-06
epoch=118 batch_idx=0
ytrue=[1 0 1 0 0 1 1 1 0 0]
ypred=[0.9999901, 6.704791e-06, 0.9999925, 6.865792e-06, 7.1849186e-06, 0.9999926, 0.99999213, 0.99999213, 6.8399386e-06, 7.1785353e-06]
loss=7.386179731838638e-06
epoch=118 batch_idx=250
ytrue=[1 1 0 1 0 1 0 0 0 0]
ypred=[0.99999213, 0.99999213, 6.8071217e-06, 0.9999926, 6.846896e-06, 0.9999926, 6.8499917e-06, 7.1564227e-06, 7.175825e-06, 7.180665e-06]
loss=7.400485174002824e-06
epoch=118 batch_idx=500
ytrue=[0 1 0 1 0 0 0 1 1 0]
ypred=[7.149308e-06, 0.9999927, 6.8313793e-06, 0.9999926, 6.821139e-06, 7.125

epoch=130 batch_idx=250
ytrue=[1 1 1 0 0 1 1 0 0 1]
ypred=[0.99999297, 0.99999297, 0.99999285, 6.145188e-06, 6.456951e-06, 0.9999933, 0.99999297, 6.1651344e-06, 6.4603078e-06, 0.9999933]
loss=6.678081717836903e-06
epoch=130 batch_idx=500
ytrue=[0 0 0 1 1 1 1 1 1 0]
ypred=[6.1379787e-06, 6.433808e-06, 6.460209e-06, 0.99999344, 0.99999297, 0.99999297, 0.99999297, 0.99999297, 0.99999285, 6.1184937e-06]
loss=6.656624464085326e-06
epoch=131 batch_idx=0
ytrue=[0 0 0 1 1 0 0 0 0 1]
ypred=[5.3656304e-06, 6.365867e-06, 6.444739e-06, 0.99999344, 0.99999297, 6.141327e-06, 6.4343417e-06, 6.4602586e-06, 6.465916e-06, 0.99999344]
loss=6.580330591532402e-06
epoch=131 batch_idx=250
ytrue=[0 1 0 1 1 0 0 0 0 0]
ypred=[6.417183e-06, 0.99999344, 6.1540673e-06, 0.99999344, 0.99999297, 6.114061e-06, 6.408609e-06, 6.4346855e-06, 6.440389e-06, 6.4414025e-06]
loss=6.456354185502278e-06
epoch=131 batch_idx=500
ytrue=[1 0 1 1 1 0 0 1 0 1]
ypred=[0.99999297, 6.069652e-06, 0.99999344, 0.9999931, 0.99999297, 6.0754

epoch=143 batch_idx=250
ytrue=[0 0 0 1 1 0 1 0 1 0]
ypred=[5.849868e-06, 5.8702553e-06, 5.8749038e-06, 0.99999404, 0.9999937, 5.576185e-06, 0.99999404, 5.602517e-06, 0.99999404, 5.604516e-06]
loss=5.877001058252063e-06
epoch=143 batch_idx=500
ytrue=[1 1 0 1 1 1 0 1 1 1]
ypred=[0.9999937, 0.9999937, 5.541071e-06, 0.99999404, 0.9999937, 0.9999937, 5.5395703e-06, 0.99999404, 0.9999937, 0.9999937]
loss=5.910379059059778e-06
epoch=144 batch_idx=0
ytrue=[1 1 1 1 1 0 1 0 1 1]
ypred=[0.999992, 0.99999344, 0.99999344, 0.99999356, 0.99999356, 5.530301e-06, 0.99999404, 5.5804035e-06, 0.99999404, 0.9999937]
loss=6.2489316405844875e-06
epoch=144 batch_idx=250
ytrue=[1 1 1 1 0 1 0 1 0 1]
ypred=[0.9999937, 0.9999937, 0.9999937, 0.99999356, 5.5143337e-06, 0.99999404, 5.5608107e-06, 0.99999404, 5.5645028e-06, 0.99999404]
loss=5.929452072450658e-06
epoch=144 batch_idx=500
ytrue=[0 0 1 1 1 0 1 1 0 1]
ypred=[5.5158903e-06, 5.780988e-06, 0.99999416, 0.9999937, 0.9999937, 5.5009186e-06, 0.99999404, 0.999993

epoch=156 batch_idx=250
ytrue=[1 0 1 1 0 0 0 0 1 1]
ypred=[0.99999416, 5.0699673e-06, 0.9999945, 0.9999943, 5.0851527e-06, 5.3334843e-06, 5.359805e-06, 5.365349e-06, 0.99999464, 0.9999943]
loss=5.445465831144247e-06
epoch=156 batch_idx=500
ytrue=[1 0 1 0 1 0 0 0 0 1]
ypred=[0.99999416, 5.0558688e-06, 0.9999945, 5.0984872e-06, 0.9999945, 5.101736e-06, 5.3226713e-06, 5.343045e-06, 5.347521e-06, 0.99999464]
loss=5.414471161202528e-06
epoch=157 batch_idx=0
ytrue=[0 1 1 0 0 1 1 1 1 0]
ypred=[4.45005e-06, 0.9999945, 0.99999416, 5.059631e-06, 5.313623e-06, 0.99999464, 0.9999943, 0.99999416, 0.99999416, 5.0523454e-06]
loss=5.359634997148532e-06
epoch=157 batch_idx=250
ytrue=[0 0 1 1 1 1 0 0 1 1]
ypred=[5.32962e-06, 5.3303725e-06, 0.99999464, 0.9999943, 0.9999943, 0.99999416, 5.035246e-06, 5.2947307e-06, 0.99999464, 0.9999943]
loss=5.409703135228483e-06
epoch=157 batch_idx=500
ytrue=[0 1 1 1 0 1 1 0 1 0]
ypred=[5.065212e-06, 0.99999464, 0.9999943, 0.9999943, 5.022e-06, 0.99999464, 0.9999943, 5.

epoch=169 batch_idx=250
ytrue=[0 1 1 0 1 0 1 1 0 1]
ypred=[4.9131477e-06, 0.999995, 0.99999475, 4.6809173e-06, 0.999995, 4.70799e-06, 0.999995, 0.99999475, 4.6783025e-06, 0.999995]
loss=4.93525249112281e-06
epoch=169 batch_idx=500
ytrue=[1 1 1 0 0 0 0 1 0 0]
ypred=[0.99999464, 0.99999464, 0.99999464, 4.64624e-06, 4.8866614e-06, 4.914052e-06, 4.919862e-06, 0.999995, 4.7050185e-06, 4.8983916e-06]
loss=4.9614782255957834e-06
epoch=170 batch_idx=0
ytrue=[1 0 0 1 1 0 1 0 1 1]
ypred=[0.99999344, 4.5767515e-06, 4.8681586e-06, 0.999995, 0.99999475, 4.665079e-06, 0.999995, 4.6930218e-06, 0.999995, 0.99999475]
loss=4.968630491930526e-06
epoch=170 batch_idx=250
ytrue=[0 0 0 0 0 1 1 1 1 1]
ypred=[4.677575e-06, 4.8813336e-06, 4.90113e-06, 4.905479e-06, 4.9062974e-06, 0.9999951, 0.99999475, 0.99999475, 0.99999464, 0.99999464]
loss=4.956710199621739e-06
epoch=170 batch_idx=500
ytrue=[1 0 0 1 0 1 1 1 1 1]
ypred=[0.9999951, 4.674186e-06, 4.867945e-06, 0.9999951, 4.6748282e-06, 0.999995, 0.99999475, 0.9

epoch=182 batch_idx=250
ytrue=[1 1 1 1 0 1 0 0 0 0]
ypred=[0.99999535, 0.9999951, 0.9999951, 0.999995, 4.3105215e-06, 0.99999535, 4.3506025e-06, 4.544137e-06, 4.5625316e-06, 4.5666584e-06]
loss=4.584778707794612e-06
epoch=182 batch_idx=500
ytrue=[1 1 1 0 0 1 1 0 1 0]
ypred=[0.99999535, 0.9999951, 0.9999951, 4.301245e-06, 4.5235643e-06, 0.99999547, 0.9999951, 4.314918e-06, 0.99999535, 4.33958e-06]
loss=4.5919314288767055e-06
epoch=183 batch_idx=0
ytrue=[1 1 0 1 1 0 1 1 1 0]
ypred=[0.9999939, 0.9999949, 4.2655306e-06, 0.99999535, 0.9999951, 4.3102336e-06, 0.99999535, 0.9999951, 0.9999951, 4.3011505e-06]
loss=4.632462150766514e-06
epoch=183 batch_idx=250
ytrue=[1 1 1 0 0 1 0 1 1 0]
ypred=[0.9999951, 0.9999951, 0.9999951, 4.2860825e-06, 4.5108604e-06, 0.99999547, 4.3354185e-06, 0.99999547, 0.9999951, 4.3007526e-06]
loss=4.575241746351821e-06
epoch=183 batch_idx=500
ytrue=[0 1 0 0 0 1 1 1 1 1]
ypred=[4.5089896e-06, 0.99999547, 4.324286e-06, 4.508783e-06, 4.5253682e-06, 0.99999547, 0.9999951

epoch=195 batch_idx=250
ytrue=[0 0 0 1 1 0 0 1 1 0]
ypred=[4.009049e-06, 4.2184984e-06, 4.242188e-06, 0.9999957, 0.99999547, 4.0254545e-06, 4.2216657e-06, 0.9999957, 0.99999547, 4.024867e-06]
loss=4.262915354047436e-06
epoch=195 batch_idx=500
ytrue=[0 1 0 0 1 0 0 0 1 1]
ypred=[4.2307706e-06, 0.9999957, 4.047242e-06, 4.2170304e-06, 0.9999957, 4.0462346e-06, 4.216837e-06, 4.2327115e-06, 0.9999957, 0.99999547]
loss=4.262915354047436e-06
epoch=196 batch_idx=0
ytrue=[1 1 0 1 0 1 1 1 1 1]
ypred=[0.9999944, 0.99999523, 3.968371e-06, 0.9999957, 4.0328705e-06, 0.9999957, 0.99999547, 0.99999547, 0.99999535, 0.99999535]
loss=4.42742293671472e-06
epoch=196 batch_idx=250
ytrue=[0 0 1 1 1 1 1 0 1 1]
ypred=[4.195241e-06, 4.2189613e-06, 0.9999958, 0.99999547, 0.99999547, 0.99999547, 0.99999547, 3.985727e-06, 0.9999957, 0.99999547]
loss=4.274836101103574e-06
epoch=196 batch_idx=500
ytrue=[1 0 0 0 0 1 0 0 1 1]
ypred=[0.99999547, 3.9802876e-06, 4.1845965e-06, 4.2075817e-06, 4.2125685e-06, 0.9999958, 4.02