In [1]:
import torch
import glob
from domain import make_domain
from search import run_search_on_batch

In [7]:
def get_policy(filename):
    policy = torch.load(filename, map_location=torch.device('cpu'))
    return policy

def get_tactics(filename):
    import pickle
    with open(filename, 'rb') as f:
        tactics = pickle.load(f)
    return tactics

In [30]:
def get_results(filename, tactics_filename, domain, seeds):
    policy = get_policy(filename)
    tactics = get_tactics(tactics_filename)
    print(len(tactics))
    eval_results = run_search_on_batch(
        make_domain(domain, tactics),
        seeds,
        policy,
        "on-policy-sample",
        1000,
        18,
        output_path=None,
        debug=False,
        epsilon=0,
    )

    return eval_results, tactics

In [22]:
base_path = "/network/scratch/m/moksh.jain/peano_experiments/"
exp_name = "all_cpo"
step_number = 99
cpo_filename = glob.glob(f"{base_path}/{exp_name}/{step_number}.pt")[0]
cpo_tactics_filename = glob.glob(f"{base_path}/{exp_name}/tactics-{step_number}.pkl")[0]
tb_base_path = "/network/scratch/m/moksh.jain/peano_experiments/"
tb_exp_name = "continue_train_partest_3_0.25"
tb_step_number = 20
tb_filename = glob.glob(f"{tb_base_path}/{tb_exp_name}/{tb_step_number}.pt")[0]
tb_tactics_filename = glob.glob(f"{tb_base_path}/{tb_exp_name}/tactics-{tb_step_number}.pkl")[0]
eval_interval = [1] * 50
domain = "two-step-eq"

In [None]:
def check_diversity(domain):
    cpo_div, tb_div = [], []
    for i in range(100):
        eval_interval = [i] * 50
        cpo_results = get_results(cpo_filename, cpo_tactics_filename, domain, eval_interval)
        tb_results = get_results(tb_filename, tb_tactics_filename, domain, eval_interval)
        tb_traj_success = ["+".join(e.actions) for e in tb_results.episodes if e.success]
        cpo_traj_success = ["+".join(e.actions) for e in cpo_results.episodes if e.success]
        print(f"Step {i}, Num Unique TB: {len(set(tb_traj_success))}, Num Unique CPO: {len(set(cpo_traj_success))}")  
        cpo_div.append(len(set(cpo_traj_success)))
        tb_div.append(len(set(tb_traj_success)))
    return cpo_div, tb_div      

In [6]:
tb, cpo = check_diversity(domain)

  policy = torch.load(filename, map_location=torch.device('cuda'))
100%|██████████| 50/50 [00:01<00:00, 34.54it/s]


Solved 44/50


100%|██████████| 50/50 [00:34<00:00,  1.46it/s]


Solved 31/50
Step 0, Num Unique TB: 1, Num Unique CPO: 1


100%|██████████| 50/50 [00:00<00:00, 104.86it/s]


Solved 49/50


100%|██████████| 50/50 [00:00<00:00, 103.74it/s]


Solved 50/50
Step 1, Num Unique TB: 1, Num Unique CPO: 1


100%|██████████| 50/50 [00:00<00:00, 105.62it/s]


Solved 47/50


100%|██████████| 50/50 [00:00<00:00, 104.01it/s]


Solved 50/50
Step 2, Num Unique TB: 1, Num Unique CPO: 1


100%|██████████| 50/50 [00:00<00:00, 110.53it/s]


Solved 45/50


100%|██████████| 50/50 [00:00<00:00, 104.20it/s]


Solved 50/50
Step 3, Num Unique TB: 1, Num Unique CPO: 1


100%|██████████| 50/50 [00:00<00:00, 239.78it/s]


Solved 3/50


100%|██████████| 50/50 [00:00<00:00, 221.43it/s]


Solved 2/50
Step 4, Num Unique TB: 1, Num Unique CPO: 1


100%|██████████| 50/50 [00:00<00:00, 106.74it/s]


Solved 50/50


100%|██████████| 50/50 [00:05<00:00,  8.63it/s]


Solved 22/50
Step 5, Num Unique TB: 1, Num Unique CPO: 1


100%|██████████| 50/50 [00:00<00:00, 103.17it/s]


Solved 50/50


100%|██████████| 50/50 [00:00<00:00, 102.96it/s]


Solved 50/50
Step 6, Num Unique TB: 1, Num Unique CPO: 1


