In [11]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [12]:
import tensorflow as tf
tf.enable_eager_execution()
tfe = tf.contrib.eager
import numpy as np
import skimage.transform
import pdb
import pyrednertensorflow as pyredner
import redner

import pyrednertensorflow as pyrednertorch
import torch

import utils 

# Set up the pyredner scene for rendering:
material_map, mesh_list, light_map = pyredner.load_obj('scenes/teapot.obj')
for _, mesh in mesh_list:
    # mesh.normals = pyredner.pyredner.compute_vertex_normal(mesh.vertices, mesh.indices)
    mesh.normals = pyredner.compute_vertex_normal(mesh.vertices, mesh.indices)

material_maptorch, mesh_listtorch, light_maptorch = pyrednertorch.load_obj('scenes/teapot.obj')
for _, mesh in mesh_listtorch:
    mesh.normals = pyrednertorch.compute_vertex_normal(mesh.vertices, mesh.indices)

for (_, mesh), (_, meshtf) in zip(mesh_list, mesh_listtorch):
    assert utils.is_same_tensor(mesh.normals, meshtf.normals)

# Setup camera
cam = pyredner.Camera(position = tfe.Variable([0.0, 30.0, 200.0]),
                      look_at = tfe.Variable([0.0, 30.0, 0.0]),
                      up = tfe.Variable([0.0, 1.0, 0.0], dtype=tf.float32),
                      fov = tfe.Variable([45.0], dtype=tf.float32), # in degree
                      clip_near = 1e-2, # needs to > 0
                      resolution = (256, 256),
                      fisheye = False)


# Setup camera
camtorch = pyrednertorch.Camera(position = torch.tensor([0.0, 30.0, 200.0]),
                      look_at = torch.tensor([0.0, 30.0, 0.0]),
                      up = torch.tensor([0.0, 1.0, 0.0]),
                      fov = torch.tensor([45.0]), # in degree
                      clip_near = 1e-2, # needs to > 0
                      resolution = (256, 256),
                      fisheye = False)

assert utils.is_same_camera(cam, camtorch)

material_id_map = {}
materials = []
count = 0
for key, value in material_map.items():
    material_id_map[key] = count
    count += 1
    materials.append(value)

material_id_maptorch = {}
materialstorch = []
count = 0
for key, value in material_maptorch.items():
    material_id_maptorch[key] = count
    count += 1
    materialstorch.append(value)

assert material_id_map == material_id_maptorch
assert utils.is_same_container(materials, materialstorch)

shapes = []
print(f">>> {len(mesh_list)} meshes in the session")
for mtl_name, mesh in mesh_list:
    shapes.append(pyredner.Shape(
        vertices = mesh.vertices,
        indices = mesh.indices,
        uvs = mesh.uvs,
        normals = mesh.normals,
        material_id = material_id_map[mtl_name]))

shapestorch = []
for mtl_name, mesh in mesh_listtorch:
    shapestorch.append(pyrednertorch.Shape(\
        vertices = mesh.vertices,
        indices = mesh.indices,
        uvs = mesh.uvs,
        normals = mesh.normals,
        material_id = material_id_maptorch[mtl_name]))

assert utils.is_same_container(shapes, shapestorch)

scene = pyredner.Scene(cam, shapes, materials, area_lights = [], envmap = None)
scenetorch = pyrednertorch.Scene(camtorch, shapestorch, materialstorch, area_lights = [], envmap = None)

assert utils.is_same_scene(scene, scenetorch)

scene_args = pyredner.RenderFunction.serialize_scene(
    scene = scene,
    num_samples = 16,
    max_bounces = 0,
    channels = [redner.channels.depth, redner.channels.shading_normal])

scene_argstorch = pyrednertorch.RenderFunction.serialize_scene(\
    scene = scenetorch,
    num_samples = 16,
    max_bounces = 0,
    channels = [redner.channels.depth, redner.channels.shading_normal])

assert utils.is_same_scene_args(scene_args, scene_argstorch)

render = pyredner.RenderFunction
# Render. The first argument is the seed for RNG in the renderer.
img = render.forward(0, scene_args)

rendertorch = pyrednertorch.RenderFunction.apply
# Render. The first argument is the seed for RNG in the renderer.
imgtorch = rendertorch(0, *scene_argstorch)

