# Testing predictive performance of our method on Flatland

This notebook tests the predictive performance of representations and dynamics learned using our method. Test results for several seeds are saved and can be plotted and compared against other methods using the fig6_createplot notebook

In [1]:
import os
import gym
import math
import numpy as np
import time
from PIL import Image
import matplotlib.pyplot as plt
from IPython import display
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pickle

os.chdir('src/flatland/flat_game/')
from env import Env


pygame 1.9.6
Hello from the pygame community. https://www.pygame.org/contribute.html
Loading chipmunk for Linux (64bit) [/home/william/.local/lib/python3.6/site-packages/pymunk/libchipmunk.so]


# Flatworld environment

We start by defining the flatworld environment, which is based on the code available at https://github.com/Caselles/NeurIPS19-SBDRL. This environment returns pixel observations of a ball on a cyclical 2D grid. The available (discrete) actions step the ball by a fixed amount in all four directions.

In [2]:
RADIUS = 15
PERIOD = 5

class FlatWorld():
    
    class action_space():
        def __init__(self,n_actions):
            self.n = n_actions
            
        def sample(self, k=1):
            return torch.randint(0,self.n,(k,))       

    class observation_space():
        def __init__(self):
            self.shape = [84,84]
    
    def __init__(self, env_parameters, period=10, radius=15):

        self.action_space = self.action_space(4)
        self.observation_space = self.observation_space()  
        self.period = period
        
        self.step_size = 0.1*(63-2*env_parameters['agent']['radius'])/period
        start_positions_list = [27 + 10*self.step_size*i for i in range(period)]
        self.start_positions = []
        for i in start_positions_list:
            for j in start_positions_list:
                self.start_positions.append((i,j))
                        
        env_parameters['agent']['radius'] = radius
        self.env = Env(**env_parameters)
        
    def reset(self, start_position=None):
        if start_position==None:
            #obs = self.env.reset(position=random.sample(self.start_positions, 1)[0])
            obs = self.env.reset(position=self.start_positions[12])
        else:
             obs = self.env.reset(position=start_position)
        return torch.FloatTensor(obs)/255
    
    def step(self, action):
        action_dict = self.create_action_dict(action)
        obs, reward, done, info = self.env.step(action_dict)
        return torch.FloatTensor(obs)/255
    
    def create_action_dict(self, action):
        action_dict = {}
        if action == 0:
            action_dict['longitudinal_velocity'] = 0
            action_dict['lateral_velocity'] = self.step_size
            action_dict['angular_velocity'] = 0
        if action == 1:
            action_dict['longitudinal_velocity'] = 0
            action_dict['lateral_velocity'] = -self.step_size
            action_dict['angular_velocity'] = 0
        if action == 2:
            action_dict['longitudinal_velocity'] = self.step_size
            action_dict['lateral_velocity'] = 0
            action_dict['angular_velocity'] = 0
        if action == 3:
            action_dict['longitudinal_velocity'] = -self.step_size
            action_dict['lateral_velocity'] = 0
            action_dict['angular_velocity'] = 0
        return action_dict
        

### Hyperparameters for Flatland

In [3]:
agent_parameters = {
    'radius': 15,
    'speed': 10,
    'rotation_speed' : math.pi/8,
    'living_penalty': 0,
    'position': (30,30),
    'angle': 0,
    'sensors': [
      
        {
           'nameSensor' : 'proximity_test',
           'typeSensor': 'proximity',
           'fovResolution': 64,
           'fovRange': 300,
           'fovAngle': math.pi ,
           'bodyAnchor': 'body',
           'd_r': 0,
           'd_theta': 0,
           'd_relativeOrientation': 0,
           'display': False,
        }
        
       
    ],
    'actions': ['forward', 'turn_left', 'turn_right', 'left', 'right', 'backward'],
    'measurements': ['health', 'poisons', 'fruits'],
    'texture': {
        'type': 'color',
        'c': (255, 255, 255)
    },
    'normalize_measurements': False,
    'normalize_states': False,
    'normalize_rewards': False
}

env_parameters = {
    'map':False,
    'n_rooms': 2,
    'display': False,
    'horizon': 10001,
    'shape': (84, 84),
    'mode': 'time',
    'poisons': {
        'number': 0,
        'positions': 'random',
        'size': 10,
        'reward': -10,
        'respawn': True,
        'texture': {
            'type': 'color',
            'c': (255, 255, 255),
        }
    },
    'fruits': {
        'number': 0,
        'positions': 'random',
        'size': 10,
        'reward': 10,
        'respawn': True,
        'texture': {
            'type': 'color',
            'c': (255, 150, 0),
        }
    },
    'obstacles': [
       
    ],
    'walls_texture': {
        'type': 'color',
        'c': (1, 1, 1)
    },
    'agent': agent_parameters
}

### Define Encoder and Decoder

In [4]:
class Encoder(nn.Module):

    def __init__(self, n_out=4, n_hid = 64):

        super().__init__()

        self.conv = nn.Conv2d(1, 5, 10, stride=3)
        self.fc1 = nn.Linear(180, n_hid)
        self.fc2 = nn.Linear(n_hid, n_out)

    def forward(self, x):
        x = F.relu(self.conv(x.unsqueeze(0).unsqueeze(1)))
        x = F.max_pool2d(x, 4, 4)
        x = x.view(-1, 180)
        x = F.relu(self.fc1(x))
        return F.normalize(self.fc2(x)).squeeze()

