Skip to content

Commit

Permalink
initial work on DataHo9lder Class
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshensman committed Jun 24, 2016
1 parent 9374793 commit 406c1a3
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 31 deletions.
19 changes: 3 additions & 16 deletions GPflow/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import print_function
from .param import Parameterized, AutoFlow
from .param import Parameterized, AutoFlow, DataHolder
from scipy.optimize import minimize, OptimizeResult
import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -77,7 +77,6 @@ def __init__(self, name='model'):
self._needs_recompile = True
self._session = tf.Session()
self._free_vars = tf.placeholder(tf.float64)
self._data_dict = {}

@property
def name(self):
Expand All @@ -103,15 +102,6 @@ def __setstate__(self, d):
self._needs_recompile = True
self._session = tf.Session()

def get_feed_dict(self):
"""
Return a dicitonary containing all the placeholder-value pairs that
should be fed to tensorflow in order to evaluate the model
"""
d = Parameterized.get_feed_dict(self)
d.update(self._data_dict)
return d

def _compile(self, optimizer=None):
"""
compile the tensorflow function "self._objective"
Expand Down Expand Up @@ -324,11 +314,8 @@ def __init__(self, X, Y, kern, likelihood, mean_function, name='model'):
kern, likelihood, mean_function
Model.__init__(self, name)

# set of data is stored in dict self._data_dict
# self._data_dict will be feeded to tensorflow at the runtime.
self.X = tf.placeholder(tf.float64, shape=X.shape, name="X")
self.Y = tf.placeholder(tf.float64, shape=Y.shape, name="Y")
self._data_dict = {self.X: X, self.Y: Y}
self.X = DataHolder(X)
self.Y = DataHolder(Y)

def build_predict(self):
raise NotImplementedError
Expand Down
76 changes: 61 additions & 15 deletions GPflow/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,50 @@ def name(self):
referenced by a parent")
return matches[0]

def __getstate__(self):
d = object.__getstate__(self)
d.pop('_parent')
return d

def __setstate__(self, d):
object.__setstate__(self, d)
self._parent = None


class DataHolder(Parentable):
"""
"""
def __init__(self, array):
"""
"""
Parentable.__init__(self)
self._array = array

def make_tf_array(self, X):
self._tf_array = tf.placeholder(dtype=tf.float64,
shape=self._array.shape,
name=self.name)

@property
def value(self):
return self._array.copy()

def get_feed_dict(self):
return {self._tf_array: self._array}

def __getstate__(self):
d = Parentable.__getstate__(self)
d.pop('_tf_array')
return d

@property
def size(self):
return self._array.size

@property
def shape(self):
return self._array.shape


class Param(Parentable):
"""
Expand Down Expand Up @@ -249,18 +293,16 @@ def _html_table_rows(self, name_prefix=''):
return html

def __getstate__(self):
d = self.__dict__.copy()
d = Parentable.__getstate__(self)
d.pop('_tf_array')
d.pop('_log_jacobian')
d.pop('_parent')
return d

def __setstate__(self, d):
self.__dict__ = d
Parentable.__setstate__(self, d)
self._tf_array = None
self._log_jacobian = None
self._parent = None
# NB the parent property will be set by the parent object, aprt from
# NB the parent property will be set by the parent object, apart from
# for the top level, where it muct be None
# the tf_array and _log jacobian will be replaced when the model is recompiled

Expand Down Expand Up @@ -371,7 +413,7 @@ def __getattribute__(self, key):
representations.
"""
o = object.__getattribute__(self, key)
if isinstance(o, Param) and object.__getattribute__(self, '_tf_mode'):
if isinstance(o, (Param, DataHolder)) and object.__getattribute__(self, '_tf_mode'):
return o._tf_array
return o

Expand Down Expand Up @@ -437,6 +479,8 @@ def make_tf_array(self, X):
count = 0
for p in self.sorted_params:
count += p.make_tf_array(X[count:])
for d in self.data_holders:
d.make_tf_array()
return count

@property
Expand All @@ -450,6 +494,14 @@ def sorted_params(self):
and key is not '_parent']
return sorted(params, key=id)

@property
def data_holders(self):
"""
Return a list of all the child DataHolders
"""
return [child for key, child in self.__dict__.items()
if isinstance(child, DataHolder)]

@property
def fixed(self):
return all(p.fixed for p in self.sorted_params)
Expand All @@ -473,7 +525,7 @@ def get_feed_dict(self):
associated values
"""
d = {}
for p in self.sorted_params:
for p in self.sorted_params + self.data_holders:
d.update(p.get_feed_dict())
return d

Expand Down Expand Up @@ -566,16 +618,10 @@ def _repr_html_(self):
html.append("</table>")
return ''.join(html)

def __getstate__(self):
d = self.__dict__.copy()
d.pop('_parent')
return d

def __setstate__(self, d):
self.__dict__ = d
self._parent = None
Parentable.__setstate__(self, d)
# reinstate _parent graph
for p in self.sorted_params:
for p in self.sorted_params + self.data_holders:
p._parent = self


Expand Down

0 comments on commit 406c1a3

Please sign in to comment.