assert utils.is_same_tensor(img, imgtorch)

depth = img[:, :, 0]
normal = img[:, :, 1:4]
depthtorch = imgtorch[:, :, 0]
normaltorch = imgtorch[:, :, 1:4]

target_depth = pyredner.imread('results/test_g_buffer/target_depth.exr')
target_depth = target_depth[:, :, 0]
target_normal = pyredner.imread('results/test_g_buffer/target_normal.exr')

target_depthtorch = pyrednertorch.imread('results/test_g_buffer/target_depth.exr')
target_depthtorch = target_depthtorch[:, :, 0]
target_normaltorch = pyrednertorch.imread('results/test_g_buffer/target_normal.exr')

assert utils.is_same_tensor(target_depth, target_depthtorch)
assert utils.is_same_tensor(target_normal, target_normaltorch)

translation_params = tfe.Variable([0.1, -0.1, 0.1], trainable=True)
translation = translation_params * 100.0
euler_angles = tfe.Variable([0.1, -0.1, 0.1], trainable=True)

translation_paramstorch = torch.tensor([0.1, -0.1, 0.1],
    device = pyrednertorch.get_device(), requires_grad=True)
translationtorch = translation_paramstorch * 100.0
euler_anglestorch = torch.tensor([0.1, -0.1, 0.1], requires_grad=True)

assert utils.is_same_tensor(translation, translationtorch)
assert utils.is_same_tensor(euler_angles, euler_anglestorch)

shape0_vertices = tf.identity(shapes[0].vertices)
shape1_vertices = tf.identity(shapes[1].vertices)

shape0_verticestorch = shapestorch[0].vertices.clone()
shape1_verticestorch = shapestorch[1].vertices.clone()

assert utils.is_same_tensor(shape0_vertices, shape0_verticestorch)
assert utils.is_same_tensor(shape1_vertices, shape1_verticestorch)

rotation_matrix = pyredner.gen_rotate_matrix(euler_angles)
rotation_matrixtorch = pyrednertorch.gen_rotate_matrix(euler_anglestorch)

assert utils.is_same_tensor(rotation_matrix, rotation_matrixtorch)

center = tf.math.reduce_mean(
    tf.concat([shape0_vertices, shape1_vertices], axis=0), 
    axis=0
    )
centertorch = torch.mean(torch.cat([shape0_verticestorch, shape1_verticestorch]), 0)

assert utils.is_same_tensor(center, centertorch)

shapes[0].vertices = \
    (shape0_vertices - center) @ tf.transpose(rotation_matrix) + \
    center + translation
shapes[1].vertices = \
    (shape1_vertices - center) @ tf.transpose(rotation_matrix) + \
    center + translation

shapestorch[0].vertices = \
    (shape0_verticestorch - centertorch) @ torch.t(rotation_matrixtorch) + \
    centertorch + translationtorch
shapestorch[1].vertices = \
    (shape1_verticestorch - centertorch) @ torch.t(rotation_matrixtorch) + \
    centertorch + translationtorch

assert utils.is_same_tensor(shapes[0].vertices, shapestorch[0].vertices)
assert utils.is_same_tensor(shapes[1].vertices, shapestorch[1].vertices)

shapes[0].normals = pyredner.compute_vertex_normal(shapes[0].vertices, shapes[0].indices)
shapes[1].normals = pyredner.compute_vertex_normal(shapes[1].vertices, shapes[1].indices)

shapestorch[0].normals = pyrednertorch.compute_vertex_normal(shapestorch[0].vertices, shapestorch[0].indices)
shapestorch[1].normals = pyrednertorch.compute_vertex_normal(shapestorch[1].vertices, shapestorch[1].indices)

assert utils.is_same_tensor(shapes[0].normals, shapestorch[0].normals)
assert utils.is_same_tensor(shapes[1].normals, shapestorch[1].normals)

scene_args = pyredner.RenderFunction.serialize_scene(
    scene = scene,
    num_samples = 16,
    max_bounces = 0,
    channels = [redner.channels.depth, redner.channels.shading_normal])

scene_argstorch = pyrednertorch.RenderFunction.serialize_scene(\
    scene = scenetorch,
    num_samples = 16,
    max_bounces = 0,
    channels = [redner.channels.depth, redner.channels.shading_normal])

