# Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from abc_py.interface import ABC
import functools
import os
from torch.distributions import Categorical

# Definitions

In [2]:
possible_actions = [
    functools.partial(ABC.resub, zero_cost=False, preserve_levels=False),
    functools.partial(ABC.resub, zero_cost=False, preserve_levels=True),
    functools.partial(ABC.resub, zero_cost=True, preserve_levels=False),
    functools.partial(ABC.resub, zero_cost=True, preserve_levels=True),
    functools.partial(ABC.rewrite, zero_cost=False, preserve_levels=True, verbose=False),
    functools.partial(ABC.rewrite, zero_cost=True, preserve_levels=True, verbose=False),
    functools.partial(ABC.rewrite, zero_cost=False, preserve_levels=False, verbose=False),
    functools.partial(ABC.rewrite, zero_cost=True, preserve_levels=False, verbose=False),
    functools.partial(ABC.refactor, zero_cost=False, preserve_levels=True),
    functools.partial(ABC.refactor, zero_cost=False, preserve_levels=False),
    functools.partial(ABC.refactor, zero_cost=True, preserve_levels=True),
    functools.partial(ABC.refactor, zero_cost=True, preserve_levels=False),
    functools.partial(ABC.balance),
]

num_actions = len(possible_actions)
num_features = 8
learning_rate = 0.01
discount_factor = 0.99

In [3]:
class ActorNetwork(nn.Module):
    def __init__(self, num_features, num_actions):
        super(ActorNetwork, self).__init__()
        self.fc1 = nn.Linear(num_features, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, num_actions)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return F.softmax(self.fc3(x), dim=-1)

