In [1]:
import argparse
import os
import sys
import numpy as np
import torch
import gym
import matplotlib

matplotlib.use("agg")
import matplotlib.pyplot as plt
import unittest
from DeepLearning_Models.utils.general import join, plot_combined
from DeepLearning_Models.ActorCritic.policy_gradient import PolicyGradient
from Explanations_Models.LIME import LimeModel
import random
import yaml
yaml.add_constructor("!join", join)
parser = argparse.ArgumentParser()
def weight_dict_corrector(loaded_state_dict):
    new_state_dict = {}
    for key, value in loaded_state_dict.items():
        new_key = key.replace("network.", "")  # Remove the 'network.' prefix
        new_state_dict[new_key] = value
    return new_state_dict

In [2]:
config_file = open("config_envs/{}.yml".format("cartpole"))
config = yaml.load(config_file, Loader=yaml.FullLoader)
env = gym.make(config["env"]["env_name"], render_mode="rgb_array")
seed = config["env"]["seed"][0]
model = PolicyGradient(env, config, seed)

In [3]:
model = PolicyGradient(env, config, seed)

model.network.load_state_dict(weight_dict_corrector(torch.load(config["output"]["actor_output"].format(seed))))
model.baseline_network.load_state_dict(torch.load(config["output"]["critic_output"].format(seed)))

<All keys matched successfully>

Begin LIME Explanations

In [4]:
config_file = open("config_explanations/{}.yml".format("GaussianSample"))
config = yaml.load(config_file, Loader=yaml.FullLoader)

In [5]:
LM = LimeModel(model.network, torch.tensor([0.,0.,0.,0.]), config)

In [6]:
samps = LM.sample(100, LM.point)

In [7]:
Y = LM.model(torch.tensor(LM.sample_points, device="cuda", dtype = torch.float32))

  Y = LM.model(torch.tensor(LM.sample_points, device="cuda", dtype = torch.float32))


In [8]:
Y[:,0]

tensor([  3.3547, -10.5668, -17.8137,  10.1920,   4.8997,  -6.6350,  14.3175,
         -7.7740, -13.1545,  -5.6261,  10.8121,  -9.5294,   2.2876,  -3.6910,
        -20.6059,   6.2828,  12.0113,  25.9428,  -0.3521,  13.7781,  19.1842,
         -6.3477, -16.1317,  -4.1602,  -6.5810, -16.9604,   3.8750,  11.1806,
         -4.5095,  14.2091,  -0.7084,  11.2509, -28.7032, -22.1905,  -9.4337,
        -28.5690,  -9.3899,  -8.8435, -19.5029,   3.1770,   6.9602,  -1.0129,
         -5.6001,  -3.2313,  -0.6847,  19.9492,  13.7667,  -3.6638,  -4.8174,
         -9.5965,  13.9765,  11.3342, -12.7250, -13.1623,  -6.7251,  21.9975,
          0.4634, -16.0289,  -0.3898,   5.5406,   5.8469,  -8.0073,  18.9156,
         -6.3588,  15.5878,   6.1713,  -8.4890, -17.6609,  -9.0555,  16.0272,
          1.1764,   7.4044,   4.5111,  14.4021,   4.8010,  -9.7887,  -7.7580,
        -19.4607,   4.7492,   7.4421, -19.9969,   8.1797,  -2.7411,   9.2231,
         22.5388,   4.3602,   4.4240, -23.1548,  21.9900, -17.53

In [9]:
LM.runner()

tensor([[0., 0., 0., 0.]], device='cuda:0')
tensor([[ 0.2719,  0.0805,  2.2126,  0.0684],
        [ 0.5232,  0.8501,  0.7001,  0.2933],
        [-0.1720,  0.4648, -0.9004, -0.5971],
        [ 0.2738,  1.6555, -0.2676, -0.6091],
        [ 1.1393,  1.6320, -0.6169,  1.4469],
        [ 0.0306,  0.3302, -0.8416,  3.5051],
        [-1.4211, -0.4201,  0.6641,  0.4464],
        [ 0.4221, -0.5076,  0.8310,  0.5409],
        [ 0.8098, -0.3150,  1.4641,  1.3306],
        [-0.0609, -0.7338,  1.6359,  0.7501],
        [-1.1909, -0.4802, -0.2009,  0.3843],
        [ 0.6437, -0.7280, -0.6158, -1.5993],
        [ 0.6977, -1.3572, -1.9675, -0.5005],
        [ 1.8424, -1.3201, -0.7002, -0.1138],
        [-0.2187, -0.4289, -0.3054, -0.8079],
        [-0.6823,  1.4480, -0.0263, -0.3958],
        [-0.9352,  0.1360, -0.2440,  1.2177],
        [-1.3048, -0.6143, -0.3758,  0.3014],
        [-0.3039,  0.3354, -0.2922,  1.0854],
        [ 0.5032, -0.5779, -1.7488, -0.9334],
        [ 1.0595,  1.9707, -0.4723, 

In [11]:
for i in LM.interpretable_models:
    for param in i.parameters():
        print(param)

Parameter containing:
tensor([[ 0.6485, -3.4888, -8.7510, -8.1977]], device='cuda:0',
       requires_grad=True)
Parameter containing:
tensor([[-0.7123,  3.4499,  9.1834,  8.5764]], device='cuda:0',
       requires_grad=True)


In [11]:
t

tensor([0, 0, 0])

In [12]:
t.unsqueeze(0)

tensor([[0, 0, 0]])