# Train our network
### Trains our network on pre-computed train data.
### Saves our network after each epoch.

In [1]:
import numpy as np
import os
import pacrr
import tensorflow as tf
import time

  from ._conv import register_converters as _register_converters


In [2]:
counter = 0

lq = 10
ld = 20
lf = 32
lg = 3
k = 3
lr = 0.01

saves_folder = r"D:\IRfiles\saves"
data_folder = r"D:\IRfiles\tdata"


In [3]:
# load network
yp, yn, prel, loss, opt, saver = pacrr.build(lq, ld, lf, lg, k=k, lr=lr)

In [4]:
# load epoch by epoch num (from pre-calculated epochs)
def load_epoch(counter):
    pathp = os.path.join(data_folder, "dp%d.npy" % counter)
    pathn = os.path.join(data_folder, "dn%d.npy" % counter)
    
    dp = np.load(pathp)
    dn = np.load(pathn)
    
    return dp, dn
    

In [5]:
def split_batches(a, batch_size):
    batches = []
    
    for i in range(int(a.shape[0]/batch_size)):
        batches.append(a[i*batch_size : (i+1) * batch_size])
    
    return batches

In [6]:
def train(dp, dn, sess, batch_size=32, print_every=0.25):
    dpb = split_batches(dp, batch_size)
    dnb = split_batches(dn, batch_size)
    
    if print_every < 1 and print_every > 0:
        print_every = int(np.floor(print_every * len(dpb)))
    
    batch_counter = 0
    for bp, bn in zip(dpb, dnb):
        feed_dict = { yp : bp, yn : bn }
        
        _, outls = sess.run([opt, loss], feed_dict=feed_dict)
        
        if print_every > 0 and batch_counter % print_every == 0:
            print("batch %d finished. loss: %f" % (batch_counter, np.sum(outls)))
        
        batch_counter += 1
        

In [7]:
# get mean loss of network
def get_loss(dp, dn, sess, batch_size=32):
    dpb = split_batches(dp, batch_size)
    dnb = split_batches(dn, batch_size)
    losses = []
    
    for bp, bn in zip(dpb, dnb):
        feed_dict = { yp : bp, yn : bn }
        
        outls = sess.run(loss, feed_dict=feed_dict)
        losses.append(outls)
        
    return np.mean(outls)

In [8]:
sess = None

In [9]:
# this is done for testing and debugging, if the session is already initialized, close it.
if sess is not None:
    sess.close()
    
sess = tf.Session()    

In [10]:
# initialize TensorFlow variables.
var_inits = tf.global_variables_initializer()
sess.run(var_inits)

In [14]:
# train for 400 epochs (we stopped it early)
for i in range(400):
    print("starting epoch %d" % (counter))
    st = time.time()
    dp, dn = load_epoch(counter%203)
    train(dp, dn, sess, print_every=0, batch_size=128)
    
    meanloss = get_loss(dp, dn, sess)
    del dp
    del dn
    
    fpath = os.path.join(saves_folder, str(counter))
    if not os.path.isdir(fpath):
        os.mkdir(fpath)
    
    pacrr.save(sess, os.path.join(fpath, "e%d.cpt" % counter), saver)
    
    print("epoch finished. loss: %f. secs: %f" % (meanloss, time.time()-st))
    counter += 1

starting epoch 0
epoch finished. loss: 0.472765. secs: 22.341220
starting epoch 1
epoch finished. loss: 0.473435. secs: 24.734217
starting epoch 2
epoch finished. loss: 0.472336. secs: 24.702677
starting epoch 3
epoch finished. loss: 0.471986. secs: 24.278812
starting epoch 4
epoch finished. loss: 0.472320. secs: 24.779385
starting epoch 5
epoch finished. loss: 0.466125. secs: 24.889526
starting epoch 6
epoch finished. loss: 0.474077. secs: 24.542462
starting epoch 7
epoch finished. loss: 0.474077. secs: 24.579564
starting epoch 8
epoch finished. loss: 0.474077. secs: 24.551034
starting epoch 9
epoch finished. loss: 0.474077. secs: 24.452001
starting epoch 10
epoch finished. loss: 0.474077. secs: 24.745042
starting epoch 11
epoch finished. loss: 0.474077. secs: 24.400964
starting epoch 12
epoch finished. loss: 0.474077. secs: 24.754376
starting epoch 13
epoch finished. loss: 0.474077. secs: 24.246185
starting epoch 14
epoch finished. loss: 0.474077. secs: 24.779265
starting epoch 15
ep

epoch finished. loss: 0.474077. secs: 24.568006
starting epoch 125
epoch finished. loss: 0.474077. secs: 24.154933
starting epoch 126
epoch finished. loss: 0.474077. secs: 24.850379
starting epoch 127
epoch finished. loss: 0.474077. secs: 24.326315
starting epoch 128
epoch finished. loss: 0.461917. secs: 24.079987
starting epoch 129
epoch finished. loss: 0.467000. secs: 24.556877
starting epoch 130
epoch finished. loss: 0.474077. secs: 24.012225
starting epoch 131
epoch finished. loss: 0.474077. secs: 24.578441
starting epoch 132
epoch finished. loss: 0.474077. secs: 24.242689
starting epoch 133
epoch finished. loss: 0.474077. secs: 24.400656
starting epoch 134
epoch finished. loss: 0.474077. secs: 24.469602
starting epoch 135
epoch finished. loss: 0.474077. secs: 23.989480
starting epoch 136
epoch finished. loss: 0.474077. secs: 24.234500
starting epoch 137
epoch finished. loss: 0.474077. secs: 24.123279
starting epoch 138
epoch finished. loss: 0.474077. secs: 24.421013
starting epoch

KeyboardInterrupt: 