assert utils.is_same_scene_args(scene_args, scene_argstorch)

img = render.forward(1, scene_args)
depth = img[:, :, 0]
normal = img[:, :, 1:4]

imgtorch = rendertorch(1, *scene_argstorch)
depthtorch = imgtorch[:, :, 0]
normaltorch = imgtorch[:, :, 1:4]

assert utils.is_same_tensor(img, imgtorch)

diff_depth = tf.abs(target_depth - depth)
diff_normal = tf.abs(target_normal - normal)

diff_depthtorch = torch.abs(target_depthtorch - depthtorch)
diff_normaltorch = torch.abs(target_normaltorch - normaltorch)

assert utils.is_same_tensor(diff_depth, diff_depthtorch)
assert utils.is_same_tensor(diff_normal, diff_normaltorch)

optimizer = tf.train.AdamOptimizer(1e-2)
optimizertorch = torch.optim.Adam([translation_paramstorch, euler_anglestorch], lr=1e-2)


>>> 2 meshes in the session
Forward pass, time: 0.07200 s
Scene construction, time: 0.11887 s
Forward pass, time: 0.11798 s
Detach
Detach
Detach
Detach
Detach
Detach
Detach
Detach
Detach
Detach
Detach
Forward pass, time: 0.07692 s
Scene construction, time: 0.16078 s
Forward pass, time: 0.12383 s
Detach
Detach
Detach


In [13]:
t=0

In [14]:
shapes[0].vertices.shape

TensorShape([Dimension(6460), Dimension(3)])

In [5]:
with tf.GradientTape(persistent=True) as g1:
    g1.watch(translation_params)
    translation = translation_params * 100.0
    with tf.GradientTape(persistent=True) as g2:
        g2.watch(euler_angles)
        rotation_matrix = pyredner.gen_rotate_matrix(euler_angles)
        center = tf.math.reduce_mean(
            tf.concat([shape0_vertices, shape1_vertices], axis=0), 
            axis=0)
        shapes[0].vertices = \
            (shape0_vertices - center) @ tf.transpose(rotation_matrix) + \
            center + translation
        shapes[1].vertices = \
            (shape1_vertices - center) @ tf.transpose(rotation_matrix) + \
            center + translation

d_translation_params_dshape0vertices = g1.gradient(shapes[0].vertices, translation_params)
d_translation_params_dshape1vertices = g1.gradient(shapes[1].vertices, translation_params)

deuler_dshape0vertices = g2.gradient(shapes[0].vertices, euler_angles)
deuler_dshape1vertices = g2.gradient(shapes[1].vertices, euler_angles)

shapes[0].normals = pyredner.compute_vertex_normal(shapes[0].vertices, shapes[0].indices)
shapes[1].normals = pyredner.compute_vertex_normal(shapes[1].vertices, shapes[1].indices)
del g1, g2


optimizertorch.zero_grad()
translationtorch = translation_paramstorch * 100.0
rotation_matrixtorch = pyrednertorch.gen_rotate_matrix(euler_anglestorch)
centertorch = torch.mean(torch.cat([shape0_verticestorch, shape1_verticestorch]), 0)
shapestorch[0].vertices = \
    (shape0_verticestorch - centertorch) @ torch.t(rotation_matrixtorch) + \
    centertorch + translationtorch
shapestorch[1].vertices = \
    (shape1_verticestorch - centertorch) @ torch.t(rotation_matrixtorch) + \
    centertorch + translationtorch

shapestorch[0].normals = pyrednertorch.compute_vertex_normal(shapestorch[0].vertices, shapestorch[0].indices)
shapestorch[1].normals = pyrednertorch.compute_vertex_normal(shapestorch[1].vertices, shapestorch[1].indices)


assert utils.is_same_tensor(shapes[0].normals, shapestorch[0].normals)
assert utils.is_same_tensor(shapes[1].normals, shapestorch[1].normals)


scene_args = pyredner.RenderFunction.serialize_scene(
    scene = scene,
    num_samples = 16,
    max_bounces = 0,
    channels = [redner.channels.depth, redner.channels.shading_normal])