100%|██████████| 50/50 [00:00<00:00, 109.87it/s]


Solved 48/50


100%|██████████| 50/50 [00:18<00:00,  2.66it/s]


Solved 26/50
Step 7, Num Unique TB: 1, Num Unique CPO: 1


100%|██████████| 50/50 [00:00<00:00, 149.77it/s]


Solved 2/50


100%|██████████| 50/50 [00:00<00:00, 185.55it/s]


Solved 0/50
Step 8, Num Unique TB: 0, Num Unique CPO: 1


100%|██████████| 50/50 [00:00<00:00, 99.91it/s] 


Solved 50/50


 84%|████████▍ | 42/50 [00:23<00:00, 17.73it/s]

: 

In [31]:
cpo_results = get_results(cpo_filename, cpo_tactics_filename, domain, eval_interval)

  policy = torch.load(filename, map_location=torch.device('cpu'))


38


100%|██████████| 50/50 [00:03<00:00, 12.81it/s]

Solved 44/50





In [34]:
cpo_results[1][0].steps

(Step(arrows=('eval',), arguments=('?a@*',), result='?0', branch=None),
 Step(arrows=('rewrite',), arguments=('?0', '?a@*'), result='?1', branch=None))

In [35]:
cpo_results[1][1].steps

(Step(arrows=('tactic000',), arguments=('?a',), result='?0', branch=None),
 Step(arrows=('tactic000',), arguments=('?0',), result='?1', branch=None))

In [39]:
cpo_results[1][30].steps

(Step(arrows=('tactic025',), arguments=('?a', '?b'), result='?0', branch=None),
 Step(arrows=('tactic022',), arguments=('?0', '?c'), result='?1', branch=None))

In [40]:
cpo_results[1][27].steps

(Step(arrows=('tactic014',), arguments=('?a', '?b'), result='?0', branch=None),
 Step(arrows=('tactic020',), arguments=('?0', '?c'), result='?1', branch=None))

In [24]:
cpo_results.episodes[0].states[0]

'G:(= x ?)\n(= (+ (/ x 8) -3) -2)'

In [25]:
success_traj = [e for e in cpo_results.episodes if e.success]

In [27]:
[e.states for e in success_traj]

[['G:(= x ?)\n(= (+ (/ x 8) -3) -2)',
  'G:(= x ?)\n(= (+ (/ x 8) -3) -2)\ntactic030:-###',
  'G:(= x ?)\n(= (+ (/ x 8) -3) -2)\ntactic030:-(= x 8)'],
 ['G:(= x ?)\n(= (+ (/ x 8) -3) -2)',
  'G:(= x ?)\n(= (+ (/ x 8) -3) -2)\ntactic030:-###',
  'G:(= x ?)\n(= (+ (/ x 8) -3) -2)\ntactic030:-(= x 8)'],
 ['G:(= x ?)\n(= (+ (/ x 8) -3) -2)',
  'G:(= x ?)\n(= (+ (/ x 8) -3) -2)\ntactic030:-###',
  'G:(= x ?)\n(= (+ (/ x 8) -3) -2)\ntactic030:-(= x 8)'],
 ['G:(= x ?)\n(= (+ (/ x 8) -3) -2)',
  'G:(= x ?)\n(= (+ (/ x 8) -3) -2)\ntactic030:-###',
  'G:(= x ?)\n(= (+ (/ x 8) -3) -2)\ntactic030:-(= x 8)'],
 ['G:(= x ?)\n(= (+ (/ x 8) -3) -2)',
  'G:(= x ?)\n(= (+ (/ x 8) -3) -2)\ntactic030:-###',
  'G:(= x ?)\n(= (+ (/ x 8) -3) -2)\ntactic030:-(= x 8)'],
 ['G:(= x ?)\n(= (+ (/ x 8) -3) -2)',
  'G:(= x ?)\n(= (+ (/ x 8) -3) -2)\ntactic030:-###',
  'G:(= x ?)\n(= (+ (/ x 8) -3) -2)\ntactic030:-(= x 8)'],
 ['G:(= x ?)\n(= (+ (/ x 8) -3) -2)',
  'G:(= x ?)\n(= (+ (/ x 8) -3) -2)\ntactic030:-###',
  

In [14]:
unsuccessful_trajs = [e for e in cpo_results.episodes if not e.success]

In [16]:
unsuccessful_trajs[1].actions

['+_assoc_r']