<a href="https://colab.research.google.com/github/LucaAmbrogioni/TaylorAlgebra/blob/master/TDDualControl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! git clone https://github.com/3ammor/Weights-Initializer-pytorch.git
import sys

sys.path
sys.path.append('/content/Weights-Initializer-pytorch')

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import PIL.Image
from weight_initializer import Initializer

Cloning into 'Weights-Initializer-pytorch'...
remote: Enumerating objects: 21, done.[K
remote: Total 21 (delta 0), reused 0 (delta 0), pack-reused 21[K
Unpacking objects: 100% (21/21), done.


In [0]:
class Dynamics:
  
    def __init__(self, environment, g = 1., noise=0.05, lam=0.01, control=None):
        self.environment = environment
        self.g = g
        self.noise = noise
        self.xt = []
        self.yt = []
        self.control = control
        self.lam = lam
        self.cost = 0.
        
    def compute_force(self, r, t):
        upscaled_h_cells = torch.nn.functional.interpolate(environment.h_cells, scale_factor=environment.scale,  mode="bilinear", align_corners=True)
        _, Gx, Gy = environment.extrapolate(r[:,0], r[:,1], upscaled_h_cells, 
                                            activation=lambda x: x, 
                                            derivative=True,
                                            d_activation=lambda x: x)
        if self.control is not None:
          Ux = environment.extrapolate(r[:,0], r[:,1], self.control[t][0], 
                                       activation=lambda x: x, 
                                       derivative=False,
                                       std=.5, 
                                       normalized=True)
          Uy = environment.extrapolate(r[:,0], r[:,1], self.control[t][1], 
                                       activation=lambda x: x, 
                                       derivative=False,
                                       std=.5, 
                                       normalized=True)
          control_force = torch.cat((Ux, Uy), 1)
        else:
          control_force = 0.
        grad = torch.cat((Gx.unsqueeze(1), Gy.unsqueeze(1)), 1)
        env_x_repulsion = 3*torch.sigmoid(-10*r[:,0]) - 3*torch.sigmoid(-10*(environment.resolution - r[:,0]))
        env_y_repulsion = 3*torch.sigmoid(-10*r[:,1]) - 3*torch.sigmoid(-10*(environment.resolution - r[:,1]))
        repulsion_force = torch.stack((env_x_repulsion, env_y_repulsion),1)
        F = (self.g * grad + control_force + repulsion_force)
        if self.control is not None:
          control_cost = self.lam*torch.sum(control_force**2,1)
        else:
          control_cost = 0.
        return F, control_cost

    def compute_reward(self, r):
        R = environment.extrapolate(r[:,0], r[:,1], environment.r, 
                                    activation=lambda x: x, 
                                    derivative=False,
                                    std=.5, 
                                    normalized=True)
        return torch.sum(R, 1) 

    def integrate(self, r, dt, N): #Midpoint integration
        num_samples = self.environment.num_samples
        for n in range(N):
            F0, control_cost0 = self.compute_force(r, n)
            F, control_cost = self.compute_force(r + 0.5*dt*F0, n)
            r = r + (F * dt + torch.normal(0., self.noise, (self.environment.num_samples,2)) * dt**(1/2.))
            self.xt += [r.detach().numpy()[:, 0]]
            self.yt += [r.detach().numpy()[:, 1]]
            self.cost +=  0.5*(control_cost0 + control_cost)
        self.cost += - 0.0001*self.compute_reward(r)
        return r

    def sample(self, dt, num_iter):
        r0 = torch.empty(self.environment.num_samples, 2).uniform_(10, self.environment.resolution-10)
        r = self.integrate(r0, dt, num_iter)
        return r

    def reset(self):
        self.xt = []
        self.yt = []
        self.cost = 0.

In [0]:
def logit(x):
  return np.log(x) + np.log(1 - x)

In [0]:
class GaussianEnvironment:

    def __init__(self, resolution, std, num_samples, scale = 5):
        if resolution % scale != 0:
          raise(ValueError)("The resulition should have {} as a factor".format(scale))
        latent_resolution   = int(resolution/scale)
        self.distribution_mean = torch.zeros([num_samples, 1, latent_resolution, latent_resolution])
        self.distribution_std = torch.ones([num_samples, 1, latent_resolution, latent_resolution])
        self.r_distribution_logits = logit(3./(latent_resolution*latent_resolution))*torch.ones([num_samples, latent_resolution*latent_resolution])
        self.num_samples = num_samples
        self.resolution   = resolution
        self.latent_resolution   = latent_resolution
        self.scale = scale
        self.std          = std
        self.environment_hardness = 0.01
        self.reward_hardness = 0.01
        self.is_generated = False
        self.colors = [(torch.tensor([0., 0., 1.]).unsqueeze(1).unsqueeze(2).expand(1,3,resolution, resolution), 0, 0.2),
                       (torch.tensor([0., 1., 0.]).unsqueeze(1).unsqueeze(2).expand(1,3,resolution, resolution), 0.2, 0.4),
                       (torch.tensor([0.58824, 0.29412, 0]).unsqueeze(1).unsqueeze(2).expand(1,3,resolution, resolution), 0.4, 0.6),
                       (torch.tensor([0.5, 0.5, 0.5]).unsqueeze(1).unsqueeze(2).expand(1,3,resolution, resolution), 0.6, 0.8),
                       (torch.tensor([1., 1., 1.]).unsqueeze(1).unsqueeze(2).expand(1,3,resolution, resolution), 0.6, 0.8)]
        self.halfsize     = None
        self.kernel       = None
        self.set_kernel()
        self.c   = None
        self.h   = None
        self.dxh = None
        self.dyx = None

    def visibility_map(self, x0, y0, v0, k0):
        arange        = torch.arange(0., self.resolution).float()
        x, y          = torch.meshgrid([arange, arange])
        h = self.extrapolate(x0, y0, self.h, 
                             activation=lambda x: x, 
                             derivative=False,
                             normalized=True) 
        x0 = x0.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(self.num_samples,1,self.resolution,self.resolution)
        y0 = y0.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(self.num_samples,1,self.resolution,self.resolution)
        h = h.unsqueeze(2).unsqueeze(3).expand(self.num_samples,1,self.resolution,self.resolution)
        
        d_map = torch.sqrt(((x0 - x) ** 2 + (y0 - y) ** 2))
        visibility_mask = 1./(1. + F.relu(d_map - h*k0 + 1)**2)
        hard_mask = 1. - torch.sigmoid(10000*(d_map - h*k0 + 1))
        likelihood_variance = v0 + F.relu(d_map - h*k0 + 1)**3
        return likelihood_variance, visibility_mask, hard_mask
        
    def env_bayesian_update(self, inference_net, x0, y0, v0 = 0.00001, k0 = 30., data=None):
        prior_mean = self.distribution_mean
        prior_std = self.distribution_std

        likelihood_variance, visibility_mask, hard_mask = self.visibility_map(x0, y0, v0, k0)
        
        if data is None:
           mean = self.h
           distribution = torch.distributions.normal.Normal(mean,torch.sqrt(likelihood_variance))
           sample = distribution.rsample()
           data = hard_mask*sample 

        posterior_mean, posterior_var = inference_net.get_posterior_parameters(data, likelihood_variance, prior_mean, prior_std)

        self.distribution_mean = posterior_mean
        self.distribution_std = torch.sqrt(posterior_var)

        variational_loss1 = inference_net.neg_ELBO_loss(data, prior_mean, prior_std, self, likelihood_variance,hard_mask)
        latent = self.h_cells
        variational_loss2 = inference_net.FAVI_loss(data, latent, prior_mean, prior_std, likelihood_variance)
        return variational_loss1 + variational_loss2

    def rew_bayesian_update(self, inference_net, x0, y0, v0 = 0.001, k0 = 20., data=None):
        prior_logits = self.r_distribution_logits
        
        likelihood_variance, visibility_mask, hard_mask = self.visibility_map(x0, y0, v0, k0)
        
        if data is None:
           mean = self.r
           distribution = torch.distributions.normal.Normal(mean,torch.sqrt(likelihood_variance))
           sample = distribution.rsample()
           data = hard_mask*sample 

        posterior_logits = inference_net.get_posterior_parameters(data, likelihood_variance, prior_logits, hard_mask)
        self.r_distribution_logits = posterior_logits

        #variational_loss = inference_net.neg_ELBO_loss(data, prior_logits, self, likelihood_variance) 
        latent = self.r_cells
        variational_loss = inference_net.FAVI_loss(data, latent, prior_logits, likelihood_variance, hard_mask)
        return variational_loss

    def dsigmoidd(self, x):
        sigmoid = torch.sigmoid(x);
        return sigmoid * (1 - sigmoid)

    def get_statistics(self):
        return self.distribution_mean, self.distribution_std, self.r_distribution_logits

    def filter_environment(self, cells):
        upscaled_cells = torch.nn.functional.interpolate(cells, scale_factor=self.scale,  mode="bilinear", align_corners=True)
        pre_map = torch.nn.functional.conv2d(upscaled_cells, 
                                             self.kernel.unsqueeze(0).unsqueeze(0), padding = self.halfsize)
        env_map = torch.sigmoid(self.environment_hardness * pre_map)

        dxh = torch.nn.functional.conv2d(upscaled_cells, 
                                         self.dxkernel.unsqueeze(0).unsqueeze(0), padding = self.halfsize)
        dyh = torch.nn.functional.conv2d(upscaled_cells, 
                                         self.dykernel.unsqueeze(0).unsqueeze(0), padding = self.halfsize)

        dxh = dxh * self.environment_hardness * self.dsigmoidd(self.environment_hardness * pre_map)
        dyh = dyh * self.environment_hardness * self.dsigmoidd(self.environment_hardness * pre_map)
        return env_map, dxh, dyh

    def filter_reward(self, r_cells):
        upscaled_r_cells = torch.nn.functional.interpolate(r_cells.view((self.num_samples,1,self.latent_resolution, self.latent_resolution)), 
                                                           scale_factor=self.scale,  mode="bilinear", align_corners=True)
        reward = (0.1/3)*torch.nn.functional.conv2d(upscaled_r_cells, 
                                                    self.kernel.unsqueeze(0).unsqueeze(0), padding = self.halfsize)
        return reward

    def generate(self):
        mean = self.distribution_mean
        std = self.distribution_std
        distribution = torch.distributions.normal.Normal(mean, 
                                                         std)
        cells = distribution.rsample()
        
        r_logits = self.r_distribution_logits
        r_distribution = torch.distributions.bernoulli.Bernoulli(logits=r_logits)
        r_cells = r_distribution.sample()
        
        env_map, dxh, dyh = self.filter_environment(cells)
        reward = self.filter_reward(r_cells)

        self.c = self.paint(env_map)
        self.h_cells = cells
        self.h = env_map
        self.r = reward
        self.r_cells = r_cells
        self.dxh = dxh
        self.dyh = dyh
        self.is_generated = True
    
    def set_kernel(self):
        self.halfsize = 4*int(np.ceil(2 * self.std))
        arange        = torch.arange(-self.halfsize, self.halfsize + 1).float()
        x, y          = torch.meshgrid([arange, arange])
        self.kernel   = torch.exp(-(x ** 2 + y ** 2) / (2 * self.std ** 2))
        self.dxkernel = -self.kernel.detach() * x / self.std **2
        self.dykernel = -self.kernel.detach() * y / self.std **2

    def extrapolate(self, x0, y0, image, activation, derivative=False, d_activation = None, std=None, normalized=False):
        if std is None: #
          std = self.std
        arange        = torch.arange(0., self.resolution).float()
        x, y          = torch.meshgrid([arange, arange])
        x = x.unsqueeze(0).unsqueeze(0).expand(self.num_samples,1,self.resolution,self.resolution)
        y = y.unsqueeze(0).unsqueeze(0).expand(self.num_samples,1,self.resolution,self.resolution)
        x0 = x0.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(self.num_samples,1,self.resolution,self.resolution)
        y0 = y0.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(self.num_samples,1,self.resolution,self.resolution)
        
        weights = torch.exp(-((x0 - x) ** 2 + (y0 - y) ** 2) / (2 * std ** 2))

        if derivative:
          dx_weights = -(x - x0)*weights / self.std **2
          dy_weights = -(y - y0)*weights / self.std **2

        if normalized:
          weights = weights/torch.sum(weights, (1,2,3), keepdim=True).expand(self.num_samples,1,self.resolution,self.resolution)
        extr = torch.sum(image * weights, (1,2,3))

        if derivative:
          dx_extr = d_activation(extr)*torch.sum(image * dx_weights, (1,2,3))
          dy_extr = d_activation(extr)*torch.sum(image * dy_weights, (1,2,3))
          return activation(extr), dx_extr, dy_extr
        else:
          extr = activation(torch.sum(image * weights, (2,3)))
          return activation(extr)

    def soft_indicator(self, lower, upper, soft):
        indicator = lambda height: torch.sigmoid(soft * (height - lower)) * (1 - torch.sigmoid(soft * (height - upper)))
        return indicator

    def paint(self, x):
        return sum([color.expand(self.num_samples,3,self.resolution,self.resolution) * self.soft_indicator(lower, upper, 10.)(x) for color, lower, upper in self.colors])
        

In [0]:
class HJB(torch.nn.Module):

    def __init__(self, image_size, x_force, y_force, noise_map, reward, lam, dt, intermediate_reward=False):
        super(HJB, self).__init__()
        self.image_size = image_size
        self.x_force = x_force
        self.y_force = y_force
        self.noise_map = noise_map
        self.reward = reward
        self.lam = lam
        self.dt = dt
        self.kx, self.ky, self.k_laplace = self._get_derivative_filters()
        self.intermediate_reward = intermediate_reward
        #self.kx_minus, self.kx_plus, self.ky_minus, self.ky_plus = self._get_derivative_filters()

    def _get_derivative_filters(self): #Upwind method
        ky = torch.tensor([[1., 2. , 1.], [0., 0., 0.], [-1., -2. , -1.]])/4.
        ky = ky.expand(1,1,3,3)
        kx = torch.transpose(ky, 3, 2)
        k_laplace = torch.tensor([[1., 1. , 1.], [1., -8. , 1.], [1., 1. , 1.]])
        k_laplace = k_laplace.expand(1,1,3,3)
        return kx, ky, k_laplace

    def backward_update(self, V, control=False):
       Vpad = torch.nn.functional.pad(V, (1,1,1,1), "reflect")
       dVx = torch.nn.functional.conv2d(Vpad, self.kx, padding = 0)
       dVy = torch.nn.functional.conv2d(Vpad, self.ky, padding = 0)
       LV = torch.nn.functional.conv2d(Vpad, self.k_laplace, padding = 0)
       if self.intermediate_reward:
          r = self.reward
       else:
          r = 0.
       update = (-r - dVx**2/(2*self.lam) - dVy**2/(2*self.lam) + self.x_force * dVx + self.y_force * dVy + self.noise_map**2*LV)
       if control:
          Ux = -(1/self.lam)*dVx 
          Uy = -(1/self.lam)*dVy
          return update, Ux, Uy
       else:
          return update

    def backward_step(self, V):
        update, Ux, Uy = self.backward_update(V, control=True)
        Vprev = V + self.dt*update
        return Vprev, Ux, Uy

    def RK_backward_step(self, V):
        k1, Ux, Uy = self.backward_update(V, control=True)
        k1 *= self.dt
        k2 = self.dt*self.backward_update(V + k1/2)
        k3 = self.dt*self.backward_update(V + k2/2)
        k4 = self.dt*self.backward_update(V + k3)
        return V + (k1 + 2*k2 + 2*k3 + k4)/6., Ux, Uy 

    def compute_value(self, N, RK = False, plot=False):
        Vn = -self.reward 
        V_list = [-Vn]
        U_list = [None]
        for n in reversed(range(N)):
         if n % 20 == 0:
            if plot:
              x,y = (np.arange(0, resolution), np.arange(0, resolution))
              plt.imshow(Vn[0,:,:,:].detach().numpy().squeeze(), extent = [0, resolution, 0, resolution], origin="lower")
              plt.quiver(x, y, environment.dyh[0,:,:,:].detach().numpy().squeeze(), environment.dxh[0,:,:,:].detach().numpy().squeeze())
              #plt.quiver(x, y, Ux[0,:,:,:].detach().numpy().squeeze(), Uy[0,:,:,:].numpy().squeeze(), color="red")
              fig = plt.gcf()
              fig.set_size_inches(18.5, 18.5)
              plt.show()
         if not RK:
           Vn, Ux, Uy = self.backward_step(Vn)
         else:
           Vn, Ux, Uy = self.RK_backward_step(Vn) 
         V_list.append(-Vn)
         U_list.append((-Uy, -Ux)) #TODO: flipped/sign flipped
        return list(reversed(V_list)), list(reversed(U_list))

In [0]:
class EnvInferenceNet(nn.Module):

    def __init__(self, gain, h_size=30, k_size=3, var_k_size=3, latent_resolution = 8, scale_factor = 5):
        super(EnvInferenceNet, self).__init__()

        self.conv_in = nn.Conv2d(h_size, h_size, k_size, padding=0) #Input: h_mean, h_std, r_mean, r_std times h_size
        self.out =  nn.Linear(latent_resolution*latent_resolution*h_size*scale_factor**2, latent_resolution*latent_resolution)

        self.var_l1 = nn.Conv2d(1, h_size, 1, padding=0, bias=False)
        self.var_out = nn.Conv2d(h_size, 1 , 1, padding=0, bias=False)

        # Parameters
        self.h_size = h_size
        self.k_size = k_size
        self.k_pad = int((k_size - 1)/2)
        self.var_k_pad = int((var_k_size - 1)/2)
        self.latent_resolution = latent_resolution
        self.scale_factor = scale_factor
        self.gain = gain

    def forward(self,  data, likelihood_var):

        activation = lambda x: torch.relu(x)
        b_size = data.shape[0]

        x = data.repeat(1,self.h_size,1,1)
        x_pad = torch.nn.functional.pad(x, (self.k_pad,self.k_pad,self.k_pad,self.k_pad), "reflect")
        h = activation(self.conv_in(x_pad)).view(b_size, self.h_size*self.latent_resolution*self.latent_resolution*self.scale_factor**2)
        latent_data = self.out(h)
        latent_data = latent_data.view(b_size,1,self.latent_resolution,self.latent_resolution)
        
        x = F.interpolate(likelihood_var, scale_factor=1/self.scale_factor,  mode="bilinear", align_corners=True)
        x = activation(self.var_l1(x))
        x = self.var_out(x)**2

        return self.gain*latent_data, x

    def get_posterior_parameters(self, data, likelihood_var, prior_mean, prior_std):
       latent_data, latent_variance = self(data, likelihood_var)
       posterior_var = 1/(1/prior_std**2 + 1/latent_variance)
       posterior_mean = (prior_mean/prior_std**2 + latent_data/latent_variance)*posterior_var
       return posterior_mean, posterior_var

    def neg_ELBO_loss(self, data, prior_mean, prior_std, environment, lk_variance, mask):
       prior_distribution = torch.distributions.normal.Normal(prior_mean, prior_std)
       posterior_mean, posterior_var =  self.get_posterior_parameters(data, lk_variance, prior_mean, prior_std)
       post_distribution = torch.distributions.normal.Normal(posterior_mean,torch.sqrt(posterior_var))

       posterior_sample = post_distribution.rsample()
       lik_filter = lambda x: environment.filter_environment(x)[0]
       avg_log_lik = torch.mean(-0.5*mask*(data - lik_filter(posterior_sample))**2/lk_variance)

       KL_regularization = torch.distributions.kl.kl_divergence(post_distribution, prior_distribution)
       return torch.mean(-avg_log_lik + KL_regularization)

    def FAVI_loss(self, data, latent, prior_mean, prior_std, lk_variance):
       posterior_mean, posterior_var =  self.get_posterior_parameters(data, lk_variance, prior_mean, prior_std)
       loss =  torch.mean(0.5*(latent - posterior_mean)**2/posterior_var + 0.5*torch.log(2*np.pi*posterior_var))
       return loss

In [0]:
class RewInferenceNet(nn.Module):

    def __init__(self, gain, h_size=60, k_size=5, latent_resolution = 8, scale_factor = 5):
        super(RewInferenceNet, self).__init__()

        self.l = nn.Linear(latent_resolution*latent_resolution*scale_factor**2, latent_resolution*latent_resolution)

        #self.var_l1 = nn.Conv2d(1, h_size, 1, padding=0, bias=False)
        #self.var_out = nn.Conv2d(h_size, 1 , 1, padding=0, bias=False)

        # Parameters
        self.h_size = h_size
        self.k_size = k_size
        self.k_pad = int((k_size - 1)/2)
        self.latent_resolution = latent_resolution
        self.scale_factor = scale_factor
        self.gain = gain

    def forward(self, data, likelihood_var, hard_mask):
        b_size = data.shape[0]
        mask = F.interpolate(hard_mask, scale_factor=1/self.scale_factor,  mode="bilinear", align_corners=True).view((b_size,self.latent_resolution*self.latent_resolution))
        x = mask*(0.1*F.softplus(self.l(data.view(b_size, self.latent_resolution*self.latent_resolution*self.scale_factor**2))) - 2.)
        return x

    def get_posterior_parameters(self, data, likelihood_var, prior_logits, hard_mask):
       latent_logits = self(data, likelihood_var, hard_mask)
       posterior_logits = prior_logits + latent_logits
       return posterior_logits

    def neg_ELBO_loss(self, data, prior_logits, environment, lk_variance):
       prior_distribution = torch.distributions.categorical.Categorical(logits=prior_logits)
       posterior_logits =  self.get_posterior_parameters(data, lk_variance, prior_logits)
       post_distribution = torch.distributions.categorical.Categorical(logits=posterior_logits)

       enumeration = post_distribution.enumerate_support(expand=False)
       log_probs = post_distribution.log_prob(enumeration).transpose(1,0)
       probs = torch.exp(log_probs).unsqueeze(2).unsqueeze(3)
       log_lk = torch.sum(-0.5*(data - environment.filter_reward(enumeration[:,0]).transpose(1,0))**2/lk_variance, (2,3))
       avg_log_lik = torch.mean(probs*log_lk.detach())
#
       KL_regularization = torch.distributions.kl.kl_divergence(post_distribution, prior_distribution)

       return torch.mean(-avg_log_lik + KL_regularization)

    def FAVI_loss(self, data, latent, prior_logits, lk_variance, hard_mask):
       b_size = data.shape[0]
       weights = F.interpolate(hard_mask, scale_factor=1/self.scale_factor,  mode="bilinear", align_corners=True).view((b_size,self.latent_resolution*self.latent_resolution))
       loss_fn = torch.nn.BCEWithLogitsLoss(weight=weights.detach())
       posterior_logits = self.get_posterior_parameters(data, lk_variance, prior_logits, hard_mask)
       loss =  loss_fn(posterior_logits, latent.detach())
       
       if False and iteration % 10 == 0:
          plot_map(data)
          mean_r_cells = torch.sigmoid(torch.nn.functional.interpolate(posterior_logits.view((b_size,1,self.latent_resolution, self.latent_resolution)), 
                                                                       scale_factor=self.scale_factor,  mode="bilinear", align_corners=True))
          r_mean = (0.1/3)*torch.nn.functional.conv2d(mean_r_cells, 
                                                      environment.kernel.unsqueeze(0).unsqueeze(0), padding = environment.halfsize)
          r_var = r_mean*(1 - r_mean)
          plot_map(r_mean)
          plot_map(r_var)

       return loss

In [0]:
# Policy network TODO: Work in progress

class ValueNet(nn.Module):
    def __init__(self, environment, smoothing_std=2, h_size=40, k_size=1):
        super(ValueNet, self).__init__()
        
        self.conv_in = nn.Conv2d(3, h_size, k_size, padding=0, bias=False) 
        self.out = nn.Linear(environment.latent_resolution*environment.latent_resolution*h_size,
                             environment.latent_resolution*environment.latent_resolution, bias=False) 
        self.conv_out = nn.Conv2d(h_size, 1, k_size, padding=0, bias=False)
        
        # Smoothing layer
        self.smoothing_std = smoothing_std

        # Parameters
        self.h_size = h_size
        self.k_size = k_size
        self.k_pad = int((k_size - 1)/2)
        self.halfsize = 4*int(np.ceil(2 * smoothing_std))
        arange        = torch.arange(-self.halfsize, self.halfsize + 1).float()
        x, y          = torch.meshgrid([arange, arange])
        self.smoothing_ker = torch.exp(-(x ** 2 + y ** 2) / (2 * smoothing_std ** 2))
        self.environment = environment
        self.V_trace = None

    def forward(self,  h_mean, h_std, r_logits, N, g, dt, exploit=False):

        activation = lambda x: F.softplus(x)

        predicted_reward = 0.1*environment.filter_reward(torch.sigmoid(r_logits))

        x = torch.cat((h_mean, 
                       h_std, 
                       r_logits.view((environment.num_samples, 
                                      1, 
                                      environment.latent_resolution, 
                                      environment.latent_resolution))), 
                      1)
        x = activation(self.conv_in(x))
        y = x.view((environment.num_samples, self.h_size*environment.latent_resolution**2))
        z = self.conv_out(x)
        x = z #+ 0.1*self.out(y).view((environment.num_samples, 1, environment.latent_resolution, environment.latent_resolution))
        x = F.interpolate(x, scale_factor=environment.scale,  
                          mode="bilinear", align_corners=True)
        x_pad = torch.nn.functional.pad(x, (self.halfsize,self.halfsize,self.halfsize,self.halfsize), "reflect")
        output = torch.nn.functional.conv2d(x_pad.view(environment.num_samples, 
                                                       1, 
                                                       x_pad.shape[2], 
                                                       x_pad.shape[3]), self.smoothing_ker.unsqueeze(0).unsqueeze(0)).view(environment.num_samples, 
                                                                                                                           1, 
                                                                                                                           environment.resolution, 
                                                                                                                           environment.resolution)
        value = 0.01*F.softplus(output)

        if exploit is False:
            hjb_input = value
        else:
            hjb_input = predicted_reward

        hjb = HJB(image_size=environment.resolution, 
                  x_force= -g*environment.dyh, #TODO: this should be changed
                  y_force= -g*environment.dxh, #TODO: this should be changed
                  noise_map= 0.25, #TODO: this should be changed
                  reward=hjb_input, 
                  lam= 0.02, 
                  dt=dt)
        _, Ulist = hjb.compute_value(N, RK=True)
        return value, Ulist

    def TDloss(self, reward, value, future_value, kernel, gamma=0.9):
        reward = reward.unsqueeze(1).unsqueeze(2).unsqueeze(3)
        future_value = future_value.unsqueeze(2).unsqueeze(3)
        #
        TD_target = (reward + gamma*future_value).detach()
        loss = torch.mean(kernel*(value - TD_target)**2)
        return loss

    def TD_lambda_loss(self, reward, value, future_value, kernel, step, gamma=0.95, lam = 0.2):
        if step == 0:
          self.V_trace = 0.
        else:
          self.V_trace = gamma*lam*self.V_trace + value
        future_value = future_value.unsqueeze(2).unsqueeze(3)
        loss = torch.mean(kernel*(reward + gamma*future_value.detach() - value).detach()*V_trace)
        return loss


In [0]:
def plot_trajectories(Ulist, environment, dynamics, value):
    x0_range = [20., 40.]
    y0_range = [20., 40.]

    _, _, r_logits = environment.get_statistics()
    r_map = environment.filter_reward(torch.sigmoid(r_logits))

    x,y = (np.arange(0, environment.resolution), np.arange(0, environment.resolution))
    #plt.imshow(environment.h[0,0,:,:].detach().numpy().squeeze(), extent = [0, environment.resolution, 0, environment.resolution], origin="lower")
    #plt.contour(x, y, environment.h[0,0,:,:].detach().numpy().squeeze(), colors='red')
    plt.imshow(r_map[0,0,:,:].detach().numpy().squeeze(), extent = [0, environment.resolution, 0, environment.resolution], origin="lower")
    plt.quiver(x, y, environment.dyh[0,0,:,:].detach().numpy().squeeze(), environment.dxh[0,:,:,:].detach().numpy().squeeze())
    plt.plot(np.array(dynamics.yt)[:,0], np.array(dynamics.xt)[:,0], linewidth=4, color = "red")
    plt.plot(np.array(dynamics.yt)[0,0], np.array(dynamics.xt)[0,0], "xb")
    plt.colorbar()

    fig = plt.gcf()
    fig.set_size_inches(18.5, 18.5)
    plt.show()

In [0]:
def plot_map(mp, norm=False, lim=1.):
    x0_range = [20., 40.]
    y0_range = [20., 40.]

    x,y = (np.arange(0, environment.resolution), np.arange(0, environment.resolution))
    plt.imshow(mp[0,0,:,:].detach().numpy().squeeze(), extent = [0, environment.resolution, 0, environment.resolution], origin="lower")
    plt.colorbar()

    fig = plt.gcf()
    fig.set_size_inches(18.5, 18.5)
    #plt.clim(0,1.)
    if norm:
      plt.clim(0,1.)
    plt.show()

In [0]:
import pickle

def save_network(net, name):
  pickle.dump(net, open( "{}.p".format(name), "wb" ))

In [0]:
def load_network(name):
  return pickle.load( open( "{}.p".format(name), "rb" ) )

In [0]:
# Train
N_iters = 2000 
RL_batch_size = 3
VI_batch_size = 20
N_steps = 5
N_intergration_steps = 400 #200
N_VI_iterations = 400

resolution = 40 #40
scale = 8 #5
std = 7.5
g = 0.0005 #0.005
noise = 0.3
dt = 0.1

environment = GaussianEnvironment(resolution=resolution, std=std, num_samples=VI_batch_size, scale=scale)
net = ValueNet(environment) #TODO: Multiple networks
#Initializer.initialize(model=net, initialization=nn.init.xavier_uniform, gain=nn.init.calculate_gain('relu'))
optimizer = optim.Adam(net.parameters(), lr=0.00001) 

env_inference_net = EnvInferenceNet(gain=1., scale_factor = scale, latent_resolution = int(resolution/scale)) #TODO: Multiple networks
#Initializer.initialize(model=env_inference_net, initialization=nn.init.xavier_uniform, gain=nn.init.calculate_gain('relu'))
env_VI_optimizer = optim.Adam(env_inference_net.parameters(), lr=0.00001) 

reward_inference_net = RewInferenceNet(gain=1., scale_factor = scale, latent_resolution = int(resolution/scale)) #TODO: Multiple networks
#Initializer.initialize(model=reward_inference_net, initialization=nn.init.xavier_uniform, gain=nn.init.calculate_gain('relu'))
reward_VI_optimizer = optim.Adam(reward_inference_net.parameters(), lr=0.0001) 

loss_list = []
env_VI_loss_list = []
reward_VI_loss_list = []

In [0]:
load_value_net = False
try:
  env_inference_net = load_network("env_net")
  reward_inference_net = load_network("reward_net")
  print("Loading inference networks")
  N_VI_itr = 0
except:
  print("Training inference networks")
  N_VI_itr = N_VI_iterations

if load_value_net:
  try:
    net = load_network("value_net")
    optimizer = optim.Adam(net.parameters(), lr=0.0001)
    print("Loading value networks")
  except:
    print("Training value network")

for iteration in range(N_iters):
   if iteration > N_VI_itr:
      batch_size = RL_batch_size
   else:
      batch_size = VI_batch_size
   print("Iteration: {}".format(iteration))
   environment = GaussianEnvironment(resolution=resolution, std=std, num_samples=batch_size, scale=scale)
   dynamics = Dynamics(environment, g=g, noise=noise, lam=0.0000)
   environment.generate()
   r = dynamics.sample(dt, 1)
   if iteration > N_VI_itr:
      environment.env_bayesian_update(env_inference_net, r[:,0], r[:,1])
      environment.generate() 
   total_loss = 0.
   total_reward = 0.
   total_env_VI_loss = 0.
   total_reward_VI_loss = 0.
   reward = torch.zeros((batch_size,))
   for step in range(N_steps):
      print("Step: {}".format(step))
      # zero the parameter gradients
      optimizer.zero_grad()
      env_VI_optimizer.zero_grad()
      env_VI_optimizer.zero_grad()

      ## Control ##
      if step == N_steps - 1:
        exploit = True
      else:
        exploit = False
      if iteration > N_VI_itr:
        h_mean, h_std, r_logits = environment.get_statistics()
        value, Ulist = net.forward(h_mean, 
                                   h_std, 
                                   r_logits, 
                                   N_intergration_steps, g, dt, exploit=exploit)
        dynamics.control = Ulist
        print(value.max())

      old_r = r

      if iteration > N_VI_itr:
        r = dynamics.integrate(r, dt, N_intergration_steps).detach()
      else:
        r = dynamics.sample(dt, 1)

      if np.any(np.isnan(r.detach().numpy())):
         print("not a number found in the new coordinates")
         break

      if np.any(r.detach().numpy() > resolution + 8) or np.any(r.detach().numpy() < -8):
         print("The agent has left the environment")
         break

      if iteration % 1 == 0 and iteration > N_VI_itr:
         plot_trajectories(Ulist, environment, dynamics, value)
         save_network(net, "value_net")

      ## Reward ##
      new_reward = -dynamics.cost

      # Bayesian update
      env_VI_loss = environment.env_bayesian_update(env_inference_net, r[:,0], r[:,1])
      reward_VI_loss = environment.rew_bayesian_update(reward_inference_net, r[:,0], r[:,1])

      ## Information gain ##
      if iteration > N_VI_itr:
        if step < N_steps - 1:
          new_h_mean, new_h_std, new_r_logits = environment.get_statistics() 
          future_value_map,_ = net.forward(new_h_mean, new_h_std, new_r_logits, N_intergration_steps, g, dt, exploit=exploit)
          future_value = environment.extrapolate(r[:,0], r[:,1], future_value_map, 
                                                 activation=lambda x: x, 
                                                 derivative=False,
                                                 std=.5, 
                                                 normalized=True).detach()
        else:
          future_value = new_reward.unsqueeze(1)
      
      
        ## TD kernel ##
        arange        = torch.arange(0, environment.resolution).float()
        x, y          = torch.meshgrid([arange, arange])
        x = x.unsqueeze(0).unsqueeze(0).expand(environment.num_samples,1,environment.resolution,environment.resolution)
        y = y.unsqueeze(0).unsqueeze(0).expand(environment.num_samples,1,environment.resolution,environment.resolution)
        x0 = old_r[:,0].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(environment.num_samples,1,environment.resolution,environment.resolution)
        y0 = old_r[:,1].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(environment.num_samples,1,environment.resolution,environment.resolution)
        kernel   = torch.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * 3. ** 2))
        
        ## TD learning ##
        loss = net.TDloss(reward, value, future_value, kernel, gamma=0.95)
        #loss = net.TD_lambda_loss(reward, value, future_value, kernel, step, gamma=0.95, lam = 0.2)
        if not np.isnan(loss.detach().numpy()):
          loss.backward(retain_graph=True)
          optimizer.step()
          environment.generate()
          total_loss += float(loss.detach().numpy())
          total_reward += float(torch.sum(reward).detach().numpy())
        else:
          break

      ## Reward ##
      reward = new_reward
 
      ## VI update ##
      if iteration < N_VI_itr:
        env_VI_loss.backward(retain_graph=True)
        reward_VI_loss.backward(retain_graph=True)
        env_VI_loss.backward(retain_graph=True)
        reward_VI_loss.backward(retain_graph=True)
        env_VI_optimizer.step()
        reward_VI_optimizer.step()
      total_env_VI_loss += float(env_VI_loss.detach().numpy())
      total_reward_VI_loss += float(reward_VI_loss.detach().numpy())
   if iteration == N_VI_itr:
     save_network(env_inference_net, "env_net")
     save_network(reward_inference_net, "reward_net")
     
   if iteration > N_VI_itr:
     print("Reward: {}".format(total_reward))
   else:
     print("VI env loss: {}".format(total_env_VI_loss))
     print("VI rew loss: {}".format(total_reward_VI_loss))
   #loss_list += [loss.detach().numpy()]#
   env_VI_loss_list += [total_env_VI_loss]
   reward_VI_loss_list += [total_reward_VI_loss]
   if iteration == N_VI_itr:
     plt.plot(env_VI_loss_list)
     plt.show()
     plt.plot(reward_VI_loss_list)
     plt.show()


