In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
# tf.enable_eager_execution()
import math
import os
import pickle
import glob
from PIL import Image

In [3]:
tfk = tf.keras
tfkl = tf.keras.layers

In [None]:
# seed_dir = '/nfs/kun1/users/justinvyu/ray_results/gym/DClaw/TurnFreeValve3ResetFreeSwapGoal-v0/2019-08-31T12-04-34-spatial_softmax_reset_free/id=2c4f706c-seed=519_2019-08-31_12-04-35_57pmudj'
seed_dir = '/home/justinyu/ray_results/gym/DClaw/TurnFreeValve3Fixed-v0/2019-09-02T11-39-20-spatial_softmax_larger_convnet/id=fa779d2c-seed=7846_2019-09-02_11-39-21nf8pk80r'

def visualize_convnet(seed_dir, n_images=3):
    pixels_dir = os.path.join(seed_dir, 'pixels')
    for checkpoint_dir in glob.iglob(os.path.join(seed_dir, 'checkpoint_*')):
        print(checkpoint_dir)
        print(pixels_dir)
        
        chekcpoint_path = os.path.join(checkpoint_dir, 'checkpoint.pkl')
        with open(chekcpoint_path, 'rb') as f:
            picklable = pickle.load(f)
            
        variant = picklable['variant']

        environment_params = (
            variant['environment_params']['evaluation']
            if 'evaluation' in variant['environment_params']
            else variant['environment_params']['training'])

        from softlearning.environments.utils import get_environment_from_params
        evaluation_environment = get_environment_from_params(environment_params)

        from softlearning.policies.utils import get_policy_from_variant
        policy = (
            get_policy_from_variant(variant, evaluation_environment))
        policy.set_weights(picklable['policy_weights'])
        
        convnet = policy.actions_model.get_layer('convnet_preprocessor')
        convnet.summary()
        
        i = 0
        for im_path in glob.iglob(os.path.join(pixels_dir, '*.png')):
            im = Image.open(im_path)
            image = np.array(im)
            plt.figure()
            plt.axis('off')
            plt.imshow(image)
            visualize_layers(convnet, image)
            i += 1
            if i == n_images:
                break
            break
        break

# convnet, image = visualize_convnet(seed_dir)
visualize_convnet(seed_dir)

In [18]:
def visualize_layers(convnet, image):
    layers = convnet.layers
    layers_to_visualize = [layer for layer in layers 
                           if isinstance(layer, tfk.Sequential) or isinstance(layer, tfk.Model)]
    layer_names = set()
    for layer in layers_to_visualize:
        if layer.name in layer_names:
            layer._name = layer.name + str(len(layer_names))
        layer_names.add(layer.name)
        layer_model = tfk.Model(convnet.inputs, layer.outputs)
        activations = layer_model.predict(image[None])
        plot_activations(activations)
        
    output_model = tfk.Model(convnet.inputs, convnet.outputs)
    keypoints = output_model.predict(image[None]).reshape((-1, 2))

    plot_activations(activations, keypoints=keypoints)
    
    w, h = activations.shape[1], activations.shape[2]
    import skimage
    im = skimage.util.img_as_float(image)
    im = skimage.transform.downscale_local_mean(
        im, (image.shape[0] // w, image.shape[1] // h, 1))
    im = skimage.util.img_as_ubyte(im)
    
    plt.figure(figsize=(5, 5))
#     plt.gca().invert_yaxis()
    plt.imshow(im)
    kp = (keypoints + 1) / 2
    kp[:, 0] *= w
    kp[:, 1] *= w
    plt.scatter(x=kp[:, 0], y=kp[:, 1])
    plt.show()

def plot_activations(activations, keypoints=None, title=''):
    n_filters = activations.shape[3]
    n_columns = int(np.sqrt(n_filters))
    n_rows = math.ceil(n_filters / n_columns) + 1
    width, height = activations.shape[1], activations.shape[2]
    plt.figure(figsize=(5 * n_columns, 5 * n_rows))
#     plt.title(title)
    for i in range(n_filters):
        plt.subplot(n_rows, n_columns, i+1)
        plt.axis('off')
        plt.title(f'Filter {i}')
        plt.imshow(activations[0, :, :, i], interpolation="nearest", cmap="gray")
        if keypoints is not None:
            point = keypoints[i]
            x, y = (point + 1) / 2
            x, y = int(width * x), int(height * y)
            plt.scatter([x], [y], c='r', s=100)
        
# visualize_layers(convnet, image)