scene_argstorch = pyrednertorch.RenderFunction.serialize_scene(\
    scene = scenetorch,
    num_samples = 16,
    max_bounces = 0,
    channels = [redner.channels.depth, redner.channels.shading_normal])

assert utils.is_same_scene_args(scene_args, scene_argstorch)

img = render.forward(t+1, scene_args)

with tf.GradientTape(persistent=False) as g3:
    g3.watch(img)
    depth = img[:, :, 0]
    normal = img[:, :, 1:4]

    # Save the intermediate render.
    pyredner.imwrite(depth,
        'results/test_g_buffer/iter_depth_{}.png'.format(t),
        normalize = True)
    pyredner.imwrite(normal,
        'results/test_g_buffer/iter_normal_{}.png'.format(t),
        normalize = True)
    # Compute the loss function. Here it is L2.
    loss = tf.reduce_sum(tf.square(depth - target_depth)) / 200.0 \
        + tf.reduce_sum(tf.square(normal - target_normal))
    print('loss:', loss)

d_img = g3.gradient(loss, img)
grads = render.backward(d_img)

##################################################################
# Torch
imgtorch = rendertorch(t+1, *scene_argstorch)
depthtorch = imgtorch[:, :, 0]
normaltorch = imgtorch[:, :, 1:4]
losstorch = (depthtorch - target_depthtorch).pow(2).sum() / 200.0 + (normaltorch - target_normaltorch).pow(2).sum()
losstorch.backward()


Detach
Detach
Detach
Detach
Detach
Detach




Forward pass, time: 0.06180 s
loss: tf.Tensor(1588092.6, shape=(), dtype=float32)
Backward pass, time: 1.13145 s
Scene construction, time: 0.10722 s
Forward pass, time: 0.08823 s
Backward pass, time: 1.23783 s


In [20]:
img.shape

TensorShape([Dimension(256), Dimension(256), Dimension(4)])

In [19]:
d_img.shape

TensorShape([Dimension(256), Dimension(256), Dimension(4)])

## View grads from PyTorch

In [29]:
translation_paramstorch

tensor([ 0.1000, -0.1000,  0.1000], requires_grad=True)

In [6]:
td1 = translation_paramstorch.grad; td1

tensor([ 1892877.7500, -3885791.0000, -8072497.5000])

In [7]:
de1 = euler_anglestorch.grad; de1

tensor([ -916749.6875,    75257.5625, -1346990.6250])

In [15]:
grads.d_vertices_list[0]

<tf.Variable 'd_vertices_0:0' shape=(6460, 3) dtype=float32, numpy=
array([[-0.00082663,  0.00044479, -0.00012241],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       ...,
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ]], dtype=float32)>

In [9]:
tf.reduce_sum(grads.d_vertices_list[0], axis=1)

<tf.Tensor: id=5103, shape=(6460,), dtype=float32, numpy=
array([-0.00050425,  0.        ,  0.        , ...,  0.        ,
        0.        ,  0.        ], dtype=float32)>

In [10]:
tf.reduce_sum(grads.d_vertices_list[0])

<tf.Tensor: id=5111, shape=(), dtype=float32, numpy=-87772.555>

## View grads from TF

In [22]:
grads.d_position

<tf.Variable 'd_position:0' shape=(3,) dtype=float32, numpy=array([-4303.9756,  6767.5195, 80900.06  ], dtype=float32)>

In [16]:
d_translation_params_dshape0vertices.shape

TensorShape([Dimension(3)])

In [25]:
shapes[0].vertices

<tf.Tensor: id=6280, shape=(6460, 3), dtype=float32, numpy=
array([[41.045395 , 53.686386 ,  8.889906 ],
       [41.378353 , 53.6588   ,  8.908552 ],
       [41.26592  , 53.17974  ,  4.262364 ],
       ...,
       [81.39684  , 55.873413 , -3.1109543],
       [83.08425  , 56.150414 , -3.173088 ],
       [83.08425  , 56.150414 , -3.173088 ]], dtype=float32)>

In [23]:
grads.d_vertices_list[0]

<tf.Variable 'd_vertices_0:0' shape=(6460, 3) dtype=float32, numpy=
array([[-0.00082663,  0.00044479, -0.00012241],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       ...,
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ]], dtype=float32)>

In [26]:
tf.reduce_mean(grads.d_vertices_list[0], axis=1)

