In [1]:
import tensorflow as tf
import keras

In [15]:
class MyModel(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dense0 = keras.layers.Dense(4, activation = "relu")
        self.dense1 = keras.layers.Dense(6, activation = "relu")
        self.dense2 = keras.layers.Dense(2, activation = "softmax")

    def call(self, inputs):
        x = self.dense0(inputs)
        x = self.dense1(x)
        x = self.dense2(x)
        return x

In [17]:
model = MyModel(name="My Subclassing API Model")

In [19]:
model.summary()

In [21]:
model.name

'My Subclassing API Model'

In [23]:
x = tf.random.normal([3, 8])
y = model(x)

In [25]:
print(f"출력 텐서 shape: {y.shape}")

출력 텐서 shape: (3, 2)


In [None]:
# Subclassing API로 복잡한 모델 만들기

In [27]:
class ComplexModel(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dense0_0 = keras.layers.Dense(8, activation = "relu")
        self.dense0_1 = keras.layers.Dense(4, activation = "relu")
        self.concat = keras.layers.Concatenate()
        self.dense1_0 = keras.layers.Dense(8)
        self.dense1_1 = keras.layers.Dense(2)

    def call(self, inputs):
        input0, input1, input2 = inputs
        h0_0 = self.dense0_0(input0)
        h0_1 = self.dense0_1(input1)
        concat_feature = self.concat([h0_0, h0_1, input2])
        output0 = self.dense1_0(concat_feature)
        output1 = self.dense1_1(concat_feature)
        return output0, output1

In [29]:
complex_model = ComplexModel()

In [31]:
x0 = tf.random.normal([3, 4])
x1 = tf.random.normal([3, 6])
x2 = tf.random.normal([3, 2])

y = complex_model([x0, x1, x2])

In [33]:
y

(<tf.Tensor: shape=(3, 8), dtype=float32, numpy=
 array([[ 0.09152794,  2.027401  ,  0.70917535, -1.2605323 , -2.4480157 ,
         -0.15068951, -0.29007652,  1.454891  ],
        [-0.16378106, -0.31391037,  0.60854375, -0.15176743,  0.72444993,
         -0.12649071, -0.6155102 , -0.6483801 ],
        [-0.54178125,  0.9161054 ,  0.8939525 ,  0.6089164 ,  0.7060109 ,
          0.19168928,  1.7074941 , -0.33645108]], dtype=float32)>,
 <tf.Tensor: shape=(3, 2), dtype=float32, numpy=
 array([[ 1.5734881e-03,  3.4883919e+00],
        [ 9.4543204e-02, -3.2733735e-01],
        [-4.5151281e-01, -1.3257071e+00]], dtype=float32)>)