INTERPRETER tutorial

In [None]:
%pip install stable-baselines3==2.3
%pip install scikit-learn
%pip install cloudpickle==3.0.0
%pip install joblib

Train a PPO and save the final policy on CartPole-v1

In [None]:
from stable_baselines3 import PPO
import gymnasium as gym

env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env)
model.learn(6e4, progress_bar=True) # 4 minutes on cpu
model.save("oracle_cartpole")

INTERPRETER DT extraction Class

In [None]:
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from gymnasium import Env
from tqdm import tqdm
import os
from pickle import dump
from statistics import mean
from operator import itemgetter
from copy import deepcopy
import time
from typing import Union
from stable_baselines3 import PPO, DQN

class DecisionTreeExtractor: #Dagger
    def __init__(self, model: Union[PPO,DQN], dtpolicy: DecisionTreeClassifier, env: Env, data_per_iter: int=10_000):
        self.model = model
        self.env = env # is vectorized
        self.data_per_iter = data_per_iter
        self.dt = dtpolicy
        self.deterministic = isinstance(model, DQN)

    def collect_data(self):
        S, A = [], []
        s, _ = self.env.reset()
        for i in tqdm(range(self.data_per_iter)):
            
            action = self.model.predict(s, deterministic=self.deterministic)[0]
            S.append(s)
            A.append(action)
            s, _, term, trunc, _ = self.env.step(action)
            if term or trunc:
                s, _ = self.env.reset()
        return np.array(S), np.array(A)
    

    def collect_data_dt(self, mask):
        S = []
        episodes = []
        ep_reward = 0
        s, _ = self.env.reset()
        for i in range(self.data_per_iter):
            action = self.dt.predict(s[mask].reshape(1, -1))[0]
            S.append(s)
            s, r, term, trunc, infos = self.env.step(action)
            ep_reward += r
            if term or trunc:
                s, _ = self.env.reset()
                episodes.append(ep_reward)
                ep_reward = 0
        if len(episodes) < 1:
            episodes.append(ep_reward)
        return np.array(S), mean(episodes)

    def fit_DT(self, S, A):
        ## sampling
        self.dt.fit(S, A)
        acc = self.dt.score(S, A)
        return acc

    def imitate(self, nb_iter: int, mask):
        start_time = time.time()
        self.list_acc, self.list_eval, self.list_dt, self.times = [], [], [], []
        DS, DA = self.collect_data()
        acc_dt = self.fit_DT(DS[:,mask], DA)
        S_dt, eval_dt = self.collect_data_dt(mask)
        self.times.append(time.time()-start_time)

        print("Accuracy: {} - Evaluation: {}".format(acc_dt, eval_dt))
        self.list_dt.append(deepcopy(self.dt))
        self.list_acc.append(acc_dt)
        self.list_eval.append(eval_dt)
        DS = np.concatenate((DS, S_dt))
        DA = np.concatenate((DA, self.model.predict(S_dt)[0]))
        
        for _ in range(nb_iter - 1):
            acc_dt = self.fit_DT(DS[:,mask], DA)
            S_dt, eval_dt = self.collect_data_dt(mask)
            self.times.append(time.time()-start_time)

            print("Accuracy: {} - Evaluation: {}".format(acc_dt, eval_dt))
            self.list_dt.append(deepcopy(self.dt))
            self.list_acc.append(acc_dt)
            self.list_eval.append(eval_dt)
            DS = np.concatenate((DS, S_dt))
            DA = np.concatenate((DA, self.model.predict(S_dt)[0]))


class ObliqueDecisionTreeExtractor(DecisionTreeExtractor):
    def __init__(self, model: Union[PPO,DQN], dtpolicy: DecisionTreeClassifier, env: Env, data_per_iter: int = 10000):
        super().__init__(model, dtpolicy, env, data_per_iter)
    
    def fit_DT(self, S, A):
        num_cols = S.shape[1]

        # Generate indices for the lower triangular part of the matrix
        indices = np.tril_indices(num_cols, k=-1)

        # Tile the rows to create matrices for subtraction
        a_mat = np.tile(S[:, np.newaxis, :], (1, num_cols, 1))
        b_mat = np.transpose(a_mat, axes=(0, 2, 1))

        # Compute the differences and store them in the appropriate location in the result array
        diffs = a_mat - b_mat
        result = diffs[:,  indices[0], indices[1]]

        # Stack the original rows with the differences
        final = np.hstack((S, result))
        return super().fit_DT(final, A)
    
    def collect_data_dt(self, mask):
        S = []
        episodes = []
        ep_reward = 0
        s, _ = self.env.reset()
        s_mat = np.tile(s[mask],(s[mask].shape[0],1))
        diff_s = s_mat - s_mat.T
        s_comb = np.append(s[mask], diff_s[np.tril_indices(s[mask].shape[0], k=-1)])
        for h in tqdm(range(self.data_per_iter)):
            action = self.dt.predict(s_comb.reshape(1, -1))[0]
            S.append(s)
            s, r, term, trunc, infos = self.env.step(action)
            s_mat = np.tile(s[mask],(s[mask].shape[0],1))
            diff_s = s_mat - s_mat.T
            s_comb = np.append(s[mask], diff_s[np.tril_indices(s[mask].shape[0], k=-1)])
            ep_reward += r
            if term or trunc:
                s, _ = self.env.reset()
                s_mat = np.tile(s[mask],(s[mask].shape[0],1))
                diff_s = s_mat - s_mat.T
                s_comb = np.append(s[mask], diff_s[np.tril_indices(s[mask].shape[0], k=-1)])
                episodes.append(ep_reward)
                ep_reward = 0
        if len(episodes) < 1:
            episodes.append(ep_reward)
        return np.array(S), mean(episodes)
    
    
    def save_best_tree(self, save_dir: str):
        os.makedirs(save_dir, exist_ok=True)

        for j, tree in enumerate(self.list_dt):
            save=open(save_dir+"Oblique-Tree-{}_{}".format(j, self.list_eval[j]), 'wb')
            dump(tree, save)
            
        index, element = max(enumerate(self.list_eval), key=itemgetter(1))
        self.best_dt = self.list_dt[index]
        save=open(save_dir+"Best-Oblique-Tree-"+str(element), 'wb')

        dump(self.best_dt, save)


