In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.optim import Adam
from abc_py.interface import ABC
import functools
import os

In [None]:
class AdvantageActorCritic(nn.Module):
    def __init__(self, input_dim, n_actions, gamma):
        super(AdvantageActorCritic, self).__init__()
        self.gamma = gamma

        self.pi1 = nn.Linear(input_dim, 20)
        self.pi2 = nn.Linear(20, 20)
        self.pi = nn.Linear(20, n_actions)

        self.v1 = nn.Linear(input_dim, 20)
        self.v = nn.Linear(20, 1)

        self.rewards = []
        self.actions = []
        self.states = []

    def remember(self, state, action, reward):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)

    def clear_memory(self):
        self.states = []
        self.actions = []
        self.rewards = []

    def forward(self, state):
        pi1 = F.relu(self.pi1(state))
        pi2 = F.relu(self.pi2(pi1))
        pi = self.pi(pi2)

        v1 = F.relu(self.v1(state))
        v = self.v(v1)

        return pi, v
    
    def calc_R(self):
        states = torch.stack(self.states)
        _, v = self.forward(states)

        R = v[-1]
        batch_return = []
        for reward in self.rewards[::-1]:
            R = reward + self.gamma * R
            batch_return.append(R)

        batch_return.reverse()
        return torch.tensor(batch_return, dtype=torch.float)
    
    def calc_loss(self, final_state: torch.Tensor):
        states = torch.stack(self.states)
        actions = torch.tensor(self.actions, dtype=torch.float)
        returns = self.calc_R()

        pi, values = self.forward(states)
        values = values.squeeze()
        final_state = final_state.unsqueeze(0)
        final_value = self.forward(final_state)[1].squeeze(0)
        next_values = torch.cat((values[1:], final_value), dim=0)
        td_error = returns + (self.gamma * next_values) - values

        critic_loss = td_error ** 2
        probs = F.softmax(pi, dim=1)
        dist = Categorical(probs)
        log_probs = dist.log_prob(actions)
        actor_loss = -log_probs * td_error.detach()

        total_loss = (actor_loss + critic_loss).mean()
        return total_loss

    def select_action(self, observation):
        state = observation.unsqueeze(0)
        pi, _ = self.forward(state)
        probs = F.softmax(pi, dim=1)
        dist = Categorical(probs)
        action = dist.sample().item()

        return action

In [None]:
def perform_action(abc: ABC, area, delay, action, possible_actions):
    new_stats = possible_actions[action](abc)
    observation_ = torch.tensor(new_stats[:6], dtype=torch.float)
    new_area, new_delay = new_stats[6], new_stats[7]

    if new_delay < delay:
        if new_area < area:
            reward = 3
        elif new_area > area:
            reward = 1
        else:
            reward = 2
    elif new_delay > delay:
        if new_area < area:
            reward = -1
        elif new_area > area:
            reward = -3
        else:
            reward = -2
    else:
        if new_area < area:
            reward = 3
        elif new_area > area:
            reward = -2
        else:
            reward = 0

    return observation_, new_area, new_delay, reward

In [None]:
def train(actor_critic: AdvantageActorCritic, optimizer, episodes, iterations, abc: ABC, input_dim, possible_actions):
    filelist = os.listdir('benchmarks/arithmetic')

    for episode in range(episodes):
        avg_loss = torch.tensor(0., dtype=torch.float)
        avg_score = torch.tensor(0., dtype=torch.float)
        avg_area = torch.tensor(0., dtype=torch.float)
        avg_delay = torch.tensor(0., dtype=torch.float)

        for filename in filelist:
            if not filename.endswith('.aig'):
                continue

            abc.read_aiger(f"benchmarks/arithmetic/{filename}")
            init_stats = abc.read_libraries("libraries/asap7sc7p5t_INVBUF_RVT_FF_nldm_201020.lib", "libraries/asap7sc7p5t_SIMPLE_RVT_FF_nldm_201020.lib")
            observation, area, delay = torch.tensor([1] * input_dim, dtype=torch.float), init_stats[6], init_stats[7]
            observation[0] = init_stats[0] / 512
            observation[1] = init_stats[1] / 130

            init_area = area
            init_delay = delay
            init_stats = torch.tensor(init_stats[:6], dtype=torch.float)

            score = 0
            actor_critic.clear_memory()

            for _ in range(iterations):
                action = actor_critic.select_action(observation)
                observation_, new_area, new_delay, reward = perform_action(abc, area, delay, action, possible_actions)
                observation_ = observation_ / init_stats

                if init_stats[2] == 0:
                    observation_[2] = 0.0
                observation_[0] = init_stats[0] / 512
                observation_[1] = init_stats[1] / 130

                score += reward
                actor_critic.remember(observation, action, reward)
                observation = observation_
                area = new_area
                delay = new_delay

            loss = actor_critic.calc_loss(observation_)
            avg_loss += loss.detach().item()
            avg_area += area / init_area
            avg_delay += delay / init_delay
            avg_score += score

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            actor_critic.clear_memory()

        avg_loss = (avg_loss * 4) / len(filelist)
        avg_area = (avg_area * 4) / len(filelist)
        avg_delay = (avg_delay * 4) / len(filelist)
        avg_score = (avg_score * 4) / len(filelist)
        print(f"Episode {episode + 1}: Avg loss {avg_loss} Avg area {avg_area} Avg delay {avg_delay} Avg score {avg_score}")

In [None]:
possible_actions = [
    functools.partial(ABC.resub, zero_cost=False, preserve_levels=False),
    functools.partial(ABC.resub, zero_cost=False, preserve_levels=True),
    functools.partial(ABC.resub, zero_cost=True, preserve_levels=False),
    functools.partial(ABC.resub, zero_cost=True, preserve_levels=True),
    functools.partial(ABC.rewrite, zero_cost=False, preserve_levels=True, verbose=False),
    functools.partial(ABC.rewrite, zero_cost=True, preserve_levels=True, verbose=False),
    functools.partial(ABC.rewrite, zero_cost=False, preserve_levels=False, verbose=False),
    functools.partial(ABC.rewrite, zero_cost=True, preserve_levels=False, verbose=False),
    functools.partial(ABC.refactor, zero_cost=False, preserve_levels=True),
    functools.partial(ABC.refactor, zero_cost=False, preserve_levels=False),
    functools.partial(ABC.refactor, zero_cost=True, preserve_levels=True),
    functools.partial(ABC.refactor, zero_cost=True, preserve_levels=False),
    functools.partial(ABC.balance),
]

n_actions = len(possible_actions)
input_dim = 6
learning_rate = 1e-3
gamma = 0.9

In [None]:
actor_critic = AdvantageActorCritic(input_dim, n_actions, gamma)
optimiser = Adam(actor_critic.parameters(), lr=learning_rate)
abc = ABC()

In [None]:
train(actor_critic, optimiser, 10, 10, abc, input_dim, possible_actions)

In [None]:
abc.quit()

In [None]:
torch.save(actor_critic, "actor_critic2.pth")