Skip to content

Commit

Permalink
Changing variable minibatch to clearer minibatch_size.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexggmatthews committed Mar 15, 2016
1 parent 0fa7960 commit e4155c8
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions GPflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,26 +240,26 @@ class GPModel(Model):
The predictions can also be used to compute the (log) density of held-out
data via self.predict_density.
"""
def __init__(self, X, Y, kern, likelihood, mean_function, minibatch=None, name='model'):
def __init__(self, X, Y, kern, likelihood, mean_function, minibatch_size=None, name='model'):
self.X, self.Y, self.kern, self.likelihood, self.mean_function = X, Y, kern, likelihood, mean_function
Model.__init__(self, name)

self._tfX = tf.Variable(self.X, name="tfX") # When using minibatches, use _tfX variable
self._tfX = tf.Variable(self.X, name="tfX") # When using minibatch_sizees, use _tfX variable
self._tfY = tf.Variable(self.Y, name="tfY")
self.minibatch = minibatch
self.minibatch_size = minibatch_size

def _compile(self):
"""
compile the tensorflow function "self._objective"
"""
Model._compile(self)
def obj(x):
if self.minibatch / float(len(self.X)) > 0.5:
ss = np.random.permutation(len(self.X))[:self.minibatch]
if self.minibatch_size / float(len(self.X)) > 0.5:
ss = np.random.permutation(len(self.X))[:self.minibatch_size]
else:
# This is much faster than above, and for N >> minibatch, it doesn't make much difference. This actually
# This is much faster than above, and for N >> minibatch_size, it doesn't make much difference. This actually
# becomes the limit when N is around 10**6, which isn't uncommon when using SVI.
ss = np.random.randint(len(self.X), size=self.minibatch)
ss = np.random.randint(len(self.X), size=self.minibatch_size)
return self._session.run([self._minusF, self._minusG], feed_dict={self._free_vars: x,
self._tfX: self.X[ss, :],
self._tfY: self.Y[ss, :]})
Expand Down Expand Up @@ -297,15 +297,15 @@ def predict_density(self, Xnew, Ynew):
return self.likelihood.predict_density(pred_f_mean, pred_f_var, Ynew)

@property
def minibatch(self):
if self._minibatch is None:
def minibatch_size(self):
if self._minibatch_size is None:
return len(self.X)
else:
return self._minibatch
return self._minibatch_size

@minibatch.setter
def minibatch(self, val):
self._minibatch = val
@minibatch_size.setter
def minibatch_size(self, val):
self._minibatch_size = val



Expand Down

0 comments on commit e4155c8

Please sign in to comment.