Extract an oblique tree imitating the PPO oracle

In [None]:
from stable_baselines3.common.evaluation import evaluate_policy
from sklearn.tree import DecisionTreeClassifier
import gymnasium as gym

# Oblique tree parameters
interpretability = dict(max_leaf_nodes=3)
clf = DecisionTreeClassifier(**interpretability)

# MDP and oracle
env = gym.make("CartPole-v1")
model = PPO.load("oracle_cartpole.zip")
print("Reward is {} out of 500".format(evaluate_policy(model, env)[0]))
exp_name = "trees/"

dagger = ObliqueDecisionTreeExtractor(model, clf, env, data_per_iter=5000)

dagger.imitate(nb_iter=10, mask=range(env.observation_space.shape[0]))

dagger.save_best_tree(exp_name)

Print tree

In [None]:
from sklearn.tree import plot_tree
import joblib
import glob

clf = joblib.load(glob.glob("trees/Best-Oblique-Tree-*")[0])
plot_tree(clf)

Conversion code

In [None]:
from sklearn import tree

def convert(interpretable=False):
    clf = joblib.load(glob.glob("trees/Best-Oblique-Tree-*")[0])
    if interpretable:
        feature_names = ["Cart_Position", "Cart_Velocity", "Pole_Angle", "Pole_Angular_Velocity"]
        actions = ["left", "right"]
    else:
        feature_names = ["[0]", "[1]", "[2]", "[3]"]
        actions = [0, 1]
    s = ["{}".format(i) for i in np.array(feature_names)]
    s_ = [" - {}".format(i) for i in np.array(feature_names)]
    s = np.array(s)
    s_ = np.array(s_)
    s_mat = np.tile(s,(s.shape[0],1))
    s_mat_ = np.tile(s_,(s_.shape[0],1))
    # pint(s_mat_)

    diff_s = []
    for m in range(s_mat.shape[0]):
        level = []
        for j in range(s_mat.shape[1]):
            level.append(s_mat[m,j] + s_mat_[j,m])
        diff_s.append(level)

    diff_s = np.array(diff_s, dtype=np.str_)

    s_comb = np.append(s, diff_s[np.tril_indices(s.shape[0], k=-1)])


    r = tree.export_text(clf, feature_names=s_comb, class_names=actions)
    if interpretable:
        if not os.path.exists('play_cartpole_interpretable.py'):
            with open('play_cartpole_interpretable.py', 'a') as the_file:
                the_file.write('def play(state):\n')
                for line in r.split("\n")[:-1]:
                    split_indent = line.split("|")
                    nb_indent = 2 * (len(split_indent)-2 + 1) #first empty last is if else + 1for def
                    features_sign_val = split_indent[-1].split('--- ')[1]
                    if "<=" in features_sign_val:
                        each_feat_val = features_sign_val.split("<=")

                        featss = each_feat_val[0]
                        if "-" in featss :
                            each_feat = each_feat_val[0].split(" - ")
                            val = each_feat_val[1]

                            python_line = nb_indent * "  " + "if state." + each_feat[0] +" - " + "state." + each_feat[1] + " <=" + val +":\n"
                        else:
                            python_line = nb_indent * "  " + "if state." + features_sign_val+":\n"
                    elif ">" in features_sign_val:
                        python_line = nb_indent * "  " + "else:\n"
                    else:
                        python_line = nb_indent * "  " + "return \"" + features_sign_val.split("class: ")[1] + "\"\n"
                    the_file.write(python_line)

    else:
        if not os.path.exists('play_cartpole_playable.py'):
            
            with open('play_cartpole_playable.py', 'a') as the_file:
                the_file.write('def play(state):\n')
                for line in r.split("\n")[:-1]:
                    split_indent = line.split("|")
                    nb_indent = 2 * (len(split_indent)-2 + 1) #first empty last is if else + 1for def
                    features_sign_val = split_indent[-1].split('--- ')[1]
                    if "<=" in features_sign_val:
                        each_feat_val = features_sign_val.split("<=")

                        featss = each_feat_val[0]
                        if "-" in featss :
                            each_feat = each_feat_val[0].split(" - ")
                            val = each_feat_val[1]

                            python_line = nb_indent * "  " + "if state" + each_feat[0] +" - " + "state" + each_feat[1] + " <=" + val +":\n"
                        else:
                            python_line = nb_indent * "  " + "if state" + features_sign_val+":\n"
                    elif ">" in features_sign_val:
                        python_line = nb_indent * "  " + "else:\n"
                    else:
                        python_line = nb_indent * "  " + "return " + features_sign_val.split("class: ")[1] + "\n"
                    the_file.write(python_line)

Get programs

In [None]:
#interpretable program
convert(True)
#playable_program
convert(False)

Play with programs

In [None]:
from play_cartpole_playable import play
env = gym.make("CartPole-v1", render_mode="human")
s, _ = env.reset()
done = False
sum_r = 0
while not done:
    a = play(s)
    s, r, term, trunc, _ = env.step(a)
    env.render()
    sum_r += r
    done = term or trunc

env.close()