In [4]:
class CriticNetwork(nn.Module):
    def __init__(self, num_features):
        super(CriticNetwork, self).__init__()
        self.fc1 = nn.Linear(num_features, 10)
        self.fc2 = nn.Linear(10, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

In [5]:
def select_action(actor: ActorNetwork, state: torch.Tensor):
    action_probs = actor(state)
    action_distribution = Categorical(action_probs)
    action = action_distribution.sample()
    return action.item(), action_distribution.log_prob(action).reshape(1), action_probs

In [6]:
def calculate_losses(critic: CriticNetwork, action_log_prob: torch.Tensor, reward: float, state: torch.Tensor, next_state: torch.Tensor):
    value = critic(state)
    next_value = critic(next_state)
    td_error = reward + discount_factor * next_value - value
    actor_loss = -action_log_prob * td_error.detach()
    critic_loss = td_error ** 2
    return actor_loss, critic_loss

In [7]:
def calculate_reward(state: torch.Tensor, next_state: torch.Tensor):
    old_levels, old_nodes = state[4], state[5]
    new_levels, new_nodes = next_state[4], next_state[5]
    delay_estimate = new_levels / old_levels
    area_estimate = new_nodes / old_nodes
    reward = 0

    if delay_estimate < 1:
        if area_estimate < 1:
            reward = 3
        elif area_estimate == 1:
            reward = 2
        else:
            reward = 1
    elif delay_estimate == 1:
        if area_estimate < 1:
            reward = 2
        elif area_estimate == 1:
            reward = 0
        else:
            reward = -2
    else:
        if area_estimate < 1:
            reward = 1
        elif area_estimate == 1:
            reward = -2
        else:
            reward = -3

    return reward

In [8]:
def train(actor: ActorNetwork, critic: CriticNetwork, actor_optimizer, critic_optimizer, episodes=50, iterations=50):
    for episode in range(episodes):
        total_actor_loss = 0
        total_critic_loss = 0

        for dir in os.listdir("./benchmarks"):
            if dir != "arithmetic":
                continue
            
            for filename in os.listdir(f"./benchmarks/{dir}"):
                if not filename.endswith(".aig"):
                    continue

                abc = ABC()
                init_stats = abc.read_aiger(f"./benchmarks/{dir}/{filename}")
                assert len(init_stats) == num_features

                state = torch.tensor(init_stats[:2] + [1] * (num_features - 2), dtype=torch.float)
                init_stats = torch.tensor(init_stats, dtype=torch.float)

                for i in range(iterations):
                    # select action from actor model
                    action, action_log_prob, _ = select_action(actor, state)

                    # take action and observe next state
                    action_to_be_taken = possible_actions[action]
                    new_stats = action_to_be_taken(abc)
                    next_state = torch.tensor(new_stats, dtype=torch.float) / init_stats # take ratio with respect to initial stats
                    next_state[0] = state[0] # keep the number of inputs same
                    next_state[1] = state[1] # keep the number of outputs same

                    if init_stats[2] == 0:
                        next_state[2] = 0

                    # calculate reward and update actor and critic models
                    reward = calculate_reward(state, next_state)
                    actor_loss, critic_loss = calculate_losses(critic, action_log_prob, reward, state, next_state)
                    total_actor_loss += actor_loss.detach().item()
                    total_critic_loss += critic_loss.detach().item()

                    actor_optimizer.zero_grad()
                    actor_loss.backward()
                    actor_optimizer.step()
                    critic_optimizer.zero_grad()
                    critic_loss.backward()
                    critic_optimizer.step()

                    # update state
                    state = next_state

        total_actor_loss /= iterations
        total_critic_loss /= iterations
        print(f"Episode {episode + 1}: Actor Loss: {total_actor_loss}, Critic Loss: {total_critic_loss}")
        
        if abc.quit() != 0:
            print("Error in quitting abc")
            return

In [9]:
actor = ActorNetwork(num_features, num_actions)
critic = CriticNetwork(num_features)
actor_optimizer = optim.Adam(actor.parameters(), lr=learning_rate)
critic_optimizer = optim.Adam(critic.parameters(), lr=learning_rate)

In [None]:
train(actor, critic, actor_optimizer, critic_optimizer)

Episode 1: Actor Loss: 0.002944391734402212, Critic Loss: 1.4680209745616548
Episode 2: Actor Loss: 5.177456864296701e-08, Critic Loss: 1.3974336014125912
Episode 3: Actor Loss: -2.1068808350455104e-08, Critic Loss: 1.3542953108400184
Episode 4: Actor Loss: -1.5668985365565118e-08, Critic Loss: 0.9317612949688495
Episode 5: Actor Loss: -8.567338237980282e-08, Critic Loss: 1.536460804142838
Episode 6: Actor Loss: -6.338802307007119e-08, Critic Loss: 1.3063443870092137
Episode 7: Actor Loss: -6.458423523730827e-08, Critic Loss: 1.2414588668119684
Episode 8: Actor Loss: -6.047178431979606e-08, Critic Loss: 1.1329745596959864
Episode 9: Actor Loss: -4.03291656564061e-08, Critic Loss: 0.9921232892846809
Episode 10: Actor Loss: -1.1910036661760825e-09, Critic Loss: 0.8952761254914475
Episode 11: Actor Loss: 2.2132935308810663e-08, Critic Loss: 0.8773737777642459
Episode 12: Actor Loss: 6.34354469617555e-08, Critic Loss: 1.2192334638896138
Episode 13: Actor Loss: 1.0260956356145102e-07, Criti

In [None]:
def generate_strategy(actor: ActorNetwork, file):
    abc = ABC()
    init_stats = abc.read_aiger(file)
    assert len(init_stats) == num_features

    state = torch.tensor(init_stats[:2] + [1] * (num_features - 2), dtype=torch.float)
    init_stats = torch.tensor(init_stats, dtype=torch.float)

    for i in range(50):
        action, _, prob = select_action(actor, state)
        print(prob)
        action_to_be_taken = possible_actions[action]
        print(f"Taking action {action}")
        new_stats = action_to_be_taken(abc)
        next_state = torch.tensor(new_stats, dtype=torch.float) / init_stats
        next_state[0] = state[0]
        next_state[1] = state[1]

        if init_stats[2] == 0:
            next_state[2] = 0

        state = next_state

    print(state)
    if abc.quit() != 0:
        print("Error in quitting abc")

In [None]:
generate_strategy(actor, "./benchmarks/arithmetic/adder.aig")

In [None]:
# abc = ABC()
# init_stats = abc.read_aiger("i10.aig")
# assert len(init_stats) == num_features

# state = torch.tensor(init_stats[:2] + [1] * (num_features - 2), dtype=torch.float)
# init_stats = torch.tensor(init_stats, dtype=torch.float)
# total_actor_loss = 0
# total_critic_loss = 0

# action_probs = actor(state)
# action_distribution = Categorical(action_probs)
# action = action_distribution.sample()
# action_log_prob = action_distribution.log_prob(action).reshape(1)
# action_to_be_taken = possible_actions[action]
# new_stats = action_to_be_taken(abc)
# next_state = torch.tensor(new_stats, dtype=torch.float) / init_stats
# next_state[0] = state[0]
# next_state[1] = state[1]

# if init_stats[2] == 0:
#     next_state[2] = 0
# print(next_state)

# abc.quit()