In [1]:
from torch import nn
import torch
import torch.nn.functional as F
import numpy as np
%matplotlib notebook
import matplotlib.pyplot as plt
from collections import deque, namedtuple
from mpl_toolkits.mplot3d import Axes3D

In [2]:
class RBF(nn.Module):
    """
    Transforms incoming data using a given radial basis function:
    u_{i} = rbf(||x - c_{i}|| / s_{i})
    Arguments:
        in_features: size of each input sample
        out_features: size of each output sample
    Shape:
        - Input: (N, in_features) where N is an arbitrary batch size
        - Output: (N, out_features) where N is an arbitrary batch size
    Attributes:
        centres: the learnable centres of shape (out_features, in_features).
            The values are initialised from a standard normal distribution.
            Normalising inputs to have mean 0 and standard deviation 1 is
            recommended.
        
        log_sigmas: logarithm of the learnable scaling factors of shape (out_features).
        
        basis_func: the radial basis function used to transform the scaled
            distances.
    """

    def __init__(self, in_features, out_features, basis_func):
        super(RBF, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.centres = nn.Parameter(torch.Tensor(out_features, in_features))
        self.log_sigmas = nn.Parameter(torch.Tensor(out_features))
        self.basis_func = basis_func
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.centres, 0, 0.1)
        nn.init.normal_(self.log_sigmas, 0, 1)

    def forward(self, input):
        size = (self.out_features, self.in_features)
        x = input.expand(size)
        c = self.centres
        distances = (x - c).pow(2).sum(-1).pow(0.5) / torch.exp(self.log_sigmas)
        return self.basis_func(distances)



# RBFs

def gaussian(alpha):
    phi = torch.exp(-1*alpha.pow(2))
    return phi

def linear(alpha):
    phi = alpha
    return phi

def quadratic(alpha):
    phi = alpha.pow(2)
    return phi

def inverse_quadratic(alpha):
    phi = torch.ones_like(alpha) / (torch.ones_like(alpha) + alpha.pow(2))
    return phi

def multiquadric(alpha):
    phi = (torch.ones_like(alpha) + alpha.pow(2)).pow(0.5)
    return phi

def inverse_multiquadric(alpha):
    phi = torch.ones_like(alpha) / (torch.ones_like(alpha) + alpha.pow(2)).pow(0.5)
    return phi

def spline(alpha):
    phi = (alpha.pow(2) * torch.log(alpha + torch.ones_like(alpha)))
    return phi

def poisson_one(alpha):
    phi = (alpha - torch.ones_like(alpha)) * torch.exp(-alpha)
    return phi

def poisson_two(alpha):
    phi = ((alpha - 2*torch.ones_like(alpha)) / 2*torch.ones_like(alpha)) \
    * alpha * torch.exp(-alpha)
    return phi

def matern32(alpha):
    phi = (torch.ones_like(alpha) + 3**0.5*alpha)*torch.exp(-3**0.5*alpha)
    return phi

def matern52(alpha):
    phi = (torch.ones_like(alpha) + 5**0.5*alpha + (5/3) \
    * alpha.pow(2))*torch.exp(-5**0.5*alpha)
    return phi

def basis_func_dict():
    """
    A helper function that returns a dictionary containing each RBF
    """
    
    bases = {'gaussian': gaussian,
             'linear': linear,
             'quadratic': quadratic,
             'inverse quadratic': inverse_quadratic,
             'multiquadric': multiquadric,
             'inverse multiquadric': inverse_multiquadric,
             'spline': spline,
             'poisson one': poisson_one,
             'poisson two': poisson_two,
             'matern32': matern32,
             'matern52': matern52}
    return bases

In [5]:
def tt(x):
    return torch.from_numpy(np.array(x)).float()

t = np.linspace(0,10,500)
x = 20 + np.random.randn(500,2)#25*np.sin(t * 20) + 1*np.random.randn( 500) 

rs = RunningMeanStdOne(2)
rt = RunningMeanStdWelford((2))
rs.fill( tt(x[0] ))


y = []
z = []

for it in x:
    d = rs.update(tt(it))
    g = rt.update(tt(it))
    y.append(d.squeeze().numpy())
    z.append(g.squeeze().numpy())
    
y = np.array(y)
z = np.array(z)
fig = plt.figure(figsize= (10,10))
ax = fig.add_subplot(111)
#ax.plot(t,x)
ax.plot(y[:,1])
ax.plot(z[:,1])



<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7ff476626b38>]

In [169]:
y[0,:]

array([ 0.13618661, -0.07495765], dtype=float32)

In [4]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 5.)
        m.bias.data.fill_(0)
        
        
