Skip to content

Commit

Permalink
Merge pull request tensorflow#235 from tensorflow/rl-util
Browse files Browse the repository at this point in the history
RL util
  • Loading branch information
colah committed Mar 26, 2020
2 parents 3f72cb2 + 8d6d6fd commit 3e95f64
Show file tree
Hide file tree
Showing 6 changed files with 750 additions and 0 deletions.
91 changes: 91 additions & 0 deletions lucid/scratch/rl_util/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np
import tensorflow as tf
import sys
import importlib
from lucid.modelzoo.vision_base import Model
from lucid.misc.channel_reducer import ChannelReducer
import lucid.optvis.param as param
import lucid.optvis.objectives as objectives
import lucid.optvis.render as render
import lucid.optvis.transform as transform
from lucid.misc.io import show, save
from lucid.misc.io.showing import _image_url, _display_html

try:
import lucid.scratch.web.svelte as lucid_svelte
except NameError:
lucid_svelte = None
from .joblib_wrapper import load_joblib, save_joblib
from .util import (
zoom_to,
get_var,
get_shape,
concatenate_horizontally,
hue_to_rgb,
channels_to_rgb,
conv2d,
norm_filter,
brightness_to_opacity,
)
from .attribution import (
gradient_override_map,
maxpool_override,
get_acts,
get_grad_or_attr,
get_attr,
get_grad,
get_paths,
get_multi_path_attr,
)
from .nmf import argmax_nd, LayerNMF, rescale_opacity


def all_():
return __all__


def reload(globals_dict):
m = importlib.reload(sys.modules[__name__])
for f in m.__all__:
globals_dict.update({f: getattr(m, f)})


__all__ = [
"np",
"tf",
"Model",
"ChannelReducer",
"param",
"objectives",
"render",
"transform",
"show",
"save",
"_image_url",
"_display_html",
"lucid_svelte",
"load_joblib",
"save_joblib",
"zoom_to",
"get_var",
"get_shape",
"concatenate_horizontally",
"hue_to_rgb",
"channels_to_rgb",
"conv2d",
"norm_filter",
"brightness_to_opacity",
"gradient_override_map",
"maxpool_override",
"get_acts",
"get_grad_or_attr",
"get_attr",
"get_grad",
"get_paths",
"get_multi_path_attr",
"argmax_nd",
"LayerNMF",
"rescale_opacity",
"all_",
"reload",
]
54 changes: 54 additions & 0 deletions lucid/scratch/rl_util/arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import tensorflow as tf


def clear_cnn(unscaled_images, batch_norm=False):
"""A simple convolutional architecture designed with interpretability in
mind:
- Later convolutional layers have been replaced with dense layers, to allow
for non-visual processing
- There are no residual connections, so that the flow of information passes
through every layer
- A pool size equal to the stride has been used, to avoid gradient gridding
- L2 pooling has been used instead of max pooling, for more continuous
gradients
Batch norm has been optionally included to help with optimization.
"""

def conv_layer(out, filters, kernel_size):
out = tf.layers.conv2d(
out, filters, kernel_size, padding="same", activation=None
)
if batch_norm:
out = tf.layers.batch_normalization(out)
out = tf.nn.relu(out)
return out

def pool_l2(out, pool_size):
return tf.sqrt(
tf.layers.average_pooling2d(
out ** 2, pool_size=pool_size, strides=pool_size, padding="same"
)
+ 1e-8
)

out = tf.cast(unscaled_images, tf.float32) / 255.0
with tf.variable_scope("1a"):
out = conv_layer(out, 16, 7)
out = pool_l2(out, 2)
with tf.variable_scope("2a"):
out = conv_layer(out, 32, 5)
with tf.variable_scope("2b"):
out = conv_layer(out, 32, 5)
out = pool_l2(out, 2)
with tf.variable_scope("3a"):
out = conv_layer(out, 32, 5)
out = pool_l2(out, 2)
with tf.variable_scope("4a"):
out = conv_layer(out, 32, 5)
out = pool_l2(out, 2)
out = tf.layers.flatten(out)
out = tf.layers.dense(out, 256, activation=tf.nn.relu)
out = tf.layers.dense(out, 512, activation=tf.nn.relu)
return out
192 changes: 192 additions & 0 deletions lucid/scratch/rl_util/attribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import numpy as np
import tensorflow as tf
import lucid.optvis.render as render
import itertools
from lucid.misc.gradient_override import gradient_override_map


def maxpool_override():
def MaxPoolGrad(op, grad):
inp = op.inputs[0]
op_args = [
op.get_attr("ksize"),
op.get_attr("strides"),
op.get_attr("padding"),
]
smooth_out = tf.nn.avg_pool(inp ** 2, *op_args) / (
1e-2 + tf.nn.avg_pool(tf.abs(inp), *op_args)
)
inp_smooth_grad = tf.gradients(smooth_out, [inp], grad)[0]
return inp_smooth_grad

return {"MaxPool": MaxPoolGrad}


def get_acts(model, layer_name, obses):
with tf.Graph().as_default(), tf.Session():
t_obses = tf.placeholder_with_default(
obses.astype(np.float32), (None, None, None, None)
)
T = render.import_model(model, t_obses, t_obses)
t_acts = T(layer_name)
return t_acts.eval()


