# Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from abc_py.interface import ABC
import functools
import os
from torch.distributions import Categorical

# Definitions

In [None]:
possible_actions = [
    functools.partial(ABC.resub, zero_cost=False),
    functools.partial(ABC.resub, zero_cost=True),
    functools.partial(ABC.rewrite, zero_cost=False, preserve_levels=True, verbose=False),
    functools.partial(ABC.rewrite, zero_cost=True, preserve_levels=True, verbose=False),
    functools.partial(ABC.refactor, zero_cost=False),
    functools.partial(ABC.refactor, zero_cost=True),
    functools.partial(ABC.balance),
]

num_actions = len(possible_actions)
num_features = 8
learning_rate = 0.01
discount_factor = 0.99

In [None]:
class ActorNetwork(nn.Module):
    def __init__(self, num_features, num_actions):
        super(ActorNetwork, self).__init__()
        self.fc1 = nn.Linear(num_features, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, num_actions)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return F.softmax(self.fc3(x), dim=-1)

In [None]:
class CriticNetwork(nn.Module):
    def __init__(self, num_features):
        super(CriticNetwork, self).__init__()
        self.fc1 = nn.Linear(num_features, 10)
        self.fc2 = nn.Linear(10, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

In [None]:
def select_action(actor: ActorNetwork, state: torch.Tensor):
    action_probs = actor(torch.tensor(state).float())
    action_distribution = Categorical(action_probs)
    action = action_distribution.sample()
    return action.item(), action_distribution.log_prob(action).reshape(1)

In [None]:
def calculate_losses(critic: CriticNetwork, action_log_prob: torch.Tensor, reward: int, state: torch.Tensor, next_state: torch.Tensor):
    value = critic(state)
    next_value = critic(next_state)
    td_error = reward + discount_factor * next_value - value
    actor_loss = -action_log_prob * td_error.detach()
    critic_loss = td_error ** 2
    return actor_loss, critic_loss

In [None]:
def calculate_reward(state: torch.Tensor, next_state: torch.Tensor):
    # run asap7
    # refer to reward table in the paper
    pass

In [None]:
def train(actor: ActorNetwork, critic: CriticNetwork, actor_optimizer, critic_optimizer, episodes=50, iterations=50):
    for episode in range(episodes):
        abc = ABC()
        init_stats = abc.read_aiger("i10.aig")
        assert len(init_stats) == num_features

        state = torch.tensor(init_stats[:2] + [1] * (num_features - 2), dtype=torch.float)
        init_stats = torch.tensor(init_stats, dtype=torch.float)

        for i in range(iterations):
            # select action from actor model
            action, action_log_prob = select_action(actor, state)

            # take action and observe next state
            action_to_be_taken = possible_actions[action]
            new_stats = action_to_be_taken(abc)
            next_state = torch.tensor(new_stats, dtype=torch.float) / init_stats # take ratio with respect to initial stats
            next_state[0] = state[0] # keep the number of inputs same
            next_state[1] = state[1] # keep the number of outputs same

            # calculate reward and update actor and critic models
            reward = calculate_reward(state, next_state)
            actor_loss, critic_loss = calculate_losses(critic, action_log_prob, reward, state, next_state)
            actor_optimizer.zero_grad()
            actor_loss.backward()
            actor_optimizer.step()
            critic_optimizer.zero_grad()
            critic_loss.backward()
            critic_optimizer.step()

            # update state
            state = next_state

        if abc.quit() != 0:
            print("Error in quitting abc")
            return

In [None]:
actor = ActorNetwork(num_features, num_actions)
critic = CriticNetwork(num_features)
actor_optimizer = optim.Adam(actor.parameters(), lr=learning_rate)
critic_optimizer = optim.Adam(critic.parameters(), lr=learning_rate)