class RunningMeanStdOne:
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
    def __init__(self, input_size):
        self.input_size = input_size
        self.mean = 0.
        self.var = 0.
        self.capacity = 150
        self.memory = deque(maxlen = self.capacity)
        
    def fill(self, item = None):
        for i in range(self.capacity):
            if item is not None:
                x = item.numpy() + 1*np.random.randn(self.input_size)
                self.memory.append(x)
            else:
                x = np.random.randn(self.input_size)
                self.memory.append(x)

    def update(self, xt):
        x = xt.numpy()
        self.memory.append(x)
        self.calculate()
        
        res =  ((x - self.mean) / self.std)
        res = np.array([res])
        return torch.from_numpy(res).float()

    def calculate(self):
        self.mean = np.mean(self.memory, axis = 0)
        self.std = np.std(self.memory, axis = 0)

    def reset(self):
        self.memory = deque(maxlen = self.capacity)
        
        
class RunningMeanStdWelford(object):
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    def __init__(self, input_size=()):
        self.mean = np.zeros(input_size, 'float32')
        self.var = np.ones(input_size, 'float32')
        self.M2 = np.ones(input_size, 'float32')
        self.count = 0

    def update(self, xt):
        x = xt.numpy()
        self.count += 1
        
        delta = x - self.mean
        self.mean += delta / self.count
        delta2 = x - self.mean
        self.M2 += delta * delta2
        res = (x - self.mean) /  (self.M2 / (self.count - 1) )**(0.5)
        return torch.from_numpy(res).float()
        
        
class RunningMeanStd(object):
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    def __init__(self, input_size=()):
        self.mean = np.zeros(input_size, 'float32')
        self.var = np.ones(input_size, 'float32')
        self.count = 0

    def update(self, xt):
        x = xt.numpy()
        batch_mean = np.mean(x, axis=0)
        batch_var = np.var(x, axis=0)
        batch_count = x.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

        res =  ((xt - self.mean) / self.var).float()

        if np.isnan(res).any():
            return (xt - self.mean)
        else:
            return ((xt - self.mean) / self.var).float()

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        delta = batch_mean - self.mean
        tot_count = self.count + batch_count

        new_mean = self.mean + delta * batch_count / tot_count
        m_a = self.var * (self.count)
        m_b = batch_var * (batch_count)
        M2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
        new_var = M2 / (self.count + batch_count)

        new_count = batch_count + self.count

        self.mean = new_mean
        self.var = new_var
        self.count = new_count
        new_mean = self.mean + delta * batch_count / tot_count
        m_a = self.var * (self.count)
        m_b = batch_var * (batch_count)
        M2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
        new_var = M2 / (self.count + batch_count)

        new_count = batch_count + self.count

        self.mean = new_mean
        self.var = new_var
        self.count = new_count

        
        
class RND(nn.Module):
    def __init__(self, state_dim = 16, k = 16):
        super(RND, self).__init__()      
        self.first = True
        f1 = state_dim
        f2 = 128
        f3 = 16
        self.target =   nn.Sequential(
                            RBF(f1, f2, gaussian),
                            nn.Linear(f2, f2),                  nn.LeakyReLU(),
                            nn.Linear(f2, k),                   nn.Sigmoid()
                            #nn.Linear(f3, f3),                   nn.Sigmoid(),
                            #nn.Linear(f3, f3),                   nn.Sigmoid(),
                            #nn.Linear(f3, f3),                   nn.Sigmoid(),
                            #nn.Linear(f3, k),                    nn.Softmax()
                            )  
        self.predictor = nn.Sequential(
                            RBF(f1, f2, gaussian),
                            nn.Linear(f2, f2),                  nn.LeakyReLU(),
                            nn.Linear(f2, k),                   nn.Sigmoid()
                            #nn.Linear(f3, f3),                   nn.Sigmoid(),
                            #nn.Linear(f3, f3),                   nn.Sigmoid(),
                            #nn.Linear(f3, f3),                   nn.Sigmoid(),
                            #nn.Linear(f3, k),                    nn.Softmax()
                            )


        self.predictor.apply(weights_init)
        self.target.apply(weights_init)

        for param in self.target.parameters():
            param.requires_grad = False

        self.input_norm = RunningMeanStdOne(input_size = state_dim)
        self.output_norm = RunningMeanStdWelford(input_size = 1) 
        
    def reset(self):
        self.predictor.apply(weights_init)
        self.target.apply(weights_init)
        self.first = False

    def forward(self, x):
        if self.first:
            self.input_norm.fill(x)
        #x = self.input_norm.update(x)
        to = self.target(x).detach()
        po = self.predictor(x)
        mse = (to - po).pow(2).sum(0)
        int_reward = mse.detach().float()
        #int_reward = self.output_norm.update(mse.detach().float().unsqueeze(0))
        return to, po, int_reward


In [248]:
rnd.parameters

