# Testing Predictive Performance of a direct prediction model on Flatland

This notebook tests the predictive performance of a model trained to directly predict the next observation given the action and the current observation. Test results for several seeds are saved and can be plotted and compared against other methods using the fig6_createplot notebook. The network is the same as that used by our method, but the training protocol here uses batches of transitions collected from the environment.

In [1]:
import os
import gym
import math
import numpy as np
import time
from PIL import Image
import matplotlib.pyplot as plt
from IPython import display
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pickle

os.chdir('src/flatland/flat_game/')
from env import Env


pygame 1.9.6
Hello from the pygame community. https://www.pygame.org/contribute.html
Loading chipmunk for Linux (64bit) [/home/william/.local/lib/python3.6/site-packages/pymunk/libchipmunk.so]


# Flatworld environment

We start by defining the flatworld environment, which is based on the code available at https://github.com/Caselles/NeurIPS19-SBDRL. This environment returns pixel observations of a ball on a cyclical 2D grid. The available (discrete) actions step the ball by a fixed amount in all four directions.

In [2]:
RADIUS = 15
PERIOD = 5

class FlatWorld():
    
    class action_space():
        def __init__(self,n_actions):
            self.n = n_actions
            
        def sample(self, k=1):
            return torch.randint(0,self.n,(k,))       

    class observation_space():
        def __init__(self):
            self.shape = [84,84]
    
    def __init__(self, env_parameters, period=10, radius=15):

        self.action_space = self.action_space(4)
        self.observation_space = self.observation_space()  
        self.period = period
        
        self.step_size = 0.1*(63-2*env_parameters['agent']['radius'])/period
        start_positions_list = [27 + 10*self.step_size*i for i in range(period)]
        self.start_positions = []
        for i in start_positions_list:
            for j in start_positions_list:
                self.start_positions.append((i,j))
                        
        env_parameters['agent']['radius'] = radius
        self.env = Env(**env_parameters)
        
    def reset(self, start_position=None):
        if start_position==None:
            #obs = self.env.reset(position=random.sample(self.start_positions, 1)[0])
            obs = self.env.reset(position=self.start_positions[12])
        else:
             obs = self.env.reset(position=start_position)
        return torch.FloatTensor(obs)/255
    
    def step(self, action):
        action_dict = self.create_action_dict(action)
        obs, reward, done, info = self.env.step(action_dict)
        return torch.FloatTensor(obs)/255
    
    def create_action_dict(self, action):
        action_dict = {}
        if action == 0:
            action_dict['longitudinal_velocity'] = 0
            action_dict['lateral_velocity'] = self.step_size
            action_dict['angular_velocity'] = 0
        if action == 1:
            action_dict['longitudinal_velocity'] = 0
            action_dict['lateral_velocity'] = -self.step_size
            action_dict['angular_velocity'] = 0
        if action == 2:
            action_dict['longitudinal_velocity'] = self.step_size
            action_dict['lateral_velocity'] = 0
            action_dict['angular_velocity'] = 0
        if action == 3:
            action_dict['longitudinal_velocity'] = -self.step_size
            action_dict['lateral_velocity'] = 0
            action_dict['angular_velocity'] = 0
        return action_dict
        

### Flatland hyperparameters

In [3]:
agent_parameters = {
    'radius': 15,
    'speed': 10,
    'rotation_speed' : math.pi/8,
    'living_penalty': 0,
    'position': (30,30),
    'angle': 0,
    'sensors': [
      
        {
           'nameSensor' : 'proximity_test',
           'typeSensor': 'proximity',
           'fovResolution': 64,
           'fovRange': 300,
           'fovAngle': math.pi ,
           'bodyAnchor': 'body',
           'd_r': 0,
           'd_theta': 0,
           'd_relativeOrientation': 0,
           'display': False,
        }
        
       
    ],
    'actions': ['forward', 'turn_left', 'turn_right', 'left', 'right', 'backward'],
    'measurements': ['health', 'poisons', 'fruits'],
    'texture': {
        'type': 'color',
        'c': (255, 255, 255)
    },
    'normalize_measurements': False,
    'normalize_states': False,
    'normalize_rewards': False
}