<tf.Tensor: id=7013, shape=(6460,), dtype=float32, numpy=
array([-0.00016808,  0.        ,  0.        , ...,  0.        ,
        0.        ,  0.        ], dtype=float32)>

In [27]:
tf.reduce_mean(grads.d_vertices_list[0])

<tf.Tensor: id=7021, shape=(), dtype=float32, numpy=-4.5290275>

In [24]:
grads.d_vertices_list[1]

<tf.Variable 'd_vertices_1:0' shape=(1874, 3) dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       ...,
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

In [7]:
# sum or mean?
tf.reduce_sum(d_translation_params_dshape0vertices * grads.d_vertices_list[0], axis=0)

<tf.Tensor: id=3640, shape=(3,), dtype=float32, numpy=array([ 6.6481085e+09, -1.4127784e+10, -4.9221358e+10], dtype=float32)>

In [8]:
tf.reduce_sum(d_translation_params_dshape1vertices * grads.d_vertices_list[1], axis=0)

<tf.Tensor: id=3645, shape=(3,), dtype=float32, numpy=array([ 1.6186853e+09, -3.1835960e+09, -8.4908710e+08], dtype=float32)>

In [13]:
tf.concat([
    tf.reduce_sum(d_translation_params_dshape0vertices * grads.d_vertices_list[0], axis=0),
    tf.reduce_sum(d_translation_params_dshape1vertices * grads.d_vertices_list[1], axis=0)
], axis=0)

<tf.Tensor: id=3689, shape=(6,), dtype=float32, numpy=
array([ 6.6481085e+09, -1.4127784e+10, -4.9221358e+10,  1.6186853e+09,
       -3.1835960e+09, -8.4908710e+08], dtype=float32)>

In [14]:
tf.reduce_sum(d_translation_params_dshape0vertices * grads.d_vertices_list[0], axis=0) + tf.reduce_sum(d_translation_params_dshape1vertices * grads.d_vertices_list[1], axis=0)

<tf.Tensor: id=3699, shape=(3,), dtype=float32, numpy=array([ 8.2667940e+09, -1.7311379e+10, -5.0070446e+10], dtype=float32)>

In [16]:
td2 = tf.reduce_mean(d_translation_params_dshape0vertices * grads.d_vertices_list[0], axis=0) + tf.reduce_mean(d_translation_params_dshape1vertices * grads.d_vertices_list[1], axis=0)

## Compare d_translation_params

In [19]:
td1.numpy() - td2.numpy()

array([-0.75, -3.5 , -4.5 ], dtype=float32)

## Compare d_euler_angles

In [22]:
de2 = tf.reduce_mean(deuler_dshape0vertices * grads.d_vertices_list[0], axis=0) + tf.reduce_mean(deuler_dshape1vertices * grads.d_vertices_list[1], axis=0); de2

<tf.Tensor: id=3749, shape=(3,), dtype=float32, numpy=array([-185311.  ,  -56637.47, -495857.06], dtype=float32)>

In [25]:
de1.numpy() - de2.numpy()

array([-731438.7 ,  131895.03, -851133.56], dtype=float32)

In [28]:
shapes[0].vertices.shape, shapes[1].vertices.shape

(TensorShape([Dimension(6460), Dimension(3)]),
 TensorShape([Dimension(1874), Dimension(3)]))

In [30]:
d_translation_params_dshape0vertices

<tf.Tensor: id=1989, shape=(3,), dtype=float32, numpy=array([646000., 646000., 646000.], dtype=float32)>

In [31]:
grads.d_vertices_list[0].shape

TensorShape([Dimension(6460), Dimension(3)])

In [53]:

def render(scene_args):
    print(len(scene_args))
    print(scene_args)
    if scene_args[0]:
        return scene_args[1] * 2
    else:
        return scene_args[1] * 40

def backward(d_img):
    return 2
    return 2 * d_img
    
@tf.custom_gradient
def render_op(scene_args):
    print(scene_args)
    img = render(scene_args)
    def grad(d_img):
        return None, backward(d_img)
    
    return img, grad

x = [
    tf.constant([8]),
    tf.constant([8])
]

# img, grad = render_op(x)
img = render_op(x)

with tf.GradientTape() as g:
    g.watch(x)
    img = render_op(x)
    
