In [18]:
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import StandardScaler

In [19]:
class Batch:
    def __init__(self, data, batchSize):
        self.data = np.copy(data)
        self.count = 0
        self.batchSize = batchSize
        np.random.shuffle(self.data)
    
    def getNextBatch(self):
        if self.count >= len(self.data) / self.batchSize:
            self.count = 0
            np.random.shuffle(self.data)
        start = self.count * self.batchSize
        end = start + self.batchSize
        toReturn = self.data[start:end]
        self.count += 1
        return toReturn

def normalize(data, labels):
    scaler = StandardScaler()
    scaler.fit(data)
    cleanData = scaler.transform(data)
    fullData = np.hstack((cleanData, labels))
    return scaler, fullData 

In [20]:
LEARNING_RATE_START = 0.005
LEARNING_RATE_END   = 0.0001
N_EPOCHS      = 3000
anneal_rate = (-1.0 * np.log(LEARNING_RATE_END / LEARNING_RATE_START)) / float(N_EPOCHS)
batchSize = 50
dirtyData = np.load('train.npy')
scaler, fullData = normalize(dirtyData[:, :-1], dirtyData[:, -1:])

In [21]:
np.random.shuffle(fullData)
train_data = fullData[:80000]
valid_data = fullData[80000:]
trainBatch = Batch(train_data, batchSize)

In [22]:
sess = tf.Session()
V = tf.Variable(tf.random_normal([15, 55], stddev=0.1))
W = tf.Variable(tf.random_normal([1, 15], stddev=0.1))
X = tf.placeholder(tf.float32, [None, 55])
y = tf.placeholder(tf.float32, [None, 1]) 
lr  = tf.placeholder(tf.float32)
eval_dict={
    X   : valid_data[:,:-1],
    y   : valid_data[:,-1:]
}

In [23]:
hidden = tf.nn.tanh(tf.matmul(V,tf.cast(tf.transpose(X), tf.float32)))
estimate = tf.matmul(W, hidden)

In [30]:
loss = tf.reduce_mean(tf.abs(estimate - y))
metric = tf.metrics.mean_absolute_error(estimate, tf.transpose(y))
optimizer = tf.train.AdamOptimizer(lr)
train_step = optimizer.minimize(loss)

In [31]:
saver = tf.train.Saver([V,W])

In [32]:
init = tf.global_variables_initializer()
sess.run(init)

In [33]:
#Restore Variables

# saver = tf.train.import_meta_graph('ZestimateWeight.meta')
# saver.restore(sess,tf.train.latest_checkpoint('./'))

In [34]:
epochSize = len(fullData) / batchSize
best_loss = 0.066
curr_lr = LEARNING_RATE_START

for i in range(N_EPOCHS*epochSize):
    currBatch = trainBatch.getNextBatch()
    sess.run(train_step, feed_dict={
        X   : currBatch[:, :-1],
        y   : currBatch[:, -1:],
        lr  : curr_lr
    })
    checkpointed = False
    
    if i %  epochSize == 0:
        print("epoch #" + str(i /epochSize))
        tr_loss = sess.run(loss, eval_dict)
        if tr_loss < best_loss:
            best_loss = tr_loss
            saver.save(sess, "ZestimateWeight")
            curr_lr = LEARNING_RATE_START * np.exp(-1.0 * anneal_rate * i)
            checkpointed = True
        print("checkPointed = " + str(checkpointed))
        print(tr_loss)

        

epoch #0
checkPointed = False
0.10053
epoch #1
checkPointed = False
0.0663639
epoch #2
checkPointed = False
0.0663213
epoch #3
checkPointed = False
0.066447
epoch #4
checkPointed = False
0.0664959
epoch #5
checkPointed = False
0.0664109
epoch #6
checkPointed = False
0.0668904
epoch #7
checkPointed = False
0.0663458
epoch #8
checkPointed = False
0.0664414
epoch #9
checkPointed = False
0.0664182
epoch #10
checkPointed = False
0.066456
epoch #11
checkPointed = False
0.0663653
epoch #12
checkPointed = False
0.0662806
epoch #13
checkPointed = False
0.0663154
epoch #14
checkPointed = False
0.0663292
epoch #15
checkPointed = False
0.066324
epoch #16
checkPointed = False
0.0664108
epoch #17
checkPointed = False
0.0663288
epoch #18
checkPointed = False
0.0663925
epoch #19
checkPointed = False
0.066372
epoch #20
checkPointed = False
0.0665329
epoch #21
checkPointed = False
0.0665537
epoch #22
checkPointed = False
0.0662792
epoch #23
checkPointed = False
0.0665491
epoch #24
checkPointed = False
0

KeyboardInterrupt: 

In [16]:
testData = scaler.transform(np.load('testData.npy'))

In [269]:
preds = sess.run(estimate, { X: testData})

In [270]:
np.save("NNPredictions.npy", preds)

In [196]:
preds.shape

(1, 2985217)