env_parameters = {
    'map':False,
    'n_rooms': 2,
    'display': False,
    'horizon': 10001,
    'shape': (84, 84),
    'mode': 'time',
    'poisons': {
        'number': 0,
        'positions': 'random',
        'size': 10,
        'reward': -10,
        'respawn': True,
        'texture': {
            'type': 'color',
            'c': (255, 255, 255),
        }
    },
    'fruits': {
        'number': 0,
        'positions': 'random',
        'size': 10,
        'reward': 10,
        'respawn': True,
        'texture': {
            'type': 'color',
            'c': (255, 150, 0),
        }
    },
    'obstacles': [
       
    ],
    'walls_texture': {
        'type': 'color',
        'c': (1, 1, 1)
    },
    'agent': agent_parameters
}

### Define Encoder and Decoder (we use a compact latent space for the observations to which we concatenate a one-hot embedding of the actions)

In [4]:
class Encoder(nn.Module):

    def __init__(self, n_out=4, n_hid = 64):

        super().__init__()

        self.conv = nn.Conv2d(1, 5, 10, stride=3)
        self.fc1 = nn.Linear(180, n_hid)
        self.fc2 = nn.Linear(n_hid, n_out)

    def forward(self, x):
        if len(x.shape) != 3:
            x = x.unsqueeze(0)
        x = F.relu(self.conv(x.unsqueeze(1)))
        x = F.max_pool2d(x, 4, 4)
        x = x.view(-1, 180)
        x = F.relu(self.fc1(x))
        return self.fc2(x).squeeze()

