In [1]:
import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
import itertools
from util.save_load import load_kernel_model
from dataset.graphs_kernel import get_graph_data
from kernels.wrapper import MODELS, KernelModelWrapper
from dataset.ipc2023_learning_domain_info import IPC2023_LEARNING_DOMAINS, get_number_of_ipc2023_training_data
from itertools import product
from IPython.display import display, HTML

In [36]:
_LOG_DIR = "logs/train_kernel"

ITERATIONS = [1, 2, 3, 4, 5]

PRUNE_COEFS = [0, 1, 2, 3, 4, 5]

N_TRAINING_DATA = get_number_of_ipc2023_training_data()

### Train metrics

In [59]:
def get_data(domain):
  d = {
    "config": [],
    "mse": [],
    "f1": [],
    "nonzero_weights": [],
  }

  CONFIGS = list(product(MODELS, ITERATIONS, PRUNE_COEFS))
  for config in CONFIGS:
    model, iterations, prune = config
    log_file = "_".join([model, "llg", "ipc2023-learning-"+domain, "wl", str(iterations), str(prune*iterations)])+".log"
    log_file = _LOG_DIR + "/" + log_file
    
    if not os.path.exists(log_file):
      continue

    stats = {
      "config": "_".join([model, "wl", str(iterations), str(prune*iterations)])
    }

    lines = list(open(log_file, 'r').readlines())
    for line in lines:
      toks = line.split()
      if "train_mse" in line:
        stats["mse"] = float(toks[-1])
      elif "train_f1_macro" in line:
        stats["f1"] = float(toks[-1])
      elif "zero_weights" in line:
        weights = int(toks[1].split('/')[1])
        zeros = int(toks[1].split('/')[0])
        stats["nonzero_weights"] = weights - zeros
    
    if "nonzero_weights" not in stats:
      stats["nonzero_weights"] = "na"

    if len(stats) != len(d):
      continue

    for key in stats:
      d[key].append(stats[key])

  return d

def get_df(domain):
  d = get_data(domain)
  return pd.DataFrame(d)

In [60]:
for domain in IPC2023_LEARNING_DOMAINS:
  print(domain, N_TRAINING_DATA[domain])
  df = get_df(domain)
  display(df)

blocksworld 4954


Unnamed: 0,config,mse,f1,nonzero_weights
0,linear-svr_wl_1_0,0.32,0.45,6150
1,linear-svr_wl_1_1,0.62,0.42,5268
2,linear-svr_wl_1_2,0.65,0.45,1887
3,linear-svr_wl_1_3,0.58,0.42,1644
4,linear-svr_wl_1_4,0.68,0.36,584
5,linear-svr_wl_1_5,0.6,0.37,543
6,linear-svr_wl_3_0,0.06,0.99,419878
7,linear-svr_wl_3_3,0.06,0.99,78437
8,linear-svr_wl_3_6,0.07,0.98,40158
9,linear-svr_wl_3_9,0.1,0.97,27476


childsnack 2148


Unnamed: 0,config,mse,f1,nonzero_weights


ferry 3662


Unnamed: 0,config,mse,f1,nonzero_weights


floortile 8351


Unnamed: 0,config,mse,f1,nonzero_weights


miconic 1630


Unnamed: 0,config,mse,f1,nonzero_weights


rovers 4623


Unnamed: 0,config,mse,f1,nonzero_weights


satellite 26919


Unnamed: 0,config,mse,f1,nonzero_weights


sokoban 2422


Unnamed: 0,config,mse,f1,nonzero_weights


spanner 1416


Unnamed: 0,config,mse,f1,nonzero_weights


transport 4316


Unnamed: 0,config,mse,f1,nonzero_weights


### WL metrics

In [37]:
def get_models(domain):
  d = {}

  for iterations, prune in itertools.product(ITERATIONS, PRUNE_COEFS):
    model = "linear-svr"
    log_file = "_".join([model, "llg", "ipc2023-learning-"+domain, "wl", str(iterations), str(prune*iterations)])+".log"
    log_file = _LOG_DIR + "/" + log_file
    
    if not os.path.exists(log_file):
      continue
      
    print(domain, iterations, prune)
    
    model_file = None
    for line in list(open(log_file, 'r').readlines()):
      toks = line.split()
      if "Model parameter file:" in line and len(toks)>3:
        model_file = line.split()[-1]
        break
    
    if model_file is None:
      continue
    model : KernelModelWrapper = load_kernel_model(model_file)[0]
    d[(iterations, prune)] = model

  return d

