# Full Workflow in DeepLearning (Image Classification) with Tensorflow2.0

## Import Libraries
- tensorflow 2.0

In [23]:
import tensorflow as tf

from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model

from tqdm import tqdm
print(tf.__version__)

2.5.2


## Data Load : Mnist Toy Dataset

In [6]:
mnist = tf.keras.datasets.mnist

# loading
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalization 최대 pixel 값 조정
x_train, x_test = x_train / 255.0, x_test / 255.0

print("Before add a channel x_train shape : {}".format(x_train.shape))
print("Before add a channel x_test shape : {}".format(x_test.shape))
print("y_train: {}, y_test: {}".format(y_train.shape, y_test.shape))

# Add a channel dimension
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")

print("After add a channel x_train shape : {}".format(x_train.shape))
print("After add a channel x_test shape : {}".format(x_test.shape))
print("y_train: {}, y_test: {}".format(y_train.shape, y_test.shape))

Before add a channel x_train shape : (60000, 28, 28)
Before add a channel x_test shape : (10000, 28, 28)
y_train: (60000,), y_test: (10000,)
After add a channel x_train shape : (60000, 28, 28, 1)
After add a channel x_test shape : (10000, 28, 28, 1)
y_train: (60000,), y_test: (10000,)


## Data Random Shuffle and Make Batch

In [12]:
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
train_ds

<BatchDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.uint8)>

## Modeling using Keras API

In [13]:
class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32, 3, activation='relu')
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10)
        
    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

In [14]:
# create
model = MyModel()

## Training
### Loss Function

In [15]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

In [16]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')


### Train Function

In [17]:
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        # training=True : only needed if layers with different behavior 
        # during training versus inference (e.g. Dropout)
        
        predictions = model(images, training=True)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(labels, predictions)

## Test Function

In [18]:
@tf.function
def test_step(images, labels):
    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)
    
    test_loss(t_loss)
    test_accuracy(labels, predictions)

### Trainig

In [25]:
EPOCHS = 5
for epoch in tqdm(range(EPOCHS)):
    # reset the metrics at the start of the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states
    test_loss.reset_states()
    test_accuracy.reset_states()
    
    for images, labels in tqdm(train_ds):
        train_step(images, labels)
    for test_images, test_labels in tqdm(test_ds):
        test_step(test_images, test_labels)
        
    print(
    f'Epoch {epoch + 1}, '
    f'Loss: {train_loss.result()}, '
    f'Accuracy: {train_accuracy.result() * 100}, '
    f'Test Loss: {test_loss.result()}, '
    f'Test Accuracy: {test_accuracy.result() * 100}'
  )

  0%|                                                               | 0/5 [00:00<?, ?it/s]
  0%|                                                            | 0/1875 [00:00<?, ?it/s][A
  0%|                                                    | 1/1875 [00:00<05:06,  6.11it/s][A
  0%|▏                                                   | 8/1875 [00:00<00:52, 35.40it/s][A
  1%|▍                                                  | 15/1875 [00:00<00:39, 47.26it/s][A
  1%|▌                                                  | 22/1875 [00:00<00:34, 54.45it/s][A
  2%|▊                                                  | 29/1875 [00:00<00:31, 58.87it/s][A
  2%|▉                                                  | 36/1875 [00:00<00:29, 62.25it/s][A
  2%|█▏                                                 | 43/1875 [00:00<00:29, 61.17it/s][A
  3%|█▎                                                 | 50/1875 [00:00<00:28, 63.77it/s][A
  3%|█▌                                                 | 57/18

KeyboardInterrupt: 

In [21]:
a = 1
print(f'EPOCH {a+1}')
print('EPOCH {}'.format(a+1))

EPOCH 2
EPOCH 2
