In [1]:
import tensorflow as tf
import numpy as np

In [7]:
#define model creation functions
def make_dense_nn(scope, n_inputs, n_outputs, h_layer_node_dict, loss_fn, lr=1e-4):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        x = tf.placeholder(tf.float32, [None, n_inputs], name="x")
        y = tf.placeholder(tf.float32, [None, n_outputs], name="y")
        
        prev_layer = x
        for layer_name, layer_nodes in h_layer_node_dict.items():
            prev_layer = tf.layers.dense(prev_layer, layer_nodes, activation=tf.nn.relu, kernel_initializer=tf.contrib.layers.xavier_initializer(), name=layer_name)
        
        output = tf.layers.dense(prev_layer, n_outputs)
        
        loss = tf.reduce_mean(loss_fn(y, output))
        train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss)
        
        return x, y, output, loss, train_op

In [8]:
#construct models
from_model = {}
to_model = {}
from_model['x'], from_model['y'], from_model['output'], from_model['loss'], from_model['train_op'] = make_dense_nn("from_model", 2, 1, {'h1': 5}, tf.losses.mean_squared_error)
to_model['x'], to_model['y'], to_model['output'], to_model['loss'], to_model['train_op'] = make_dense_nn("to_model", 2, 1, {'h1': 5}, tf.losses.mean_squared_error)

In [9]:
#define constants
train_steps = 20000
x_batch = np.array([[0, 0],
                    [0, 1],
                    [1, 0],
                    [1, 1]])
y_batch = np.array([[0],
                    [1],
                    [1],
                    [0]])

In [10]:
#define model training
def train(model, train_steps, x_batch, y_batch, sess):
    for step in range(train_steps):
        _, loss = sess.run([model['train_op'], model['loss']], feed_dict={model['x']: x_batch, model['y']: y_batch})
        if step % 1000 == 0:
            print("Loss:", loss)
            print("Predictions:", np.round(sess.run(from_model['output'], feed_dict={from_model['x']: x_batch}), decimals=0))

In [11]:
#train from model, copy weights, and test to model
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#train from model
train(from_model, train_steps, x_batch, y_batch, sess)

Loss: 0.3855492
Predictions: [[0.]
 [0.]
 [0.]
 [1.]]
Loss: 0.26098192
Predictions: [[0.]
 [0.]
 [1.]
 [1.]]
Loss: 0.18184265
Predictions: [[0.]
 [1.]
 [1.]
 [1.]]
Loss: 0.1201729
Predictions: [[0.]
 [1.]
 [1.]
 [0.]]
Loss: 0.07017839
Predictions: [[0.]
 [1.]
 [1.]
 [0.]]
Loss: 0.036435578
Predictions: [[0.]
 [1.]
 [1.]
 [0.]]
Loss: 0.014771845
Predictions: [[0.]
 [1.]
 [1.]
 [0.]]
Loss: 0.0037521082
Predictions: [[0.]
 [1.]
 [1.]
 [0.]]
Loss: 0.00042550123
Predictions: [[0.]
 [1.]
 [1.]
 [0.]]
Loss: 1.070382e-05
Predictions: [[0.]
 [1.]
 [1.]
 [0.]]
Loss: 2.1192344e-08
Predictions: [[0.]
 [1.]
 [1.]
 [0.]]
Loss: 2.4716673e-11
Predictions: [[0.]
 [1.]
 [1.]
 [0.]]
Loss: 8.928969e-12
Predictions: [[0.]
 [1.]
 [1.]
 [0.]]
Loss: 3.0377922e-12
Predictions: [[0.]
 [1.]
 [1.]
 [0.]]
Loss: 1.0582646e-12
Predictions: [[0.]
 [1.]
 [1.]
 [0.]]
Loss: 3.9707126e-13
Predictions: [[0.]
 [1.]
 [1.]
 [0.]]
Loss: 1.2528867e-13
Predictions: [[0.]
 [1.]
 [1.]
 [0.]]
Loss: 4.1855408e-14
Predictions: [[0.]

In [14]:
#define copy function
def copy_weights(from_model_scope, to_model_scope, sess):
    for from_model_var, to_model_var in zip(tf.trainable_variables(from_model_scope), tf.trainable_variables(to_model_scope)):
        frm = from_model_var.eval(session=sess)
        to = to_model_var.eval(session=sess)
        np.copyto(to, frm)
        to_model_var.load(to, session=sess)

In [15]:
#copy weights test
for var in tf.trainable_variables("from_model"):
    print(var.eval(session=sess))
    
print("-------------------------------------")
    
for var in tf.trainable_variables("to_model"):
    print(var.eval(session=sess))
    

copy_weights("from_model", "to_model", sess)

print("\n\n\n")

#test to model
for var in tf.trainable_variables("from_model"):
    print(var.eval(session=sess))
    
print("-------------------------------------")
    
for var in tf.trainable_variables("to_model"):
    print(var.eval(session=sess))
    
    
print("\n\n\n")

print("Predictions:", np.round(sess.run(to_model['output'], feed_dict={to_model['x']: x_batch}), decimals=0))

[[ 0.96147174 -0.8560882   1.0645647   0.2961514   0.84780455]
 [-0.9614759   0.856066   -0.37036726  0.43776488 -0.84783417]]
[-1.0758272e-05 -9.5446167e-06  3.7035573e-01  1.2271113e-01
 -2.1793145e-05]
[[ 0.7830065 ]
 [ 0.7036456 ]
 [-0.50691587]
 [ 0.47948214]
 [ 0.76059026]]
[0.1289014]
-------------------------------------
[[ 0.35380292 -0.12276626 -0.7714646  -0.6156478   0.4833982 ]
 [-0.09294921 -0.30254883 -0.21956396  0.63383996  0.00560814]]
[0. 0. 0. 0. 0.]
[[-0.5980985 ]
 [-0.0123775 ]
 [ 0.56791997]
 [ 0.9893639 ]
 [-0.1618557 ]]
[0.]




[[ 0.96147174 -0.8560882   1.0645647   0.2961514   0.84780455]
 [-0.9614759   0.856066   -0.37036726  0.43776488 -0.84783417]]
[-1.0758272e-05 -9.5446167e-06  3.7035573e-01  1.2271113e-01
 -2.1793145e-05]
[[ 0.7830065 ]
 [ 0.7036456 ]
 [-0.50691587]
 [ 0.47948214]
 [ 0.76059026]]
[0.1289014]
-------------------------------------
[[ 0.96147174 -0.8560882   1.0645647   0.2961514   0.84780455]
 [-0.9614759   0.856066   -0.37036726  0.43776