def get_grad_or_attr(
model,
layer_name,
prev_layer_name,
obses,
*,
act_dir=None,
act_poses=None,
score_fn=tf.reduce_sum,
grad_or_attr,
override=None,
integrate_steps=1
):
with tf.Graph().as_default(), tf.Session(), gradient_override_map(override or {}):
t_obses = tf.placeholder_with_default(
obses.astype(np.float32), (None, None, None, None)
)
T = render.import_model(model, t_obses, t_obses)
t_acts = T(layer_name)
if prev_layer_name is None:
t_acts_prev = t_obses
else:
t_acts_prev = T(prev_layer_name)
if act_dir is not None:
t_acts = act_dir[None, None, None] * t_acts
if act_poses is not None:
t_acts = tf.gather_nd(
t_acts,
tf.concat([tf.range(obses.shape[0])[..., None], act_poses], axis=-1),
)
t_score = score_fn(t_acts)
t_grad = tf.gradients(t_score, [t_acts_prev])[0]
if integrate_steps > 1:
acts_prev = t_acts_prev.eval()
grad = (
sum(
[
t_grad.eval(feed_dict={t_acts_prev: acts_prev * alpha})
for alpha in np.linspace(0, 1, integrate_steps + 1)[1:]
]
)
/ integrate_steps
)
else:
acts_prev = None
grad = t_grad.eval()
if grad_or_attr == "grad":
return grad
elif grad_or_attr == "attr":
if acts_prev is None:
acts_prev = t_acts_prev.eval()
return acts_prev * grad
else:
raise NotImplementedError


def get_attr(model, layer_name, prev_layer_name, obses, **kwargs):
kwargs["grad_or_attr"] = "attr"
return get_grad_or_attr(model, layer_name, prev_layer_name, obses, **kwargs)


def get_grad(model, layer_name, obses, **kwargs):
kwargs["grad_or_attr"] = "grad"
return get_grad_or_attr(model, layer_name, None, obses, **kwargs)


def get_paths(acts, nmf, *, max_paths, integrate_steps):
acts_reduced = nmf.transform(acts)
residual = acts - nmf.inverse_transform(acts_reduced)
combs = itertools.combinations(range(nmf.features), nmf.features // 2)
if nmf.features % 2 == 0:
combs = np.array([comb for comb in combs if 0 in comb])
else:
combs = np.array(list(combs))
if max_paths is None:
splits = combs
else:
num_splits = min((max_paths + 1) // 2, combs.shape[0])
splits = combs[
np.random.choice(combs.shape[0], size=num_splits, replace=False), :
]
for i, split in enumerate(splits):
indices = np.zeros(nmf.features)
indices[split] = 1.0
indices = indices[tuple(None for _ in range(acts_reduced.ndim - 1))]
complements = [False, True]
if max_paths is not None and i * 2 + 1 == max_paths:
complements = [np.random.choice(complements)]
for complement in complements:
path = []
for alpha in np.linspace(0, 1, integrate_steps + 1)[1:]:
if complement:
coordinates = (1.0 - indices) * alpha ** 2 + indices * (
1.0 - (1.0 - alpha) ** 2
)
else:
coordinates = indices * alpha ** 2 + (1.0 - indices) * (
1.0 - (1.0 - alpha) ** 2
)
path.append(
nmf.inverse_transform(acts_reduced * coordinates) + residual * alpha
)
yield path


def get_multi_path_attr(
model,
layer_name,
prev_layer_name,
obses,
prev_nmf,
*,
act_dir=None,
act_poses=None,
score_fn=tf.reduce_sum,
override=None,
max_paths=50,
integrate_steps=10
):
with tf.Graph().as_default(), tf.Session(), gradient_override_map(override or {}):
t_obses = tf.placeholder_with_default(
obses.astype(np.float32), (None, None, None, None)
)
T = render.import_model(model, t_obses, t_obses)
t_acts = T(layer_name)
if prev_layer_name is None:
t_acts_prev = t_obses
else:
t_acts_prev = T(prev_layer_name)
if act_dir is not None:
t_acts = act_dir[None, None, None] * t_acts
if act_poses is not None:
t_acts = tf.gather_nd(
t_acts,
tf.concat([tf.range(obses.shape[0])[..., None], act_poses], axis=-1),
)
t_score = score_fn(t_acts)
t_grad = tf.gradients(t_score, [t_acts_prev])[0]
acts_prev = t_acts_prev.eval()
path_acts = get_paths(
acts_prev, prev_nmf, max_paths=max_paths, integrate_steps=integrate_steps
)
deltas_of_path = lambda path: np.array(
[b - a for a, b in zip([np.zeros_like(acts_prev)] + path[:-1], path)]
)
grads_of_path = lambda path: np.array(
[t_grad.eval(feed_dict={t_acts_prev: acts}) for acts in path]
)
path_attrs = map(
lambda path: (deltas_of_path(path) * grads_of_path(path)).sum(axis=0),
path_acts,
)
total_attr = 0
num_paths = 0
for attr in path_attrs:
total_attr += attr
num_paths += 1
return total_attr / num_paths
25 changes: 25 additions & 0 deletions lucid/scratch/rl_util/joblib_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from lucid.misc.io.saving import nullcontext
from lucid.misc.io.loading import load_using_loader
from lucid.misc.io.writing import write_handle


def load_joblib(url_or_handle, *, cache=None, **kwargs):
import joblib

return load_using_loader(
url_or_handle,
decompressor=nullcontext,
loader=joblib.load,
cache=cache,
**kwargs
)


def save_joblib(value, url_or_handle, **kwargs):
import joblib

if hasattr(url_or_handle, "write") and hasattr(url_or_handle, "name"):
joblib.dump(value, url_or_handle, **kwargs)
else:
with write_handle(url_or_handle) as handle:
joblib.dump(value, handle, **kwargs)

0 comments on commit 3e95f64

Please sign in to comment.