Skip to content

Commit

Permalink
Improved code-reuse between Add and Prod kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshensman committed Mar 4, 2016
1 parent 1fb6cce commit 5583917
Showing 1 changed file with 7 additions and 28 deletions.
35 changes: 7 additions & 28 deletions GPflow/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,21 +244,21 @@ def make_kernel_names(kern_list):



class Add(Kern):
class Combination(Kern):
"""
Add a list of kernels together.
Combine a list of kernels, e.g. by adding or multiplying (see inherriting classes).
The names of the kernels are generated from their class names.
The names of the kernels to be combined are generated from their class names.
"""
def __init__(self, kern_list):
for k in kern_list:
assert isinstance(k, Kern), "can only add Kern instances"
Kern.__init__(self, input_dim=np.max([k.input_dim for k in kern_list]))

#add kernels to a list, flattening out instances of Add kerns therein.
#add kernels to a list, flattening out instances of this class therein.
self.kern_list = []
for k in kern_list:
if isinstance(k, Add):
if isinstance(k, self.__class__):
self.kern_list.extend(k.kern_list)
else:
self.kern_list.append(k)
Expand All @@ -267,36 +267,15 @@ def __init__(self, kern_list):
names = make_kernel_names(self.kern_list)
[setattr(self, name, k) for name, k in zip(names, self.kern_list)]

class Add(Combination):
def K(self, X, X2=None):
return reduce(tf.add, [k.K(X, X2) for k in self.kern_list])

def Kdiag(self, X):
return reduce(tf.add, [k.Kdiag(X) for k in self.kern_list])


class Prod(Kern):
"""
Multiply a list of kernels together.
The names of the kernels are generated from their class names.
"""
def __init__(self, kern_list):
for k in kern_list:
assert isinstance(k, Kern), "can only Prod Kern instances"
Kern.__init__(self, input_dim=np.max([k.input_dim for k in kern_list]))

#add kernels to a list, flattening out instances of Prod kerns therein.
self.kern_list = []
for k in kern_list:
if isinstance(k, Prod):
self.kern_list.extend(k.kern_list)
else:
self.kern_list.append(k)

#generate a set of suitable names and add the kernels as atributes of this one.
names = make_kernel_names(self.kern_list)
[setattr(self, name, k) for name, k in zip(names, self.kern_list)]

class Prod(Combination):
def K(self, X, X2=None):
return reduce(tf.mul, [k.K(X, X2) for k in self.kern_list])

Expand Down

0 comments on commit 5583917

Please sign in to comment.