print(g.gradient(img, x))

W0705 07:07:46.508325 139977715046144 backprop.py:820] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int32
W0705 07:07:46.509212 139977715046144 backprop.py:820] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int32
W0705 07:07:46.511441 139977715046144 backprop.py:954] The dtype of the target tensor must be floating (e.g. tf.float32) when calling GradientTape.gradient, got tf.int32
W0705 07:07:46.512087 139977715046144 backprop.py:968] The dtype of the source tensor must be floating (e.g. tf.float32) when calling GradientTape.gradient, got tf.int32
W0705 07:07:46.512669 139977715046144 backprop.py:968] The dtype of the source tensor must be floating (e.g. tf.float32) when calling GradientTape.gradient, got tf.int32


[<tf.Tensor: id=7235, shape=(1,), dtype=int32, numpy=array([8], dtype=int32)>, <tf.Tensor: id=7236, shape=(1,), dtype=int32, numpy=array([8], dtype=int32)>]
2
[<tf.Tensor: id=7235, shape=(1,), dtype=int32, numpy=array([8], dtype=int32)>, <tf.Tensor: id=7236, shape=(1,), dtype=int32, numpy=array([8], dtype=int32)>]
[<tf.Tensor: id=7235, shape=(1,), dtype=int32, numpy=array([8], dtype=int32)>, <tf.Tensor: id=7236, shape=(1,), dtype=int32, numpy=array([8], dtype=int32)>]
2
[<tf.Tensor: id=7235, shape=(1,), dtype=int32, numpy=array([8], dtype=int32)>, <tf.Tensor: id=7236, shape=(1,), dtype=int32, numpy=array([8], dtype=int32)>]
[None, None]


In [69]:
@tf.custom_gradient
def log1pexp(x):
  e = tf.exp(x)
  def grad(dy):
    return dy * (1 - 1 / (1 + e))
  return tf.math.log(1 + e), grad

def grad_log1pexp(x):
  with tf.GradientTape() as tape:
    tape.watch(x)
    value = log1pexp(x)
  return tape.gradient(value, x)


# As before, the gradient computation works fine at x = 0.
grad_log1pexp(tf.constant(0.)).numpy()

0.5

In [75]:
@tf.custom_gradient
def log1pexp(x):
    e = tf.exp(x)
    def grad(dy):
        return dy * (1 - 1 / (1 + e))
    return tf.math.log(1 + e), grad


def get_grad(x):
    with tf.GradientTape() as tape:
        tape.watch(x)
        y = log1pexp(x)
    
    return tape.gradient(y, x)

get_grad(tf.constant(2.))

<tf.Tensor: id=7408, shape=(), dtype=float32, numpy=0.8807971>

In [84]:
@tf.custom_gradient
def log1pexp(*x):
    print(b)
    x = tf.stack(x)
    e = tf.exp(x)
    def grad(dy):
        print(dy.numpy())
        return dy * (1 - 1 / (1 + e))
    return tf.math.log(1 + e), grad

def grad_log1pexp(*x):
    with tf.GradientTape() as tape:
        tape.watch(x)
        y = log1pexp(*x)
    
    return y, tape.gradient(y, x)


# As before, the gradient computation works fine at x = 0.
grad_log1pexp(tf.constant(0.))

False
1.0


ValueError: ('custom_gradient function expected to return', 2, 'gradients but returned', 1, 'instead.')

In [83]:
np.log(2)

0.6931471805599453

In [97]:
x = tf.constant(0.)
x1 = tf.constant(10.)
x2 = tf.constant(2.)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(x)
    y = x + 2
    y2 = y * x1
    y3 = y2 * x2
    
print(tape.watched_variables())
print(tape.gradient(y, x).numpy())
print(tape.gradient(y2, x).numpy())
print(tape.gradient(y3, x).numpy())
del tape

()
1.0
10.0
20.0


In [101]:
x = tf.constant(3.0)
with tf.GradientTape(persistent=True) as g:
    g.watch(x)
    y = x * x
    z = y * y
dz_dx = g.gradient(z, x)  # 108.0 (4*x^3 at x = 3)
dy_dx = g.gradient(y, x)  # 6.0
print(dz_dx)
print(dy_dx)
del g  # Drop the reference to the tape

None
None