Training inference networks
Iteration: 0
Step: 0
Step: 1
Step: 2
Step: 3
Step: 4
VI env loss: 422241042432.0
VI rew loss: 0.22977444529533386
Iteration: 1
Step: 0
Step: 1
Step: 2
Step: 3
Step: 4
VI env loss: 363338297344.0
VI rew loss: 0.3573695085942745
Iteration: 2
Step: 0
Step: 1
Step: 2
Step: 3
Step: 4
VI env loss: 406224902144.0
VI rew loss: 0.31365251168608665
Iteration: 3
Step: 0
Step: 1
Step: 2
Step: 3
Step: 4
VI env loss: 503899504640.0
VI rew loss: 0.2277168594300747
Iteration: 4
Step: 0
Step: 1
Step: 2
Step: 3
Step: 4
VI env loss: 463332421632.0
VI rew loss: 0.4688900001347065
Iteration: 5
Step: 0
Step: 1
Step: 2
Step: 3
Step: 4
VI env loss: 406449203200.0
VI rew loss: 0.3349279426038265
Iteration: 6
Step: 0
Step: 1
Step: 2
Step: 3
Step: 4
VI env loss: 329373618176.0
VI rew loss: 0.2615988291800022
Iteration: 7
Step: 0
Step: 1
Step: 2
Step: 3
Step: 4
VI env loss: 324248328192.0
VI rew loss: 0.2598825991153717
Iteration: 8
Step: 0
Step: 1
Step: 2
Step: 3
Step: 4
VI env loss: 