class Decoder(nn.Module):
    
    def __init__(self, n_in=4, n_hid = 64):

        super().__init__()
        
        self.fc1 = nn.Linear(n_in, n_hid)
        self.fc2 = nn.Linear(n_hid, 180)
        self.conv = nn.ConvTranspose2d(5, 1, 34, stride=10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = x.view(1,5,6,6)
        x = self.conv(x)
        return torch.sigmoid(x).squeeze()

### Now define representations, and a LatentWorld that mirrors the actual environment and on which the representations act

In [5]:
class Representation():

    def __init__(self, dim=4):
        self.dim = dim
        self.params = dim*(dim-1)//2
        self.thetas = torch.autograd.Variable(np.pi*(2*torch.rand(self.params)-1)/dim, requires_grad=True)

        self.__matrix = None
    
    def set_thetas(self, thetas):
        self.thetas = thetas
        self.thetas.requires_grad = True
        self.clear_matrix()
    
    def clear_matrix(self):
        self.__matrix = None
        
    def get_matrix(self):
        if self.__matrix is None:
            k = 0
            mats = []
            for i in range(self.dim-1):
                for j in range(self.dim-1-i):
                    theta_ij = self.thetas[k]
                    k+=1
                    c, s = torch.cos(theta_ij), torch.sin(theta_ij)

                    rotation_i = torch.eye(self.dim, self.dim)
                    rotation_i[i, i] = c
                    rotation_i[i, i+j+1] = s
                    rotation_i[j+i+1, i] = -s
                    rotation_i[j+i+1, j+i+1] = c

                    mats.append(rotation_i)

            def chain_mult(l):
                if len(l)>=3:
                    return l[0]@l[1]@chain_mult(l[2:])
                elif len(l)==2:
                    return l[0]@l[1]
                else:
                    return l[0]

            self.__matrix = chain_mult(mats)
                                    
        return self.__matrix

In [6]:
class LatentWorld():
    
    class action_space():
        def __init__(self,n_actions):
            self.n = n_actions
            
        def sample(self, k=1):
            return torch.randint(0,self.n,(k,))

    class observation_space():
        def __init__(self,n_features):
            self.shape = [n_features]
    
    def __init__(self,
                 dim=4,
                 n_actions=4,
                 action_reps=None):

        self.dim = dim

        self.action_space = self.action_space(n_actions)
        self.observation_space = self.observation_space(dim)
        
        if action_reps is None:
            self.action_reps = [Representation(dim=self.dim) for _ in range(n_actions)]
        else:
            if len(action_reps)!=n_actions:
                raise Exception("Must pass an action representation for every action.")
            if not all([rep.dim==self.dim]):
                raise Exception("Action representations do not act on the dimension of the latent space.")
            self.action_reps = action_reps
        
    def reset(self, state_init):
        self.state = state_init
        return self.get_observation()
    
    def clear_representations(self):
        for rep in self.action_reps:
            rep.clear_matrix()
            
    def get_representation_params(self):
        params = []
        for rep in self.action_reps:
            params.append(rep.thetas)
        return params
    
    def get_representations(self):
        return [rep.thetas for rep in self.action_reps]
    
    def save_representations(self, path):
        if os.path.splitext(path)[-1] != '.pth':
            path += '.pth'
        rep_thetas = [rep.thetas for rep in self.action_reps]
        return torch.save(rep_thetas, path)
    
    def load_reprentations(self, path):
        rep_thetas = torch.load(path)
        for rep in self.action_reps:
            rep.set_thetas(rep_thetas.pop(0))
            
    def get_observation(self):
        return self.state
    
    def step(self,action):
        self.state = torch.mv(self.action_reps[action].get_matrix(), self.state)
        obs = self.get_observation()
        return obs

# Define a RepresentationLearner that puts all these elements together

The RepresentationLearner, once initialized, can automatically be trained and tested on Flatland

In [7]:
def mk_dir(export_dir, quite=False):
    if not os.path.exists(export_dir):
        try:
            os.makedirs(export_dir)
            if not quite:
                print('Created directory: ', export_dir)
        except OSError as exc: # Guard against race condition
             if exc.errno != exc.errno.EEXIST:
                raise
        except Exception:
            pass
    else:
        if not quite:
            print('Directory already exists: ', export_dir)
            
def calc_entanglement(params):
    params = params.abs().pow(2)
    return params.sum() - params.max()

class RepresentationLearner:
    def __init__(self,
                 dim=4,
                 n_hid=64,
                 lr_dec=1e-2,
                 lr_rep=1e-2,
                 lr_enc=1e-2,
                 reg=0,
                 exp_name=None):
        
        self.n_hid = n_hid
        self.lr_dec = lr_dec
        self.lr_rep = lr_rep
        self.lr_enc = lr_enc
        self.reg = reg
        self.exp_name = exp_name
        
        mk_dir('results',quite=True)
        self.save_folder = 'results/ourmethod_' + self.exp_name
        mk_dir(self.save_folder, quite=True)
        
        self.obs_env = FlatWorld(env_parameters, period=PERIOD, radius=RADIUS)
        self.lat_env = LatentWorld(dim = 4,
                              n_actions = self.obs_env.action_space.n)
        self.decoder = Decoder(n_in = dim, n_hid = n_hid)
        self.encoder = Encoder(n_out = dim, n_hid = n_hid)

        self.optimizer_dec = optim.Adam(self.decoder.parameters(),
                                   lr=lr_dec,
                                   weight_decay=0)

        self.optimizer_enc = optim.Adam(self.encoder.parameters(),
                                   lr=lr_enc,
                                   weight_decay=0)

        self.optimizer_rep = optim.Adam(self.lat_env.get_representation_params(),
                                   lr=lr_rep,
                                   weight_decay=0)

        self.losses = []
        
    def learn(self, n_sgd_steps=5000, ep_steps=5, batch_eps=16):
        """
        Learns a representation using our method.
        n_sgd_steps: number of gradient descent steps
        ep_steps: number of sequential actions to consider
        batch_eps: number of episodes in batch
        """
        
        i = 0
        t_start = time.time()
        self.losses = []

        while i < n_sgd_steps:

            loss = torch.zeros(1)

            for _ in range(batch_eps):
                t_ep = -1
                while t_ep < ep_steps:
                    if t_ep == -1:
                        obs_x = self.obs_env.reset()
                        obs_z = self.lat_env.reset(self.encoder(obs_x))
                    else:
                        action = self.obs_env.action_space.sample().item()
                        obs_x = self.obs_env.step(action)
                        obs_z = self.lat_env.step(action)

                    t_ep += 1         

                    obs_x_recon = self.decoder(obs_z)

                    loss += F.binary_cross_entropy(obs_x_recon, obs_x)

            loss /= (ep_steps*batch_eps)
            raw_loss = loss.item()
            self.losses.append(raw_loss)
            
            reg_loss = sum([calc_entanglement(r.thetas) for r in self.lat_env.action_reps])/4
            loss += (reg_loss).abs() * self.reg

            self.optimizer_dec.zero_grad()
            self.optimizer_enc.zero_grad()
            self.optimizer_rep.zero_grad()
            loss.backward()
            self.optimizer_enc.step()
            self.optimizer_dec.step()
            self.optimizer_rep.step()

            # Remember to clear the cached action representations after we update the parameters!
            self.lat_env.clear_representations()

            i+=1

            if i%10==0:
                print("iter {} : loss={:.3f} : entanglement={:.2e} : last 10 iters in {:.3f}s".format(
                    i, raw_loss, reg_loss.item(), time.time() - t_start
                    ), end="\r" if i%100 else "\n")
                t_start = time.time()
                
        plt.plot(self.losses)
        plt.savefig(self.save_folder + '/losses.png', bbox_inches='tight')
        plt.close()
        
    def save(self, fname='rep_learner'):
        '''
        To fully save the state we need to save:
            history : (losses, entanglement losses, orthogonality_loss)
            representations
            decoder
            encoder
            optimizer_dec / optimizer_rep
        '''
        if os.path.splitext(fname)[-1]!='.pth':
            fname += '.pth'
        fname = os.path.join(self.save_folder, fname)
        
        state = {}
        
        state['history'] = self.losses
        state['representations'] = self.lat_env.get_representations()
        state['decoder'] = self.decoder.state_dict()
        state['encoder'] = self.encoder.state_dict()
        state['optimizer_dec'] = self.optimizer_dec.state_dict()
        state['optimizer_rep'] = self.optimizer_rep.state_dict()
        
        torch.save(state, fname)
        
        print("Saved RepresentationLearner state to", fname)
        
        return fname
    
    def plot_representations(self, save=True, print_matrices=False):
        """
        Plots and saves the learned representations
        """

        rep_thetas = [rep.thetas.detach().numpy() for rep in self.lat_env.action_reps]
        if print_matrices:
            for rep in self.lat_env.action_reps:
                print('MATRIX', rep.get_matrix())
                print('MATRIX POWER FIVE', torch.matrix_power(rep.get_matrix(),5))
        plt_lim = max( 0.22, max([max(t) for t in rep_thetas])/(2*np.pi) )
        titles = ["up", "down", "right", "left"]

        width=0.75
        with plt.style.context('seaborn-paper', after_reset=True):

            fig, axs = plt.subplots(1, len(rep_thetas), figsize=(8, 1.5), gridspec_kw={"wspace":0.4})

            for i, thetas in enumerate(rep_thetas):
                x = np.arange(len(thetas))
                axs[i].bar(x - width/2, thetas/(2*np.pi), width, label='Rep {}'.format(i))

                axs[i].set_xticks(x-0.25)
                axs[i].set_xticklabels(["12","13","14","23","24","34"])
                axs[i].set_xlabel("$ij$")

                axs[i].set_ylim(-plt_lim,plt_lim)

                axs[i].set_title(titles[i])

            axs[0].set_ylabel(r"$\theta / 2\pi$")

        if save:
            fig_fname = os.path.join(self.save_folder, "thetas")
            plt.savefig(fig_fname + ".png", bbox_inches='tight')
            plt.close()
            print("Representation thetas plots saved to", fig_fname+"(.pdf/.png)")
                
        return fig
                    
    def test(self, rollout_len=500, rollouts=50, save_scores=True):
        """
        Measures reconstruction error as a function of rollout lengths over long rollouts
        rollout_len: length of rollout
        rollouts: number of rollouts over which to measure average reconstruction error
        """
        
        all_errors = []
        with torch.no_grad():
            for rollout in range(rollouts):

                errors = []

                t_ep = -1
                while t_ep < rollout_len:

                    if t_ep == -1:
                        obs_x = self.obs_env.reset()
                        obs_z = self.lat_env.reset(self.encoder(obs_x))
                    else:
                        action = self.obs_env.action_space.sample().item()
                        obs_x = self.obs_env.step(action)
                        obs_z = self.lat_env.step(action)

                    t_ep += 1         

                    obs_x_recon = self.decoder(obs_z)

                    errors.append(F.binary_cross_entropy(obs_x_recon, obs_x).detach())

                errors = torch.stack(errors)

                all_errors.append(errors)

            all_errors = torch.stack(all_errors)
            all_errors = all_errors.mean(axis=0)
            
        if save_scores:
            save_path = self.save_folder + '/scores.pkl'
            with open(save_path, 'wb+') as output:
                pickle.dump(all_errors, output, pickle.HIGHEST_PROTOCOL)
            plt.plot(all_errors)
            plt.savefig(self.save_folder + '/scores.png', bbox_inches='tight')
            plt.close()
            print("Results saved.")
            
        return all_errors



### Train RepresentationLearner over several seeds and save results

In [10]:
seeds = 20
reg=1e-2
name = 'reg' #select 'reg' or 'noreg'

errors = []
for seed in range(seeds):
    agent = RepresentationLearner(
        exp_name=name+str(seed),
        reg=reg
    )
    agent.learn(n_sgd_steps=5000, ep_steps=10, batch_eps=8)
    agent.plot_representations()
    agent.save()
    error = agent.test()
    errors.append(error)
errors = torch.stack(errors)
errors = errors.mean(axis=0)

  mask = resized_img != 0


iter 100 : loss=0.177 : entanglement=1.57e-01 : last 10 iters in 4.163s
iter 200 : loss=0.128 : entanglement=1.08e-01 : last 10 iters in 2.078s
iter 300 : loss=0.070 : entanglement=1.01e-01 : last 10 iters in 2.068s
iter 400 : loss=0.065 : entanglement=5.49e-02 : last 10 iters in 2.066s
iter 500 : loss=0.054 : entanglement=2.86e-02 : last 10 iters in 2.085s
iter 600 : loss=0.051 : entanglement=1.39e-02 : last 10 iters in 2.047s
iter 700 : loss=0.051 : entanglement=6.48e-03 : last 10 iters in 2.071s
iter 800 : loss=0.053 : entanglement=2.88e-03 : last 10 iters in 2.066s
iter 900 : loss=0.052 : entanglement=1.38e-03 : last 10 iters in 2.122s
iter 1000 : loss=0.052 : entanglement=6.11e-04 : last 10 iters in 2.138s
iter 1100 : loss=0.050 : entanglement=3.67e-04 : last 10 iters in 2.163s
iter 1200 : loss=0.047 : entanglement=2.58e-04 : last 10 iters in 2.151s
iter 1300 : loss=0.049 : entanglement=9.57e-05 : last 10 iters in 2.193s
iter 1400 : loss=0.049 : entanglement=8.35e-05 : last 10 ite

  mask = resized_img != 0


iter 900 : loss=0.049 : entanglement=4.55e-01 : last 10 iters in 1.936s
iter 1000 : loss=0.049 : entanglement=4.58e-01 : last 10 iters in 1.947s
iter 1100 : loss=0.045 : entanglement=4.55e-01 : last 10 iters in 1.983s
iter 1200 : loss=0.044 : entanglement=3.74e-01 : last 10 iters in 2.095s
iter 1300 : loss=0.045 : entanglement=2.23e-01 : last 10 iters in 2.029s
iter 1400 : loss=0.042 : entanglement=9.98e-02 : last 10 iters in 2.011s
iter 1500 : loss=0.042 : entanglement=3.94e-02 : last 10 iters in 2.048s
iter 1600 : loss=0.041 : entanglement=1.50e-02 : last 10 iters in 2.043s
iter 1700 : loss=0.044 : entanglement=5.68e-03 : last 10 iters in 2.039s
iter 1800 : loss=0.040 : entanglement=2.22e-03 : last 10 iters in 2.051s
iter 1900 : loss=0.043 : entanglement=9.61e-04 : last 10 iters in 2.042s
iter 2000 : loss=0.040 : entanglement=4.71e-04 : last 10 iters in 2.020s
iter 2100 : loss=0.040 : entanglement=2.04e-04 : last 10 iters in 2.024s
iter 2200 : loss=0.037 : entanglement=1.09e-04 : las

  mask = resized_img != 0


iter 100 : loss=0.167 : entanglement=2.71e-01 : last 10 iters in 1.908s
iter 200 : loss=0.107 : entanglement=1.94e-01 : last 10 iters in 1.883s
iter 300 : loss=0.058 : entanglement=1.19e-01 : last 10 iters in 1.908s
iter 400 : loss=0.045 : entanglement=4.62e-02 : last 10 iters in 1.870s
iter 500 : loss=0.044 : entanglement=1.64e-02 : last 10 iters in 1.894s
iter 600 : loss=0.040 : entanglement=5.31e-03 : last 10 iters in 1.924s
iter 700 : loss=0.039 : entanglement=1.87e-03 : last 10 iters in 1.952s
iter 800 : loss=0.037 : entanglement=6.45e-04 : last 10 iters in 1.996s
iter 900 : loss=0.036 : entanglement=2.82e-04 : last 10 iters in 1.999s
iter 1000 : loss=0.036 : entanglement=3.22e-04 : last 10 iters in 2.041s
iter 1100 : loss=0.036 : entanglement=1.11e-04 : last 10 iters in 2.029s
iter 1200 : loss=0.034 : entanglement=2.99e-04 : last 10 iters in 2.043s
iter 1300 : loss=0.033 : entanglement=2.63e-04 : last 10 iters in 2.061s
iter 1400 : loss=0.034 : entanglement=1.29e-04 : last 10 ite

  mask = resized_img != 0


iter 300 : loss=0.066 : entanglement=1.94e-01 : last 10 iters in 1.914s
iter 400 : loss=0.054 : entanglement=1.24e-01 : last 10 iters in 1.883s
iter 500 : loss=0.050 : entanglement=7.63e-02 : last 10 iters in 1.881s
iter 600 : loss=0.046 : entanglement=4.59e-02 : last 10 iters in 1.896s
iter 700 : loss=0.042 : entanglement=2.60e-02 : last 10 iters in 1.905s
iter 800 : loss=0.048 : entanglement=1.45e-02 : last 10 iters in 1.950s
iter 900 : loss=0.042 : entanglement=7.86e-03 : last 10 iters in 1.960s
iter 1000 : loss=0.041 : entanglement=4.26e-03 : last 10 iters in 1.985s
iter 1100 : loss=0.041 : entanglement=2.24e-03 : last 10 iters in 1.985s
iter 1200 : loss=0.038 : entanglement=1.13e-03 : last 10 iters in 1.961s
iter 1300 : loss=0.039 : entanglement=5.69e-04 : last 10 iters in 2.009s
iter 1400 : loss=0.042 : entanglement=2.91e-04 : last 10 iters in 2.016s
iter 1500 : loss=0.040 : entanglement=2.15e-04 : last 10 iters in 1.998s
iter 1600 : loss=0.037 : entanglement=1.01e-04 : last 10 i

  mask = resized_img != 0


iter 1300 : loss=0.045 : entanglement=5.74e-03 : last 10 iters in 2.073s
iter 1400 : loss=0.043 : entanglement=4.04e-03 : last 10 iters in 2.053s
iter 1500 : loss=0.037 : entanglement=2.76e-03 : last 10 iters in 2.058s
iter 1600 : loss=0.045 : entanglement=1.61e-03 : last 10 iters in 2.049s
iter 1700 : loss=0.044 : entanglement=1.51e-03 : last 10 iters in 2.058s
iter 1800 : loss=0.044 : entanglement=8.16e-04 : last 10 iters in 2.075s
iter 1900 : loss=0.037 : entanglement=6.39e-04 : last 10 iters in 2.076s
iter 2000 : loss=0.041 : entanglement=4.68e-04 : last 10 iters in 2.096s
iter 2100 : loss=0.038 : entanglement=5.12e-04 : last 10 iters in 2.077s
iter 2200 : loss=0.040 : entanglement=2.81e-04 : last 10 iters in 2.046s
iter 2300 : loss=0.040 : entanglement=1.59e-04 : last 10 iters in 2.058s
iter 2400 : loss=0.036 : entanglement=1.55e-04 : last 10 iters in 2.086s
iter 2500 : loss=0.037 : entanglement=1.40e-04 : last 10 iters in 2.093s
iter 2600 : loss=0.038 : entanglement=1.02e-04 : la

  mask = resized_img != 0


iter 2500 : loss=0.039 : entanglement=2.29e-04 : last 10 iters in 2.110s
iter 2600 : loss=0.037 : entanglement=1.27e-04 : last 10 iters in 2.104s
iter 2700 : loss=0.038 : entanglement=1.84e-04 : last 10 iters in 2.105s
iter 2800 : loss=0.036 : entanglement=1.52e-04 : last 10 iters in 2.101s
iter 2900 : loss=0.039 : entanglement=4.07e-05 : last 10 iters in 2.109s
iter 3000 : loss=0.036 : entanglement=1.30e-04 : last 10 iters in 2.105s
iter 3100 : loss=0.042 : entanglement=1.07e-04 : last 10 iters in 2.097s
iter 3200 : loss=0.038 : entanglement=1.03e-04 : last 10 iters in 2.096s
iter 3300 : loss=0.037 : entanglement=1.76e-04 : last 10 iters in 2.105s
iter 3400 : loss=0.038 : entanglement=6.40e-05 : last 10 iters in 2.112s
iter 3500 : loss=0.040 : entanglement=2.73e-04 : last 10 iters in 2.103s
iter 3600 : loss=0.038 : entanglement=1.22e-04 : last 10 iters in 2.116s
iter 3700 : loss=0.041 : entanglement=1.31e-04 : last 10 iters in 2.129s
iter 3800 : loss=0.038 : entanglement=1.07e-04 : la

  mask = resized_img != 0


iter 200 : loss=0.066 : entanglement=2.33e-01 : last 10 iters in 1.901s
iter 300 : loss=0.051 : entanglement=9.85e-02 : last 10 iters in 1.930s
iter 400 : loss=0.048 : entanglement=3.23e-02 : last 10 iters in 1.883s
iter 500 : loss=0.049 : entanglement=9.00e-03 : last 10 iters in 1.907s
iter 600 : loss=0.046 : entanglement=2.20e-03 : last 10 iters in 1.929s
iter 700 : loss=0.046 : entanglement=5.16e-04 : last 10 iters in 1.978s
iter 800 : loss=0.042 : entanglement=1.39e-04 : last 10 iters in 1.975s
iter 900 : loss=0.042 : entanglement=9.45e-05 : last 10 iters in 2.004s
iter 1000 : loss=0.042 : entanglement=7.41e-05 : last 10 iters in 1.987s
iter 1100 : loss=0.044 : entanglement=8.88e-05 : last 10 iters in 2.026s
iter 1200 : loss=0.043 : entanglement=1.07e-04 : last 10 iters in 2.027s
iter 1300 : loss=0.042 : entanglement=9.62e-05 : last 10 iters in 2.036s
iter 1400 : loss=0.042 : entanglement=1.25e-04 : last 10 iters in 2.065s
iter 1500 : loss=0.039 : entanglement=1.05e-04 : last 10 it

  mask = resized_img != 0


iter 400 : loss=0.060 : entanglement=5.11e-02 : last 10 iters in 1.880s
iter 500 : loss=0.042 : entanglement=1.84e-02 : last 10 iters in 1.925s
iter 600 : loss=0.040 : entanglement=6.49e-03 : last 10 iters in 1.906s
iter 700 : loss=0.040 : entanglement=2.15e-03 : last 10 iters in 1.953s
iter 800 : loss=0.040 : entanglement=6.83e-04 : last 10 iters in 1.962s
iter 900 : loss=0.037 : entanglement=2.28e-04 : last 10 iters in 1.984s
iter 1000 : loss=0.039 : entanglement=1.33e-04 : last 10 iters in 1.979s
iter 1100 : loss=0.039 : entanglement=7.72e-05 : last 10 iters in 1.976s
iter 1200 : loss=0.036 : entanglement=1.30e-04 : last 10 iters in 1.984s
iter 1300 : loss=0.037 : entanglement=9.20e-05 : last 10 iters in 2.042s
iter 1400 : loss=0.036 : entanglement=1.92e-04 : last 10 iters in 2.047s
iter 1500 : loss=0.037 : entanglement=6.05e-05 : last 10 iters in 2.047s
iter 1600 : loss=0.038 : entanglement=1.05e-04 : last 10 iters in 2.030s
iter 1700 : loss=0.036 : entanglement=9.10e-05 : last 10 

  mask = resized_img != 0


iter 200 : loss=0.066 : entanglement=5.33e-01 : last 10 iters in 1.887s
iter 300 : loss=0.047 : entanglement=3.34e-01 : last 10 iters in 1.886s
iter 400 : loss=0.039 : entanglement=1.51e-01 : last 10 iters in 1.918s
iter 500 : loss=0.040 : entanglement=5.76e-02 : last 10 iters in 1.923s
iter 600 : loss=0.036 : entanglement=2.01e-02 : last 10 iters in 1.937s
iter 700 : loss=0.034 : entanglement=6.82e-03 : last 10 iters in 1.938s
iter 800 : loss=0.035 : entanglement=2.31e-03 : last 10 iters in 1.932s
iter 900 : loss=0.032 : entanglement=7.39e-04 : last 10 iters in 1.973s
iter 1000 : loss=0.035 : entanglement=3.26e-04 : last 10 iters in 1.950s
iter 1100 : loss=0.035 : entanglement=1.86e-04 : last 10 iters in 1.990s
iter 1200 : loss=0.034 : entanglement=1.35e-04 : last 10 iters in 1.964s
iter 1300 : loss=0.032 : entanglement=1.29e-04 : last 10 iters in 1.985s
iter 1400 : loss=0.031 : entanglement=2.13e-04 : last 10 iters in 2.076s
iter 1500 : loss=0.031 : entanglement=8.48e-05 : last 10 it

  mask = resized_img != 0


iter 700 : loss=0.050 : entanglement=7.66e-03 : last 10 iters in 1.932s
iter 800 : loss=0.046 : entanglement=2.05e-03 : last 10 iters in 1.934s
iter 900 : loss=0.046 : entanglement=5.96e-04 : last 10 iters in 1.994s
iter 1000 : loss=0.043 : entanglement=1.85e-04 : last 10 iters in 1.992s
iter 1100 : loss=0.046 : entanglement=1.95e-04 : last 10 iters in 2.049s
iter 1200 : loss=0.046 : entanglement=7.16e-05 : last 10 iters in 2.025s
iter 1300 : loss=0.042 : entanglement=7.76e-05 : last 10 iters in 2.044s
iter 1400 : loss=0.043 : entanglement=1.41e-04 : last 10 iters in 2.094s
iter 1500 : loss=0.043 : entanglement=7.32e-05 : last 10 iters in 2.102s
iter 1600 : loss=0.043 : entanglement=6.16e-05 : last 10 iters in 2.086s
iter 1700 : loss=0.046 : entanglement=8.74e-05 : last 10 iters in 2.111s
iter 1800 : loss=0.045 : entanglement=8.06e-05 : last 10 iters in 2.097s
iter 1900 : loss=0.041 : entanglement=1.67e-04 : last 10 iters in 2.106s
iter 2000 : loss=0.045 : entanglement=6.97e-05 : last 

In [11]:
seeds = 20
reg=0
name = 'noreg' #select 'reg' or 'noreg'

errors = []
for seed in range(seeds):
    agent = RepresentationLearner(
        exp_name=name+str(seed),
        reg=reg
    )
    agent.learn(n_sgd_steps=5000, ep_steps=10, batch_eps=8)
    agent.plot_representations()
    agent.save()
    error = agent.test()
    errors.append(error)
errors = torch.stack(errors)
errors = errors.mean(axis=0)

iter 100 : loss=0.155 : entanglement=5.32e-01 : last 10 iters in 1.915s
iter 200 : loss=0.118 : entanglement=6.55e-01 : last 10 iters in 1.939s
iter 300 : loss=0.062 : entanglement=1.02e+00 : last 10 iters in 1.939s
iter 400 : loss=0.050 : entanglement=1.04e+00 : last 10 iters in 1.894s
iter 500 : loss=0.047 : entanglement=1.04e+00 : last 10 iters in 1.917s
iter 600 : loss=0.044 : entanglement=1.05e+00 : last 10 iters in 1.917s
iter 700 : loss=0.043 : entanglement=1.05e+00 : last 10 iters in 1.925s
iter 800 : loss=0.041 : entanglement=1.05e+00 : last 10 iters in 1.952s
iter 900 : loss=0.043 : entanglement=1.06e+00 : last 10 iters in 1.962s
iter 1000 : loss=0.039 : entanglement=1.06e+00 : last 10 iters in 1.979s
iter 1100 : loss=0.039 : entanglement=1.06e+00 : last 10 iters in 1.949s
iter 1200 : loss=0.039 : entanglement=1.06e+00 : last 10 iters in 1.977s
iter 1300 : loss=0.039 : entanglement=1.06e+00 : last 10 iters in 1.966s
iter 1400 : loss=0.038 : entanglement=1.06e+00 : last 10 ite

  mask = resized_img != 0


iter 2400 : loss=0.039 : entanglement=1.06e+00 : last 10 iters in 2.116s
iter 2500 : loss=0.039 : entanglement=1.07e+00 : last 10 iters in 2.089s
iter 2600 : loss=0.037 : entanglement=1.05e+00 : last 10 iters in 2.116s
iter 2700 : loss=0.036 : entanglement=1.06e+00 : last 10 iters in 2.193s
iter 2800 : loss=0.036 : entanglement=1.06e+00 : last 10 iters in 2.133s
iter 2900 : loss=0.038 : entanglement=1.07e+00 : last 10 iters in 2.131s
iter 3000 : loss=0.038 : entanglement=1.05e+00 : last 10 iters in 2.141s
iter 3100 : loss=0.035 : entanglement=1.06e+00 : last 10 iters in 2.139s
iter 3200 : loss=0.039 : entanglement=1.07e+00 : last 10 iters in 2.145s
iter 3300 : loss=0.038 : entanglement=1.07e+00 : last 10 iters in 2.268s
iter 3400 : loss=0.034 : entanglement=1.07e+00 : last 10 iters in 2.127s
iter 3500 : loss=0.035 : entanglement=1.06e+00 : last 10 iters in 2.118s
iter 3600 : loss=0.037 : entanglement=1.06e+00 : last 10 iters in 2.093s
iter 3700 : loss=0.035 : entanglement=1.05e+00 : la

  mask = resized_img != 0


iter 2800 : loss=0.027 : entanglement=8.30e-01 : last 10 iters in 2.081s
iter 2900 : loss=0.028 : entanglement=8.30e-01 : last 10 iters in 2.109s
iter 3000 : loss=0.028 : entanglement=8.22e-01 : last 10 iters in 2.114s
iter 3100 : loss=0.028 : entanglement=8.26e-01 : last 10 iters in 2.118s
iter 3200 : loss=0.026 : entanglement=8.27e-01 : last 10 iters in 2.118s
iter 3300 : loss=0.026 : entanglement=8.26e-01 : last 10 iters in 2.111s
iter 3400 : loss=0.026 : entanglement=8.20e-01 : last 10 iters in 2.104s
iter 3500 : loss=0.027 : entanglement=8.23e-01 : last 10 iters in 2.084s
iter 3600 : loss=0.024 : entanglement=8.28e-01 : last 10 iters in 2.085s
iter 3700 : loss=0.027 : entanglement=8.31e-01 : last 10 iters in 2.092s
iter 3800 : loss=0.025 : entanglement=8.31e-01 : last 10 iters in 2.079s
iter 3900 : loss=0.027 : entanglement=8.28e-01 : last 10 iters in 2.081s
iter 4000 : loss=0.026 : entanglement=8.27e-01 : last 10 iters in 2.080s
iter 4100 : loss=0.028 : entanglement=8.23e-01 : la

  mask = resized_img != 0


iter 1700 : loss=0.033 : entanglement=5.95e-01 : last 10 iters in 2.038s
iter 1800 : loss=0.030 : entanglement=5.93e-01 : last 10 iters in 2.019s
iter 1900 : loss=0.033 : entanglement=5.91e-01 : last 10 iters in 2.014s
iter 2000 : loss=0.029 : entanglement=5.89e-01 : last 10 iters in 2.072s
iter 2100 : loss=0.032 : entanglement=5.90e-01 : last 10 iters in 2.044s
iter 2200 : loss=0.032 : entanglement=5.85e-01 : last 10 iters in 2.056s
iter 2300 : loss=0.033 : entanglement=5.82e-01 : last 10 iters in 2.062s
iter 2400 : loss=0.031 : entanglement=5.84e-01 : last 10 iters in 2.079s
iter 2500 : loss=0.030 : entanglement=5.87e-01 : last 10 iters in 2.061s
iter 2600 : loss=0.031 : entanglement=5.84e-01 : last 10 iters in 2.071s
iter 2700 : loss=0.033 : entanglement=5.88e-01 : last 10 iters in 2.073s
iter 2800 : loss=0.030 : entanglement=5.84e-01 : last 10 iters in 2.081s
iter 2900 : loss=0.028 : entanglement=5.85e-01 : last 10 iters in 2.086s
iter 3000 : loss=0.034 : entanglement=5.88e-01 : la

  mask = resized_img != 0


iter 1600 : loss=0.028 : entanglement=1.02e+00 : last 10 iters in 2.038s
iter 1700 : loss=0.029 : entanglement=1.02e+00 : last 10 iters in 2.021s
iter 1800 : loss=0.027 : entanglement=1.03e+00 : last 10 iters in 2.054s
iter 1900 : loss=0.027 : entanglement=1.02e+00 : last 10 iters in 2.039s
iter 2000 : loss=0.030 : entanglement=1.02e+00 : last 10 iters in 2.044s
iter 2100 : loss=0.029 : entanglement=1.03e+00 : last 10 iters in 2.036s
iter 2200 : loss=0.029 : entanglement=1.03e+00 : last 10 iters in 2.069s
iter 2300 : loss=0.029 : entanglement=1.03e+00 : last 10 iters in 2.088s
iter 2400 : loss=0.028 : entanglement=1.03e+00 : last 10 iters in 2.076s
iter 2500 : loss=0.028 : entanglement=1.03e+00 : last 10 iters in 2.081s
iter 2600 : loss=0.030 : entanglement=1.03e+00 : last 10 iters in 2.087s
iter 2700 : loss=0.028 : entanglement=1.03e+00 : last 10 iters in 2.066s
iter 2800 : loss=0.027 : entanglement=1.03e+00 : last 10 iters in 2.053s
iter 2900 : loss=0.027 : entanglement=1.03e+00 : la

  mask = resized_img != 0


iter 2500 : loss=0.027 : entanglement=8.06e-01 : last 10 iters in 2.107s
iter 2600 : loss=0.026 : entanglement=8.08e-01 : last 10 iters in 2.065s
iter 2700 : loss=0.026 : entanglement=8.10e-01 : last 10 iters in 2.121s
iter 2800 : loss=0.027 : entanglement=8.05e-01 : last 10 iters in 2.090s
iter 2900 : loss=0.026 : entanglement=8.01e-01 : last 10 iters in 2.199s
iter 3000 : loss=0.026 : entanglement=8.05e-01 : last 10 iters in 2.080s
iter 3100 : loss=0.026 : entanglement=8.06e-01 : last 10 iters in 2.091s
iter 3200 : loss=0.025 : entanglement=8.02e-01 : last 10 iters in 2.121s
iter 3300 : loss=0.025 : entanglement=8.02e-01 : last 10 iters in 2.108s
iter 3400 : loss=0.026 : entanglement=8.05e-01 : last 10 iters in 2.147s
iter 3500 : loss=0.025 : entanglement=7.93e-01 : last 10 iters in 2.086s
iter 3600 : loss=0.025 : entanglement=7.99e-01 : last 10 iters in 2.103s
iter 3700 : loss=0.024 : entanglement=8.03e-01 : last 10 iters in 2.094s
iter 3800 : loss=0.025 : entanglement=8.05e-01 : la

  mask = resized_img != 0


iter 100 : loss=0.138 : entanglement=7.81e-01 : last 10 iters in 1.925s
iter 200 : loss=0.112 : entanglement=8.18e-01 : last 10 iters in 1.915s
iter 300 : loss=0.051 : entanglement=8.76e-01 : last 10 iters in 1.932s
iter 400 : loss=0.044 : entanglement=8.87e-01 : last 10 iters in 1.888s
iter 500 : loss=0.039 : entanglement=8.93e-01 : last 10 iters in 1.995s
iter 600 : loss=0.033 : entanglement=8.92e-01 : last 10 iters in 1.942s
iter 700 : loss=0.032 : entanglement=8.94e-01 : last 10 iters in 1.953s
iter 800 : loss=0.030 : entanglement=8.95e-01 : last 10 iters in 1.931s
iter 900 : loss=0.028 : entanglement=8.94e-01 : last 10 iters in 1.999s
iter 1000 : loss=0.029 : entanglement=8.98e-01 : last 10 iters in 2.001s
iter 1100 : loss=0.027 : entanglement=8.99e-01 : last 10 iters in 2.046s
iter 1200 : loss=0.028 : entanglement=8.95e-01 : last 10 iters in 2.190s
iter 1300 : loss=0.030 : entanglement=8.93e-01 : last 10 iters in 2.020s
iter 1400 : loss=0.027 : entanglement=9.01e-01 : last 10 ite

  mask = resized_img != 0


Results saved.
iter 100 : loss=0.156 : entanglement=7.21e-01 : last 10 iters in 1.909s
iter 200 : loss=0.061 : entanglement=1.12e+00 : last 10 iters in 1.879s
iter 300 : loss=0.050 : entanglement=1.12e+00 : last 10 iters in 1.942s
iter 400 : loss=0.041 : entanglement=1.12e+00 : last 10 iters in 1.927s
iter 500 : loss=0.041 : entanglement=1.13e+00 : last 10 iters in 1.918s
iter 600 : loss=0.038 : entanglement=1.13e+00 : last 10 iters in 1.984s
iter 700 : loss=0.036 : entanglement=1.14e+00 : last 10 iters in 2.002s
iter 800 : loss=0.035 : entanglement=1.13e+00 : last 10 iters in 2.005s
iter 900 : loss=0.033 : entanglement=1.13e+00 : last 10 iters in 2.052s
iter 1000 : loss=0.032 : entanglement=1.14e+00 : last 10 iters in 2.042s
iter 1100 : loss=0.031 : entanglement=1.15e+00 : last 10 iters in 2.065s
iter 1200 : loss=0.030 : entanglement=1.14e+00 : last 10 iters in 2.059s
iter 1300 : loss=0.030 : entanglement=1.15e+00 : last 10 iters in 2.074s
iter 1400 : loss=0.030 : entanglement=1.14e+0

  mask = resized_img != 0


iter 100 : loss=0.141 : entanglement=7.11e-01 : last 10 iters in 1.896s
iter 200 : loss=0.062 : entanglement=1.02e+00 : last 10 iters in 1.903s
iter 300 : loss=0.052 : entanglement=1.03e+00 : last 10 iters in 1.925s
iter 400 : loss=0.047 : entanglement=1.03e+00 : last 10 iters in 1.958s
iter 500 : loss=0.047 : entanglement=1.04e+00 : last 10 iters in 1.944s
iter 600 : loss=0.043 : entanglement=1.04e+00 : last 10 iters in 1.948s
iter 700 : loss=0.041 : entanglement=1.04e+00 : last 10 iters in 1.959s
iter 800 : loss=0.040 : entanglement=1.04e+00 : last 10 iters in 2.008s
iter 900 : loss=0.040 : entanglement=1.04e+00 : last 10 iters in 2.018s
iter 1000 : loss=0.042 : entanglement=1.04e+00 : last 10 iters in 1.982s
iter 1100 : loss=0.042 : entanglement=1.04e+00 : last 10 iters in 2.003s
iter 1200 : loss=0.039 : entanglement=1.04e+00 : last 10 iters in 2.031s
iter 1300 : loss=0.042 : entanglement=1.04e+00 : last 10 iters in 2.037s
iter 1400 : loss=0.040 : entanglement=1.04e+00 : last 10 ite

  mask = resized_img != 0


iter 200 : loss=0.079 : entanglement=5.10e-01 : last 10 iters in 1.897s
iter 300 : loss=0.054 : entanglement=5.93e-01 : last 10 iters in 1.943s
iter 400 : loss=0.049 : entanglement=6.00e-01 : last 10 iters in 1.935s
iter 500 : loss=0.047 : entanglement=6.05e-01 : last 10 iters in 1.897s
iter 600 : loss=0.047 : entanglement=6.10e-01 : last 10 iters in 1.974s
iter 700 : loss=0.044 : entanglement=6.15e-01 : last 10 iters in 1.977s
iter 800 : loss=0.041 : entanglement=6.17e-01 : last 10 iters in 2.002s
iter 900 : loss=0.042 : entanglement=6.19e-01 : last 10 iters in 1.980s
iter 1000 : loss=0.038 : entanglement=6.17e-01 : last 10 iters in 2.007s
iter 1100 : loss=0.042 : entanglement=6.19e-01 : last 10 iters in 2.042s
iter 1200 : loss=0.040 : entanglement=6.23e-01 : last 10 iters in 2.062s
iter 1300 : loss=0.038 : entanglement=6.20e-01 : last 10 iters in 2.074s
iter 1400 : loss=0.038 : entanglement=6.23e-01 : last 10 iters in 2.097s
iter 1500 : loss=0.039 : entanglement=6.21e-01 : last 10 it

  mask = resized_img != 0


iter 800 : loss=0.034 : entanglement=4.75e-01 : last 10 iters in 1.991s
iter 900 : loss=0.034 : entanglement=4.77e-01 : last 10 iters in 2.003s
iter 1000 : loss=0.033 : entanglement=4.78e-01 : last 10 iters in 2.067s
iter 1100 : loss=0.034 : entanglement=4.80e-01 : last 10 iters in 2.044s
iter 1200 : loss=0.036 : entanglement=4.76e-01 : last 10 iters in 2.085s
iter 1300 : loss=0.032 : entanglement=4.73e-01 : last 10 iters in 2.092s
iter 1400 : loss=0.031 : entanglement=4.71e-01 : last 10 iters in 2.107s
iter 1500 : loss=0.031 : entanglement=4.74e-01 : last 10 iters in 2.121s
iter 1600 : loss=0.027 : entanglement=4.75e-01 : last 10 iters in 2.081s
iter 1700 : loss=0.034 : entanglement=4.73e-01 : last 10 iters in 2.186s
iter 1800 : loss=0.034 : entanglement=4.71e-01 : last 10 iters in 2.116s
iter 1900 : loss=0.032 : entanglement=4.69e-01 : last 10 iters in 2.090s
iter 2000 : loss=0.031 : entanglement=4.77e-01 : last 10 iters in 2.127s
iter 2100 : loss=0.033 : entanglement=4.78e-01 : last