class Decoder(nn.Module):
    
    def __init__(self, n_in=4, n_actions=4, n_hid = 64):

        super().__init__()
        
        self.fc1 = nn.Linear(n_in+n_actions, n_hid)
        self.fc2 = nn.Linear(n_hid, 180)
        self.conv = nn.ConvTranspose2d(5, 1, 34, stride=10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = x.reshape(-1,5,6,6)
        x = self.conv(x)
        return torch.sigmoid(x).squeeze()

# Create dataset of transitions

These transitions are collected from length 10 episodes, where each episode starts at the centre of the environment.

In [5]:
def create_dataset(n_data=5000):
    dataset = []
    obs_env = FlatWorld(env_parameters, period=PERIOD, radius=RADIUS)
    step = 0
    for i in range(n_data):
        if step == 0:
            obs_x = obs_env.reset()
        action = obs_env.action_space.sample().item()
        action_embedding = torch.zeros(4)
        action_embedding[action] = 1
        next_obs_x = obs_env.step(action)
        transition = (obs_x,action_embedding,next_obs_x)
        dataset.append(transition)
        obs_x = next_obs_x
        step += 1
        if step == 10:
            step = 0
    return dataset

In [6]:
dataset = create_dataset()

# Define a RepresentationLearner that puts all these elements together

The RepresentationLearner, once initialized, can automatically be trained and tested on Flatland

In [7]:
def mk_dir(export_dir, quite=False):
    if not os.path.exists(export_dir):
        try:
            os.makedirs(export_dir)
            if not quite:
                print('Created directory: ', export_dir)
        except OSError as exc: # Guard against race condition
             if exc.errno != exc.errno.EEXIST:
                raise
        except Exception:
            pass
    else:
        if not quite:
            print('Directory already exists: ', export_dir)

class RepresentationLearner:
    def __init__(self,
                 dim=4,
                 n_hid=64,
                 lr_dec=1e-2,
                 lr_enc=1e-2,
                 exp_name=None):
        
        self.n_hid = n_hid
        self.lr_dec = lr_dec
        self.lr_enc = lr_enc
        self.exp_name = exp_name
        
        mk_dir('results',quite=True)
        self.save_folder = 'results/direct_' + self.exp_name
        mk_dir(self.save_folder, quite=True)
        
        self.decoder = Decoder(n_in = dim, n_hid = 64)
        self.encoder = Encoder(n_out = dim, n_hid = 64)
        self.optimizer_dec = optim.Adam(self.decoder.parameters(),
                                   lr=1e-3,
                                   weight_decay=0)

        self.optimizer_enc = optim.Adam(self.encoder.parameters(),
                                   lr=1e-3,
                                   weight_decay=0)
        
        self.losses = []
        
    def learn(self, n_sgd_steps=5000, batch_size=32):
        """
        Learns to directly predict the next observation from the current observation 
        and the chosen action
        n_sgd_steps: number of gradient descent steps
        batch_size: number of transitions in batch
        """

        t_start = time.time()
        self.losses = []
        for i in range(n_sgd_steps):
            batch = random.sample(dataset, batch_size)
            states = torch.stack([x[0] for x in batch])
            actions = torch.stack([x[1] for x in batch])
            next_states = torch.stack([x[2] for x in batch])

            embedding = torch.cat((self.encoder(states),actions),axis=1)
            reconstructions = self.decoder(embedding)

            loss = F.binary_cross_entropy(reconstructions, next_states)
            self.losses.append(loss.item())

            self.optimizer_dec.zero_grad()
            self.optimizer_enc.zero_grad()
            loss.backward()
            self.optimizer_enc.step()
            self.optimizer_dec.step()

            if i%10==0:
                print("iter {} : loss={:.3f} : last 10 iters in {:.3f}s".format(i, loss.item(), time.time() - t_start),
                      end="\r" if i%100 else "\n")
                t_start = time.time()
                
        plt.plot(self.losses)
        plt.savefig(self.save_folder + '/losses.png', bbox_inches='tight')
        plt.close()

    def test(self, rollout_len=500, rollouts=50, save_scores=True):
        """
        Measures reconstruction error as a function of rollout lengths over long rollouts
        rollout_len: length of rollout
        rollouts: number of rollouts over which to measure average reconstruction error
        """
        
        all_errors = []

        with torch.no_grad():
            for rollout in range(rollouts):
                obs_env = FlatWorld(env_parameters, period=PERIOD, radius=RADIUS)
                obs_x = obs_env.reset()
                reconstructed = obs_x
                errors = []
                for i in range(rollout_len):
                    action = obs_env.action_space.sample().item()
                    action_embedding = torch.zeros(4)
                    action_embedding[action] = 1
                    embedding = torch.cat((self.encoder(reconstructed),action_embedding))
                    reconstructed = self.decoder(embedding)
                    obs_x = obs_env.step(action)
                    errors.append(F.binary_cross_entropy(reconstructed, obs_x).detach())

                errors = torch.stack(errors)
                all_errors.append(errors)

            all_errors = torch.stack(all_errors)
            all_errors = all_errors.mean(axis=0)

        if save_scores:
            save_path = self.save_folder + '/scores.pkl'
        with open(save_path, 'wb+') as output:
            pickle.dump(all_errors, output, pickle.HIGHEST_PROTOCOL)
        plt.plot(all_errors)
        plt.savefig(self.save_folder + '/scores.png', bbox_inches='tight')
        plt.close()
        print("Results saved.")

        return all_errors

### Train RepresentationLearner

In [8]:
seeds = 20
errors = []
for seed in range(seeds):
    agent = RepresentationLearner(exp_name=str(seed))
    agent.learn(n_sgd_steps=5000)
    error = agent.test()
    errors.append(error)
errors = torch.stack(errors)
errors = errors.mean(axis=0)

iter 0 : loss=0.689 : last 10 iters in 0.045s
iter 100 : loss=0.180 : last 10 iters in 0.440s
iter 200 : loss=0.142 : last 10 iters in 0.314s
iter 300 : loss=0.127 : last 10 iters in 0.316s
iter 400 : loss=0.124 : last 10 iters in 0.299s
iter 500 : loss=0.117 : last 10 iters in 0.407s
iter 600 : loss=0.149 : last 10 iters in 0.409s
iter 700 : loss=0.096 : last 10 iters in 0.359s
iter 800 : loss=0.099 : last 10 iters in 0.324s
iter 900 : loss=0.065 : last 10 iters in 0.359s
iter 1000 : loss=0.086 : last 10 iters in 0.354s
iter 1100 : loss=0.058 : last 10 iters in 0.351s
iter 1200 : loss=0.069 : last 10 iters in 0.338s
iter 1300 : loss=0.066 : last 10 iters in 0.352s
iter 1400 : loss=0.049 : last 10 iters in 0.330s
iter 1500 : loss=0.045 : last 10 iters in 0.351s
iter 1600 : loss=0.047 : last 10 iters in 0.412s
iter 1700 : loss=0.036 : last 10 iters in 0.352s
iter 1800 : loss=0.033 : last 10 iters in 0.371s
iter 1900 : loss=0.030 : last 10 iters in 0.361s
iter 2000 : loss=0.032 : last 10

  mask = resized_img != 0


Results saved.
iter 0 : loss=0.701 : last 10 iters in 0.033s
iter 100 : loss=0.183 : last 10 iters in 0.402s
iter 200 : loss=0.153 : last 10 iters in 0.477s
iter 300 : loss=0.162 : last 10 iters in 0.422s
iter 400 : loss=0.124 : last 10 iters in 0.344s
iter 500 : loss=0.130 : last 10 iters in 0.286s
iter 600 : loss=0.110 : last 10 iters in 0.453s
iter 700 : loss=0.115 : last 10 iters in 0.407s
iter 800 : loss=0.100 : last 10 iters in 0.331s
iter 900 : loss=0.089 : last 10 iters in 0.303s
iter 1000 : loss=0.106 : last 10 iters in 0.403s
iter 1100 : loss=0.084 : last 10 iters in 0.489s
iter 1200 : loss=0.065 : last 10 iters in 0.358s
iter 1300 : loss=0.071 : last 10 iters in 0.317s
iter 1400 : loss=0.046 : last 10 iters in 0.349s
iter 1500 : loss=0.056 : last 10 iters in 0.298s
iter 1600 : loss=0.056 : last 10 iters in 0.406s
iter 1700 : loss=0.064 : last 10 iters in 0.370s
iter 1800 : loss=0.043 : last 10 iters in 0.469s
iter 1900 : loss=0.042 : last 10 iters in 0.380s
iter 2000 : loss=

iter 2800 : loss=0.025 : last 10 iters in 0.269s
iter 2900 : loss=0.026 : last 10 iters in 0.295s
iter 3000 : loss=0.026 : last 10 iters in 0.371s
iter 3100 : loss=0.027 : last 10 iters in 0.299s
iter 3200 : loss=0.026 : last 10 iters in 0.275s
iter 3300 : loss=0.025 : last 10 iters in 0.322s
iter 3400 : loss=0.024 : last 10 iters in 0.320s
iter 3500 : loss=0.027 : last 10 iters in 0.285s
iter 3600 : loss=0.026 : last 10 iters in 0.323s
iter 3700 : loss=0.025 : last 10 iters in 0.277s
iter 3800 : loss=0.024 : last 10 iters in 0.285s
iter 3900 : loss=0.024 : last 10 iters in 0.326s
iter 4000 : loss=0.025 : last 10 iters in 0.327s
iter 4100 : loss=0.026 : last 10 iters in 0.320s
iter 4200 : loss=0.024 : last 10 iters in 0.330s
iter 4300 : loss=0.024 : last 10 iters in 0.275s
iter 4400 : loss=0.026 : last 10 iters in 0.334s
iter 4500 : loss=0.024 : last 10 iters in 0.282s
iter 4600 : loss=0.024 : last 10 iters in 0.332s
iter 4700 : loss=0.024 : last 10 iters in 0.333s
iter 4800 : loss=0.0

iter 600 : loss=0.122 : last 10 iters in 0.354s
iter 700 : loss=0.106 : last 10 iters in 0.354s
iter 800 : loss=0.082 : last 10 iters in 0.359s
iter 900 : loss=0.089 : last 10 iters in 0.360s
iter 1000 : loss=0.101 : last 10 iters in 0.351s
iter 1100 : loss=0.089 : last 10 iters in 0.296s
iter 1200 : loss=0.099 : last 10 iters in 0.356s
iter 1300 : loss=0.090 : last 10 iters in 0.352s
iter 1400 : loss=0.080 : last 10 iters in 0.353s
iter 1500 : loss=0.072 : last 10 iters in 0.351s
iter 1600 : loss=0.044 : last 10 iters in 0.354s
iter 1700 : loss=0.044 : last 10 iters in 0.359s
iter 1800 : loss=0.055 : last 10 iters in 0.352s
iter 1900 : loss=0.039 : last 10 iters in 0.364s
iter 2000 : loss=0.040 : last 10 iters in 0.361s
iter 2100 : loss=0.038 : last 10 iters in 0.350s
iter 2200 : loss=0.037 : last 10 iters in 0.398s
iter 2300 : loss=0.035 : last 10 iters in 0.320s
iter 2400 : loss=0.040 : last 10 iters in 0.300s
iter 2500 : loss=0.030 : last 10 iters in 0.432s
iter 2600 : loss=0.028 :