Skip to content

Commit

Permalink
fix to prevent BS prefactor cache from breaking things if using two g…
Browse files Browse the repository at this point in the history
…raphs
  • Loading branch information
co9olguy committed Jun 7, 2018
1 parent 1cbb40f commit 3e5df74
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions strawberryfields/backends/tfbackend/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,18 @@ def beamsplitter_matrix(t, r, D, batched=False, save=False, directory=None):
N_minus_k = tf.where(tf.greater(N, k), N_minus_k, tf.zeros_like(N_minus_k))
M_minus_n_plus_k = tf.where(tf.greater(M_minus_n_plus_k, 0), M_minus_n_plus_k, tf.zeros_like(M_minus_n_plus_k))

powers = tf.cast(tf.pow(mag_t, k) * tf.pow(mag_r, n_minus_k) * tf.pow(mag_r, N_minus_k) * tf.pow(mag_t, M_minus_n_plus_k), def_type)
phase = tf.exp(1j * tf.cast(phase_r * (n - N), def_type))

# load parameter-independent prefactors
prefac = get_prefac_tensor(D, directory, save)

powers = tf.cast(tf.pow(mag_t, k) * tf.pow(mag_r, n_minus_k) * tf.pow(mag_r, N_minus_k) * tf.pow(mag_t, M_minus_n_plus_k), def_type)
phase = tf.exp(1j * tf.cast(phase_r * (n - N), def_type))
if prefac.graph != phase.graph:
# if cached prefactors live on another graph, we'll have to reload them into this graph.
# In future versions, if 'copy_variable_to_graph' comes out of contrib, consider using that
get_prefac_tensor.cache_clear()
prefac = get_prefac_tensor(D, directory, save)

BS_matrix = tf.reduce_sum(phase * powers * prefac, -1)

if not batched:
Expand Down

0 comments on commit 3e5df74

Please sign in to comment.