<a href="https://colab.research.google.com/github/astrivedi/CSCI2270/blob/master/Sketching_with_RL_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch tqdm



In [None]:
# Importing
import torch
import torch.nn as nn
import random
from tqdm import tqdm

# Simple feedforward network with tanh activations
class FFN(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=32, hidden_layers=2):
        super().__init__()
        
        seq = []
        # Input layer
        seq.append(nn.Linear(in_dim, hidden_dim))
        seq.append(nn.Tanh())
        # Hidden layers
        for _ in range(hidden_layers-1):
            seq.append(nn.Linear(hidden_dim, hidden_dim))
            seq.append(nn.Tanh())
        # Output layer
        seq.append(nn.Linear(hidden_dim, out_dim))
        
        self.model = nn.Sequential(*seq)
    
    def forward(self, x):
        return self.model(x)

# Define Filler class which will fill holes
class Filler():
    def __init__(self):
        # Evaluation/training mode
        self.evaluation_mode = False
        # Define torch device
        self.device = 'cpu'
        # List of networks for each hole
        self.holes = []
        # prob_action history
        self.history = []
        # Accumulated reward per program run
        self.reward = 0
    
    def evaluate(self):
        self.evaluation_mode = True
    
    def add_hole(self, size):
        self.holes.append(FFN(size, 2).to(self.device))
        # Return hole number
        return len(self.holes)-1
    
    def less_than_equal_assert(self, a, b):
        if a <= b:
            self.reward += 1
        else:
            self.reward += -1
    
    def except_assert(self):
        self.reward -= 10
    
    def reset(self):
        self.history = []
        self.reward = 0
    
    def sample(self, hole_number, *inputs):
        state = torch.tensor(inputs, device=self.device).float().view(1, -1)
        net = self.holes[hole_number]
        
        if self.evaluation_mode:
            _, ind = torch.max(net(state[0]), dim=0) # Take most likely action (no sampling)
            return bool(ind)
        else :
            p = torch.softmax(net(state), dim=1)
            dist = torch.distributions.Categorical(p)
            action = dist.sample()
            self.history.append(p[0,action])
            return bool(action)
    
    def train(self, function, \
              lr=1e-3, baseline_lr=0.1, num_train_steps=500):
        parameters = []
        for net in self.holes:
            parameters += list(net.parameters())
        optim = torch.optim.Adam(parameters, lr=lr)

        baseline = torch.zeros(1).to(self.device)

        for train_step in tqdm(range(num_train_steps)):
            # Reset history and reward
            self.reset()
            
            # Call function (random inputs generated internally to function)
            try:
                function()
            except:
                self.except_assert()
            
            if len(self.history) != 0:
                # Update baseline
                baseline = (1-baseline_lr)*baseline + baseline_lr*self.reward
                # Update holes
                loss = 0
                for p_action in self.history:
                    loss += -torch.log(p_action)*(self.reward-baseline)

                optim.zero_grad()
                loss.backward()
                optim.step()    

In [None]:
filler = Filler()
hole = filler.add_hole(1)

# The program to fill
def absolute_value(n):
    if filler.sample(hole, n):
        return n
    else:
        return -n
    
def test_abs():
    n = random.randint(-10, 10)
    abs_n = absolute_value(n)
    filler.less_than_equal_assert(0, abs_n)

filler.train(test_abs)

# Put in eval mode
filler.evaluate()

100%|██████████| 500/500 [00:00<00:00, 551.62it/s]


In [None]:
for x in range(-10,11):
    print('Absolute value of ' + str(x) + ' is ' + str(absolute_value(x)))

Absolute value of -10 is 10
Absolute value of -9 is 9
Absolute value of -8 is 8
Absolute value of -7 is 7
Absolute value of -6 is 6
Absolute value of -5 is 5
Absolute value of -4 is 4
Absolute value of -3 is 3
Absolute value of -2 is 2
Absolute value of -1 is 1
Absolute value of 0 is 0
Absolute value of 1 is 1
Absolute value of 2 is 2
Absolute value of 3 is 3
Absolute value of 4 is 4
Absolute value of 5 is 5
Absolute value of 6 is 6
Absolute value of 7 is 7
Absolute value of 8 is 8
Absolute value of 9 is 9
Absolute value of 10 is 10


In [None]:
filler = Filler()
hole0 = filler.add_hole(2)
hole1 = filler.add_hole(2)

# Program with holes
def partition(l, low, high):
    pivot = l[high]
    swap_index = low
    
    for t in range(low, high+1):
        if filler.sample(hole0, l[t], pivot): # l[t] < pivot:
            l[swap_index], l[t] = l[t], l[swap_index]
            swap_index += 1
    l[swap_index], l[high] = l[high], l[swap_index]
    
    return swap_index

def quick_sort(l, low, high, depth=0):
    # Bound recursion depth
    if depth > len(l):
        return
    
    if filler.sample(hole1, low, high): # low < high
        p = partition(l, low, high)
        quick_sort(l, low, p - 1, depth=depth+1)
        quick_sort(l, p + 1, high, depth=depth+1)   

# Test on random lists
def test_quick_sort(max_length=10, val_range=(-10,10)):
    length = random.randint(0, max_length)
    l = [random.randint(val_range[0], val_range[1]) for _ in range(length)]

    quick_sort(l, 0, len(l)-1)
    
    # Assert that the list is sorted
    if len(l) > 1:
        for t in range(0, len(l)-1):
            filler.less_than_equal_assert(l[t], l[t+1]) #assert(l[t] <= l[t+1])

filler.train(test_quick_sort, num_train_steps=4000)

filler.evaluate()            

100%|██████████| 4000/4000 [00:28<00:00, 142.25it/s]


In [None]:
for _ in range(10):
    length = random.randint(0, 10)
    l = [random.randint(-10, 10) for _ in range(length)]
    print(l, end='')
    quick_sort(l, 0, len(l)-1)
    print(' ==> ', end='')
    print(l)

[0, 1, -8, -10, 3, 5, -6] ==> [-10, -8, -6, 0, 1, 3, 5]
[-6, 3, -2, -8, 1, 1] ==> [-8, -6, -2, 1, 1, 3]
[] ==> []
[-9] ==> [-9]
[1, -3, -3] ==> [-3, -3, 1]
[] ==> []
[-7, 4] ==> [-7, 4]
[2, -8] ==> [-8, 2]
[-4, -1, -9, 6, 8, -10] ==> [-10, -9, -4, -1, 6, 8]
[-1, -1, 6, -2] ==> [-2, -1, -1, 6]