In [38]:
# load all models
model_domain = {}
for domain in IPC2023_LEARNING_DOMAINS:
  print(domain, N_TRAINING_DATA[domain])
  models = get_models(domain)
  model_domain[domain] = models

blocksworld 4954
blocksworld 1 0


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


blocksworld 1 1
blocksworld 1 2
blocksworld 1 3
blocksworld 1 4
blocksworld 1 5
blocksworld 3 0
blocksworld 3 1
blocksworld 3 2
blocksworld 3 3
blocksworld 3 4
blocksworld 3 5
blocksworld 5 0
blocksworld 5 1
blocksworld 5 2
blocksworld 5 3
blocksworld 5 4
blocksworld 5 5
childsnack 2148
ferry 3662
floortile 8351
miconic 1630
rovers 4623
satellite 26919
sokoban 2422
spanner 1416
transport 4316


In [28]:
domain = "blocksworld"
os.makedirs("plots", exist_ok=True)
for iterations, prune in itertools.product(ITERATIONS, PRUNE_COEFS):
  try:
    prune = iterations * prune
    m = model_domain[domain][(iterations, prune)]
    h = m.get_hash()
    n_colours = len(m.get_hash())
    hist = np.zeros(n_colours)
    for col, cnt in m._kernel._train_histogram.items():
      key = str(col)
      for symbol in [")", "(", " "]:
        key = key.replace(symbol, "")
      if key not in h:
        continue
      c = h[key]
      hist[c] += cnt
    hist = sorted(hist, reverse=True)
    bins = np.arange(len(hist)+1)
    plt.hist(bins[:-1], bins, weights=hist, log=True)
    plt.savefig(f"plots/wl_count_{domain}_{iterations}_{prune}_{n_colours}.png")
    plt.clf()
  except:
    pass

<Figure size 640x480 with 0 Axes>

In [48]:
domain="blocksworld"
iterations=1
prune=0
m = model_domain[domain][(iterations, prune)]
h1 = m.get_hash()

In [49]:
domain="blocksworld"
iterations=5
prune=5
m = model_domain[domain][(iterations, prune)]
h2 = m.get_hash()

In [53]:
common_colours = set(h1.keys()).intersection(set(h2.keys()))
print(len(common_colours), len(h1), len(h2))

200 6219 7743


In [41]:
model_domain[domain]

{(1, 0): <kernels.wrapper.KernelModelWrapper at 0x7fa3ab09c730>,
 (1, 1): <kernels.wrapper.KernelModelWrapper at 0x7fa3ab09c3d0>,
 (1, 2): <kernels.wrapper.KernelModelWrapper at 0x7fa3e44b9f60>,
 (1, 3): <kernels.wrapper.KernelModelWrapper at 0x7fa3ab09f0d0>,
 (1, 4): <kernels.wrapper.KernelModelWrapper at 0x7fa3a46109d0>,
 (1, 5): <kernels.wrapper.KernelModelWrapper at 0x7fa3a4611e70>,
 (3, 0): <kernels.wrapper.KernelModelWrapper at 0x7fa3a47ecd30>,
 (3, 1): <kernels.wrapper.KernelModelWrapper at 0x7fa3a46133d0>,
 (3, 2): <kernels.wrapper.KernelModelWrapper at 0x7fa3a47ee410>,
 (3, 3): <kernels.wrapper.KernelModelWrapper at 0x7fa3a47eee00>,
 (3, 4): <kernels.wrapper.KernelModelWrapper at 0x7fa3a47ec6a0>,
 (3, 5): <kernels.wrapper.KernelModelWrapper at 0x7fa3a47ec9d0>,
 (5, 1): <kernels.wrapper.KernelModelWrapper at 0x7fa3a47ec0d0>,
 (5, 2): <kernels.wrapper.KernelModelWrapper at 0x7fa3a47ec100>,
 (5, 3): <kernels.wrapper.KernelModelWrapper at 0x7fa3beeee5c0>,
 (5, 4): <kernels.wrapper

In [7]:
from util.save_load import load_kernel_model_and_setup
m = load_kernel_model_and_setup("trained_models_kernel/linear-svr_llg_ipc2023-learning-blocksworld_wl_1_0", "../benchmarks/ipc2023-learning-benchmarks/blocksworld/domain.pddl", "../benchmarks/ipc2023-learning-benchmarks/blocksworld/training/easy/p01.pddl")
m._kernel._train_histogram

llg created!
time taken: 0.1975s
num nodes: 58
num edges: 100
graph density: 0.060496067755595885
123
