forked from tensorflow/lucid
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request tensorflow#235 from tensorflow/rl-util
RL util
- Loading branch information
Showing
6 changed files
with
750 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
import sys | ||
import importlib | ||
from lucid.modelzoo.vision_base import Model | ||
from lucid.misc.channel_reducer import ChannelReducer | ||
import lucid.optvis.param as param | ||
import lucid.optvis.objectives as objectives | ||
import lucid.optvis.render as render | ||
import lucid.optvis.transform as transform | ||
from lucid.misc.io import show, save | ||
from lucid.misc.io.showing import _image_url, _display_html | ||
|
||
try: | ||
import lucid.scratch.web.svelte as lucid_svelte | ||
except NameError: | ||
lucid_svelte = None | ||
from .joblib_wrapper import load_joblib, save_joblib | ||
from .util import ( | ||
zoom_to, | ||
get_var, | ||
get_shape, | ||
concatenate_horizontally, | ||
hue_to_rgb, | ||
channels_to_rgb, | ||
conv2d, | ||
norm_filter, | ||
brightness_to_opacity, | ||
) | ||
from .attribution import ( | ||
gradient_override_map, | ||
maxpool_override, | ||
get_acts, | ||
get_grad_or_attr, | ||
get_attr, | ||
get_grad, | ||
get_paths, | ||
get_multi_path_attr, | ||
) | ||
from .nmf import argmax_nd, LayerNMF, rescale_opacity | ||
|
||
|
||
def all_(): | ||
return __all__ | ||
|
||
|
||
def reload(globals_dict): | ||
m = importlib.reload(sys.modules[__name__]) | ||
for f in m.__all__: | ||
globals_dict.update({f: getattr(m, f)}) | ||
|
||
|
||
__all__ = [ | ||
"np", | ||
"tf", | ||
"Model", | ||
"ChannelReducer", | ||
"param", | ||
"objectives", | ||
"render", | ||
"transform", | ||
"show", | ||
"save", | ||
"_image_url", | ||
"_display_html", | ||
"lucid_svelte", | ||
"load_joblib", | ||
"save_joblib", | ||
"zoom_to", | ||
"get_var", | ||
"get_shape", | ||
"concatenate_horizontally", | ||
"hue_to_rgb", | ||
"channels_to_rgb", | ||
"conv2d", | ||
"norm_filter", | ||
"brightness_to_opacity", | ||
"gradient_override_map", | ||
"maxpool_override", | ||
"get_acts", | ||
"get_grad_or_attr", | ||
"get_attr", | ||
"get_grad", | ||
"get_paths", | ||
"get_multi_path_attr", | ||
"argmax_nd", | ||
"LayerNMF", | ||
"rescale_opacity", | ||
"all_", | ||
"reload", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import tensorflow as tf | ||
|
||
|
||
def clear_cnn(unscaled_images, batch_norm=False): | ||
"""A simple convolutional architecture designed with interpretability in | ||
mind: | ||
- Later convolutional layers have been replaced with dense layers, to allow | ||
for non-visual processing | ||
- There are no residual connections, so that the flow of information passes | ||
through every layer | ||
- A pool size equal to the stride has been used, to avoid gradient gridding | ||
- L2 pooling has been used instead of max pooling, for more continuous | ||
gradients | ||
Batch norm has been optionally included to help with optimization. | ||
""" | ||
|
||
def conv_layer(out, filters, kernel_size): | ||
out = tf.layers.conv2d( | ||
out, filters, kernel_size, padding="same", activation=None | ||
) | ||
if batch_norm: | ||
out = tf.layers.batch_normalization(out) | ||
out = tf.nn.relu(out) | ||
return out | ||
|
||
def pool_l2(out, pool_size): | ||
return tf.sqrt( | ||
tf.layers.average_pooling2d( | ||
out ** 2, pool_size=pool_size, strides=pool_size, padding="same" | ||
) | ||
+ 1e-8 | ||
) | ||
|
||
out = tf.cast(unscaled_images, tf.float32) / 255.0 | ||
with tf.variable_scope("1a"): | ||
out = conv_layer(out, 16, 7) | ||
out = pool_l2(out, 2) | ||
with tf.variable_scope("2a"): | ||
out = conv_layer(out, 32, 5) | ||
with tf.variable_scope("2b"): | ||
out = conv_layer(out, 32, 5) | ||
out = pool_l2(out, 2) | ||
with tf.variable_scope("3a"): | ||
out = conv_layer(out, 32, 5) | ||
out = pool_l2(out, 2) | ||
with tf.variable_scope("4a"): | ||
out = conv_layer(out, 32, 5) | ||
out = pool_l2(out, 2) | ||
out = tf.layers.flatten(out) | ||
out = tf.layers.dense(out, 256, activation=tf.nn.relu) | ||
out = tf.layers.dense(out, 512, activation=tf.nn.relu) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
import lucid.optvis.render as render | ||
import itertools | ||
from lucid.misc.gradient_override import gradient_override_map | ||
|
||
|
||
def maxpool_override(): | ||
def MaxPoolGrad(op, grad): | ||
inp = op.inputs[0] | ||
op_args = [ | ||
op.get_attr("ksize"), | ||
op.get_attr("strides"), | ||
op.get_attr("padding"), | ||
] | ||
smooth_out = tf.nn.avg_pool(inp ** 2, *op_args) / ( | ||
1e-2 + tf.nn.avg_pool(tf.abs(inp), *op_args) | ||
) | ||
inp_smooth_grad = tf.gradients(smooth_out, [inp], grad)[0] | ||
return inp_smooth_grad | ||
|
||
return {"MaxPool": MaxPoolGrad} | ||
|
||
|
||
def get_acts(model, layer_name, obses): | ||
with tf.Graph().as_default(), tf.Session(): | ||
t_obses = tf.placeholder_with_default( | ||
obses.astype(np.float32), (None, None, None, None) | ||
) | ||
T = render.import_model(model, t_obses, t_obses) | ||
t_acts = T(layer_name) | ||
return t_acts.eval() | ||
|
||
|
||
def get_grad_or_attr( | ||
model, | ||
layer_name, | ||
prev_layer_name, | ||
obses, | ||
*, | ||
act_dir=None, | ||
act_poses=None, | ||
score_fn=tf.reduce_sum, | ||
grad_or_attr, | ||
override=None, | ||
integrate_steps=1 | ||
): | ||
with tf.Graph().as_default(), tf.Session(), gradient_override_map(override or {}): | ||
t_obses = tf.placeholder_with_default( | ||
obses.astype(np.float32), (None, None, None, None) | ||
) | ||
T = render.import_model(model, t_obses, t_obses) | ||
t_acts = T(layer_name) | ||
if prev_layer_name is None: | ||
t_acts_prev = t_obses | ||
else: | ||
t_acts_prev = T(prev_layer_name) | ||
if act_dir is not None: | ||
t_acts = act_dir[None, None, None] * t_acts | ||
if act_poses is not None: | ||
t_acts = tf.gather_nd( | ||
t_acts, | ||
tf.concat([tf.range(obses.shape[0])[..., None], act_poses], axis=-1), | ||
) | ||
t_score = score_fn(t_acts) | ||
t_grad = tf.gradients(t_score, [t_acts_prev])[0] | ||
if integrate_steps > 1: | ||
acts_prev = t_acts_prev.eval() | ||
grad = ( | ||
sum( | ||
[ | ||
t_grad.eval(feed_dict={t_acts_prev: acts_prev * alpha}) | ||
for alpha in np.linspace(0, 1, integrate_steps + 1)[1:] | ||
] | ||
) | ||
/ integrate_steps | ||
) | ||
else: | ||
acts_prev = None | ||
grad = t_grad.eval() | ||
if grad_or_attr == "grad": | ||
return grad | ||
elif grad_or_attr == "attr": | ||
if acts_prev is None: | ||
acts_prev = t_acts_prev.eval() | ||
return acts_prev * grad | ||
else: | ||
raise NotImplementedError | ||
|
||
|
||
def get_attr(model, layer_name, prev_layer_name, obses, **kwargs): | ||
kwargs["grad_or_attr"] = "attr" | ||
return get_grad_or_attr(model, layer_name, prev_layer_name, obses, **kwargs) | ||
|
||
|
||
def get_grad(model, layer_name, obses, **kwargs): | ||
kwargs["grad_or_attr"] = "grad" | ||
return get_grad_or_attr(model, layer_name, None, obses, **kwargs) | ||
|
||
|
||
def get_paths(acts, nmf, *, max_paths, integrate_steps): | ||
acts_reduced = nmf.transform(acts) | ||
residual = acts - nmf.inverse_transform(acts_reduced) | ||
combs = itertools.combinations(range(nmf.features), nmf.features // 2) | ||
if nmf.features % 2 == 0: | ||
combs = np.array([comb for comb in combs if 0 in comb]) | ||
else: | ||
combs = np.array(list(combs)) | ||
if max_paths is None: | ||
splits = combs | ||
else: | ||
num_splits = min((max_paths + 1) // 2, combs.shape[0]) | ||
splits = combs[ | ||
np.random.choice(combs.shape[0], size=num_splits, replace=False), : | ||
] | ||
for i, split in enumerate(splits): | ||
indices = np.zeros(nmf.features) | ||
indices[split] = 1.0 | ||
indices = indices[tuple(None for _ in range(acts_reduced.ndim - 1))] | ||
complements = [False, True] | ||
if max_paths is not None and i * 2 + 1 == max_paths: | ||
complements = [np.random.choice(complements)] | ||
for complement in complements: | ||
path = [] | ||
for alpha in np.linspace(0, 1, integrate_steps + 1)[1:]: | ||
if complement: | ||
coordinates = (1.0 - indices) * alpha ** 2 + indices * ( | ||
1.0 - (1.0 - alpha) ** 2 | ||
) | ||
else: | ||
coordinates = indices * alpha ** 2 + (1.0 - indices) * ( | ||
1.0 - (1.0 - alpha) ** 2 | ||
) | ||
path.append( | ||
nmf.inverse_transform(acts_reduced * coordinates) + residual * alpha | ||
) | ||
yield path | ||
|
||
|
||
def get_multi_path_attr( | ||
model, | ||
layer_name, | ||
prev_layer_name, | ||
obses, | ||
prev_nmf, | ||
*, | ||
act_dir=None, | ||
act_poses=None, | ||
score_fn=tf.reduce_sum, | ||
override=None, | ||
max_paths=50, | ||
integrate_steps=10 | ||
): | ||
with tf.Graph().as_default(), tf.Session(), gradient_override_map(override or {}): | ||
t_obses = tf.placeholder_with_default( | ||
obses.astype(np.float32), (None, None, None, None) | ||
) | ||
T = render.import_model(model, t_obses, t_obses) | ||
t_acts = T(layer_name) | ||
if prev_layer_name is None: | ||
t_acts_prev = t_obses | ||
else: | ||
t_acts_prev = T(prev_layer_name) | ||
if act_dir is not None: | ||
t_acts = act_dir[None, None, None] * t_acts | ||
if act_poses is not None: | ||
t_acts = tf.gather_nd( | ||
t_acts, | ||
tf.concat([tf.range(obses.shape[0])[..., None], act_poses], axis=-1), | ||
) | ||
t_score = score_fn(t_acts) | ||
t_grad = tf.gradients(t_score, [t_acts_prev])[0] | ||
acts_prev = t_acts_prev.eval() | ||
path_acts = get_paths( | ||
acts_prev, prev_nmf, max_paths=max_paths, integrate_steps=integrate_steps | ||
) | ||
deltas_of_path = lambda path: np.array( | ||
[b - a for a, b in zip([np.zeros_like(acts_prev)] + path[:-1], path)] | ||
) | ||
grads_of_path = lambda path: np.array( | ||
[t_grad.eval(feed_dict={t_acts_prev: acts}) for acts in path] | ||
) | ||
path_attrs = map( | ||
lambda path: (deltas_of_path(path) * grads_of_path(path)).sum(axis=0), | ||
path_acts, | ||
) | ||
total_attr = 0 | ||
num_paths = 0 | ||
for attr in path_attrs: | ||
total_attr += attr | ||
num_paths += 1 | ||
return total_attr / num_paths |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from lucid.misc.io.saving import nullcontext | ||
from lucid.misc.io.loading import load_using_loader | ||
from lucid.misc.io.writing import write_handle | ||
|
||
|
||
def load_joblib(url_or_handle, *, cache=None, **kwargs): | ||
import joblib | ||
|
||
return load_using_loader( | ||
url_or_handle, | ||
decompressor=nullcontext, | ||
loader=joblib.load, | ||
cache=cache, | ||
**kwargs | ||
) | ||
|
||
|
||
def save_joblib(value, url_or_handle, **kwargs): | ||
import joblib | ||
|
||
if hasattr(url_or_handle, "write") and hasattr(url_or_handle, "name"): | ||
joblib.dump(value, url_or_handle, **kwargs) | ||
else: | ||
with write_handle(url_or_handle) as handle: | ||
joblib.dump(value, handle, **kwargs) |
Oops, something went wrong.