In [1]:
import tensorflow as tf
from tensorflow import keras


(x_train, y_train), (x_val, y_val) = keras.datasets.fashion_mnist.load_data()

In [2]:
def preprocess(x, y):
    x = tf.cast(x, tf.float32) / 255.0
    y = tf.cast(y, tf.int64)

    return x, y

def create_dataset(xs, ys, n_classes=10):
    le=len(ys)
    ys = tf.one_hot(ys, depth=n_classes)
    return tf.data.Dataset.from_tensor_slices((xs, ys)) \
    .map(preprocess) \
    .shuffle(le) \
    .batch(128)


In [3]:
train_dataset = create_dataset(x_train, y_train)
val_dataset = create_dataset(x_val, y_val)

In [4]:
model = keras.Sequential([
    keras.layers.Reshape(target_shape=(28 * 28,), input_shape=(28, 28)),
    keras.layers.Dense(units=256, activation='relu'),
    keras.layers.Dense(units=192, activation='relu'),
    keras.layers.Dense(units=128, activation='relu'),
    keras.layers.Dense(units=10, activation='softmax')
])

In [5]:
model.compile(optimizer='adam', 
              loss=tf.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(
    train_dataset.repeat(), 
    epochs=10, 
    steps_per_epoch=500,
    validation_data=val_dataset.repeat(), 
    validation_steps=2
)

Train for 500 steps, validate for 2 steps
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [6]:
weights=model.get_weights()

In [27]:
import numpy as np
print(len(weights[2]))
print(weights[2])
new_w=[];
for i,weight in enumerate(weights):
    if i==0:
        new_w.append(weight)
    elif i%2 ==1:
        new_w.append(np.array(weight)*2)
    else:
        new_w.append(weight*3)

# weights

256
[[ 0.16230051  0.03145216  0.22375041 ... -0.05054053  0.264603
   0.05829823]
 [-0.0686494   0.0758021  -0.19056806 ... -0.05589221 -0.01039014
  -0.11984411]
 [ 0.10782806  0.0396307  -0.14551792 ...  0.09090099 -0.10998756
   0.05305202]
 ...
 [ 0.04609544  0.0805263   0.09422747 ... -0.04625442 -0.11567344
   0.0696032 ]
 [-0.03284322  0.01683136  0.07694931 ... -0.02273826  0.09015141
  -0.03810262]
 [-0.06147341  0.02901897  0.15435362 ... -0.10796605 -0.07921346
   0.06150447]]


In [31]:
type(weights[0])

numpy.ndarray

In [30]:
new_w[2]

array([[ 0.48690152,  0.09435649,  0.67125124, ..., -0.15162158,
         0.79380894,  0.17489469],
       [-0.2059482 ,  0.2274063 , -0.57170415, ..., -0.16767663,
        -0.03117042, -0.35953233],
       [ 0.32348418,  0.1188921 , -0.43655375, ...,  0.272703  ,
        -0.3299627 ,  0.15915605],
       ...,
       [ 0.13828632,  0.2415789 ,  0.28268242, ..., -0.13876325,
        -0.34702033,  0.20880958],
       [-0.09852965,  0.05049409,  0.23084792, ..., -0.06821477,
         0.27045423, -0.11430785],
       [-0.18442024,  0.08705691,  0.46306086, ..., -0.32389814,
        -0.23764038,  0.1845134 ]], dtype=float32)