In [None]:
# Reward Induced Program Synthesis - RIPS 001 "STRANGE MATTER"
!pip install git+https://github.com/ayaz-amin/schema-rl.git

Collecting git+https://github.com/ayaz-amin/schema-rl.git
  Cloning https://github.com/ayaz-amin/schema-rl.git to /tmp/pip-req-build-uggcn25v
  Running command git clone -q https://github.com/ayaz-amin/schema-rl.git /tmp/pip-req-build-uggcn25v
Collecting pygame
[?25l  Downloading https://files.pythonhosted.org/packages/87/4c/2ebe8ab1a695a446574bc48d96eb3503649893be8c769e7fafd65fd18833/pygame-2.0.0-cp36-cp36m-manylinux1_x86_64.whl (11.5MB)
[K     |████████████████████████████████| 11.5MB 7.1MB/s 
Building wheels for collected packages: schema-games
  Building wheel for schema-games (setup.py) ... [?25l[?25hdone
  Created wheel for schema-games: filename=schema_games-1.0.0-cp36-none-any.whl size=28598 sha256=306346abdec71bf0734721177c98da32e2824049de0d21ee46062651cb315366
  Stored in directory: /tmp/pip-ephem-wheel-cache-4beobeub/wheels/4a/fe/a0/6800016926ff46b11b889e96271961e5c5947e7f7e67c99435
Successfully built schema-games
Installing collected packages: pygame, schema-games
Succe

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.distributions as dist


class EntityExtractor(nn.Module):
    def __init__(self, input_channels, num_objects):
        super(EntityExtractor, self).__init__()

        self.input_channels = input_channels
        self.filters = nn.Sequential(
                nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
                nn.ReLU(True),
                nn.MaxPool2d((2, 2)),
                nn.Conv2d(32, 64, kernel_size=3, padding=1),
                nn.ReLU(True),
                nn.MaxPool2d((2, 2)),
                nn.Conv2d(64, num_objects, kernel_size=1),
                nn.Sigmoid()
                )

    def obs_to_torch(self, obs):
        height, width = obs.shape[0], obs.shape[1]
        obs = torch.from_numpy(obs.copy()).float()
        return obs.view(1, self.input_channels, height, width)

    def parsed_objects(self, z):
        object_blackboard = torch.zeros(z.shape[2], z.shape[3])
        
        z = z.view(z.shape[1], z.shape[2], z.shape[3])
        for object_idx in range(z.shape[0]):
            for r in range(z.shape[1]):
                for c in range(z.shape[2]):
                    if z[object_idx, r, c] != 0:
                        object_blackboard[r][c] = object_idx

        return object_blackboard.detach().numpy()

    def forward(self, obs):
        obs = self.obs_to_torch(obs)
        z = self.filters(obs)
        return self.parsed_objects(z)

In [None]:
def out_of_bounds(r, c, shape):
    return (r < 0 or c < 0 or r >= shape[0] or c >= shape[1])

def shifted(direction, local_program, cell, obs):
    if cell is None:
        new_cell = None
    else:
        new_cell = (cell[0] + direction[0], cell[1] + direction[1])
    return local_program(new_cell, obs)

def cell_is_value(value, cell, obs):
    if cell is None or out_of_bounds(cell[0], cell[1], obs.shape):
        focus = None
    else:
        focus = obs[cell[0], cell[1]]

    return (focus == value)

def at_cell_with_value(value, local_program, obs):
    matches = np.argwhere(obs == value)
    if len(matches) == 0:
        cell = None
    else:
        cell = matches[0]
    return local_program(cell, obs)

def scanning(direction, true_condition, false_condition, cell, obs, max_timeout=50):
    if cell is None:
        return False

    for _ in range(max_timeout):
        cell = (cell[0] + direction[0], cell[1] + direction[1])

        if true_condition(cell, obs):
            return True

        if false_condition(cell, obs):
            return False

        # prevent infinite loops
        if out_of_bounds(cell[0], cell[1], obs.shape):
            return False

    return False


# My classes
class Model(nn.Module):
    # Container for program synthesis model
    def __init__(self, input_channels, object_types, action_types, num_programs):
        super(Model, self).__init__()

        self.feature_extractor = EntityExtractor(input_channels, object_types)
        self.action_types = action_types
        self.programs = nn.ModuleList()
        for i in range(num_programs):
            self.programs.append(AtActionCell(object_types, action_types))

    def forward(self, obs):
        obs = self.feature_extractor(obs)
        action_probs = torch.zeros(self.action_types)
        for r in range(obs.shape[0]):
            for c in range(obs.shape[1]):
                for program in self.programs:
                    condition, action = program((r, c), obs)
                    if condition:
                        action_probs[action] += 1

        normalized_action_probs = F.log_softmax(action_probs, dim=0)
        return dist.Categorical(normalized_action_probs) 


class AtActionCell(nn.Module):
    def __init__(self, object_types, action_types):
        super(AtActionCell, self).__init__()
        self.object_types = nn.Parameter(torch.ones(object_types))
        self.positive_object_types = nn.Parameter(torch.ones(object_types))
        self.negative_object_types = nn.Parameter(torch.ones(object_types))

        self.action_types = nn.Parameter(torch.ones(action_types))
        self.direction_types = nn.Parameter(torch.ones(8))
        self.directions = [
            (1, 0), (0, 1),
            (-1, 0), (0, -1),
            (1, 1), (-1, 1),
            (1, -1), (-1, -1)
        ]

    def forward(self, cell, obs):
        # Sample function parameters
        object_probs = F.log_softmax(self.object_types, dim=0)
        positive_object_probs = F.log_softmax(self.positive_object_types, dim=0)
        negative_object_probs = F.log_softmax(self.negative_object_types, dim=0)

        action_probs = F.log_softmax(self.action_types, dim=0)
        direction_probs = F.log_softmax(self.direction_types, dim=0)

        sample_object = dist.Categorical(object_probs).sample()
        sample_positive_object = dist.Categorical(positive_object_probs).sample()
        sample_negative_object = dist.Categorical(negative_object_probs).sample()

        sample_action = dist.Categorical(action_probs).sample()
        sample_direction = dist.Categorical(direction_probs).sample()

        direction = self.directions[sample_direction]

        # Main program
        condition = at_cell_with_value(
            sample_object, 
            lambda cell, obs : scanning(
                direction,
                lambda cell, obs : cell_is_value(sample_positive_object, cell, obs),
                lambda cell, obs : cell_is_value(sample_negative_object, cell, obs),
                cell,
                obs
            ),
            obs
        )

        return condition, sample_action

In [None]:
import torch.optim as optim

from gym.wrappers import Monitor
from schema_games.breakout.games import StandardBreakout

model = Model(input_channels=3, object_types=5, action_types=3, num_programs=10)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

env = Monitor(StandardBreakout(return_state_as_image=True), 'video', force=True)

for epoch in range(10):
    obs = env.reset()
    loss = []
    while True:
        env.render()
        action_probs = model(obs)
        action = action_probs.sample()
        #action = env.action_space.sample()
        obs, reward, done, _ = env.step(action.item())
        loss.append(-action_probs.log_prob(action) * reward)
        if done:
            optimizer.zero_grad()
            loss = torch.tensor(sum(loss), requires_grad=True) / len(loss)
            loss.backward()
            optimizer.step()
            print(loss)
            loss = []
            break

env.close()



[31m[---] Lives remaining:[0m 2
[31m[---] Lives remaining:[0m 1
[31m[---] Game over! You lost.[0m
[31m********************************************************************************[0m
tensor(-0.0062, grad_fn=<DivBackward0>)




[31m[---] Lives remaining:[0m 2
[31m[---] Lives remaining:[0m 1
[31m[---] Game over! You lost.[0m
[31m********************************************************************************[0m
tensor(-0.0188, grad_fn=<DivBackward0>)




[31m[---] Lives remaining:[0m 2
[31m[---] Lives remaining:[0m 1
[31m[---] Game over! You lost.[0m
[31m********************************************************************************[0m
tensor(0., grad_fn=<DivBackward0>)
[31m[---] Lives remaining:[0m 2
[31m[---] Lives remaining:[0m 1
[31m[---] Game over! You lost.[0m
[31m********************************************************************************[0m
tensor(-0.0578, grad_fn=<DivBackward0>)
[31m[---] Lives remaining:[0m 2
[31m[---] Lives remaining:[0m 1
[31m[---] Game over! You lost.[0m
[31m********************************************************************************[0m
tensor(-0.0035, grad_fn=<DivBackward0>)
[31m[---] Lives remaining:[0m 2
[31m[---] Lives remaining:[0m 1
[31m[---] Game over! You lost.[0m
[31m********************************************************************************[0m
tensor(0.0060, grad_fn=<DivBackward0>)
[31m[---] Lives remaining:[0m 2
[31m[---] Lives remaining:[0m 1
[



[31m[---] Lives remaining:[0m 2
[31m[---] Lives remaining:[0m 1
[31m[---] Game over! You lost.[0m
[31m********************************************************************************[0m
tensor(-0.0062, grad_fn=<DivBackward0>)


In [None]:
import io
import base64
from IPython.display import HTML

video = io.open('video/openaigym.video.%s.video000000.mp4' % env.file_infix, 'r+b').read()
encoded = base64.b64encode(video)
HTML(data='''
    <video width="54" height="auto" alt="test" controls><source src="data:video/mp4;base64,{0}" type="video/mp4" /></video>'''
.format(encoded.decode('ascii')))

In [None]:
torch.save(model.state_dict(), 'rips_sidereal.pt')