<bound method Module.parameters of RND(
  (target): Sequential(
    (0): RBF()
    (1): Linear(in_features=128, out_features=128, bias=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Linear(in_features=128, out_features=16, bias=True)
    (4): Sigmoid()
  )
  (predictor): Sequential(
    (0): RBF()
    (1): Linear(in_features=128, out_features=128, bias=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Linear(in_features=128, out_features=16, bias=True)
    (4): Sigmoid()
  )
)>

In [7]:
nx, ny = (0.5, 0.5)
xs = np.linspace(-nx, nx, 50)
ys = np.linspace(-ny, ny, 50)

rnd = RND(2)

xv, yv = np.meshgrid(xs, ys, sparse=False, indexing='ij')
riv = np.zeros(np.shape(xv))



for i, _ in enumerate(xs):
    for j, _  in enumerate(ys):
        food = torch.from_numpy( np.array([xv[i,j], yv[i,j] ] ) ).float()
        pred, targ, ri  = rnd(food)
        #ri = torch.sum((pred - targ).pow(2)).detach().numpy()
        riv[i,j] = ri
        

fig = plt.figure(figsize= (13,13))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(xv, yv, riv)

<IPython.core.display.Javascript object>

<mpl_toolkits.mplot3d.art3d.Poly3DCollection at 0x7ff46dc6f438>

In [397]:

memory = deque(maxlen = 10)

In [417]:
memory.append(3)

In [418]:
memory

deque([1, 1, 1, 1, 1, 1, 2, 2, 2, 3])

In [320]:
rnd.target.weight

AttributeError: 'Sequential' object has no attribute 'weight'

In [338]:
rnd.input_norm.calculate()

In [342]:
rnd.input_norm.var
((x - rnd.input_norm.mean) / rnd.input_norm.var)

tensor([-1.1875e+01, -7.1904e-10])

In [219]:
torch.sum((pred - targ).pow(2))

tensor(nan, grad_fn=<SumBackward0>)

tensor([[ 2.0542e-09, -2.9341e-09,  1.1768e-09, -9.4877e-09,  7.2859e-09,
         -4.4815e-09,  2.1769e-09, -5.3731e-10, -3.4876e-09,  8.2436e-10,
          1.5965e-09, -6.2026e-10, -1.9587e-09,  7.6616e-10, -4.1638e-09,
          4.5467e-09],
        [ 2.0542e-09, -2.9341e-09,  1.1768e-09, -9.4877e-09,  7.2859e-09,
         -4.4815e-09,  2.1769e-09, -5.3731e-10, -3.4876e-09,  8.2436e-10,
          1.5965e-09, -6.2026e-10, -1.9587e-09,  7.6616e-10, -4.1638e-09,
          4.5467e-09],
        [ 2.0542e-09, -2.9341e-09,  1.1768e-09, -9.4877e-09,  7.2859e-09,
         -4.4815e-09,  2.1769e-09, -5.3731e-10, -3.4876e-09,  8.2436e-10,
          1.5965e-09, -6.2026e-10, -1.9587e-09,  7.6616e-10, -4.1638e-09,
          4.5467e-09],
        [ 2.0542e-09, -2.9341e-09,  1.1768e-09, -9.4877e-09,  7.2859e-09,
         -4.4815e-09,  2.1769e-09, -5.3731e-10, -3.4876e-09,  8.2436e-10,
          1.5965e-09, -6.2026e-10, -1.9587e-09,  7.6616e-10, -4.1638e-09,
          4.5467e-09],
        [ 2.0542e-09

ValueError: operands could not be broadcast together with shapes (16,) (2,) 

In [361]:
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure(figsize= (13,13))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(xv, yv, riv)

<IPython.core.display.Javascript object>

<mpl_toolkits.mplot3d.art3d.Poly3DCollection at 0x7fc32d986a90>

In [76]:
help(Axes3D.plot_surface)

Help on function plot_surface in module mpl_toolkits.mplot3d.axes3d:

plot_surface(self, X, Y, Z, *args, norm=None, vmin=None, vmax=None, lightsource=None, **kwargs)
    Create a surface plot.
    
    By default it will be colored in shades of a solid color, but it also
    supports color mapping by supplying the *cmap* argument.
    
    .. note::
    
       The *rcount* and *ccount* kwargs, which both default to 50,
       determine the maximum number of samples used in each direction.  If
       the input data is larger, it will be downsampled (by slicing) to
       these numbers of points.
    
    .. note::
    
       To maximize rendering speed consider setting *rstride* and *cstride*
       to divisors of the number of rows minus 1 and columns minus 1
       respectively. For example, given 51 rows rstride can be any of the
       divisors of 50.
    
       Similarly, a setting of *rstride* and *cstride* equal to 1 (or
       *rcount* and *ccount* equal the number of rows an

In [68]:
# Make data.
X = np.arange(-5, 5, 0.25)
Y = np.arange(-5, 5, 0.25)
X, Y = np.meshgrid(X, Y)
R = np.sqrt(X**2 + Y**2)
Z = np.sin(R)

In [73]:
yv.shape

(20, 20)