# Tensorflow 2.0 Template Training Code

## 概要
* このプログラムは、MNISTの画像分類ができるようにすることを目的としています。
    * MNIST:0~9の手書き文字画像データセット(超有名)
    * 画像分類：画像に写っているオブジェクトが何か推測するタスク
        * 画像関係のタスクだと他には、物体検出、画像生成、キャプション生成なんてものが有名
* なるべくわかりやすくコメントしたつもりです。
    * わからないところあれば、Teamsで聞いてください。

## ライブラリのインポート

In [1]:
import os
import sys
import numpy as np
import pandas as pd
from tqdm import tqdm
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D
from sklearn.utils import shuffle
from model.lenet import LeNet

from config import Config
from model.saver import Saver
from dataset.entity import Entity
from dataset.generator import Generator

## configの読み込み
* Config Classは~/work/template/config.pyで定義してあります。
* パスの定義とかをこんな感じでconfigでまとめてあげると綺麗になります。
* 乱数の種はセット推奨です。

In [2]:
config = Config()

np.random.seed(config.seed)
tf.random.set_seed(config.seed)

/home/jovyan/work/_checkpoint/template is not found. so, created.
/home/jovyan/work/_log/template/train is not found. so, created.
/home/jovyan/work/_log/template/valid is not found. so, created.


## Dataset読み込み
* オリジナルデータセットへの応用も考慮し、あえてローカルのMNISTデータセットを使う遠回しなコードにしてます。
* GeneratorはEpoch毎のミニバッチ生成をするクラスです。

In [3]:
'''学習データ'''
train_df = pd.read_csv(config.dataset.train_csv_path, header=None)
train_np = train_df.values
train_data = [Entity(i) for i in train_np]
train_gen = Generator(train_data, config.model.batch_size, shuffle=True, random_state=config.seed)

'''検証データ'''
valid_df = pd.read_csv(config.dataset.valid_csv_path, header=None)
valid_np = valid_df.values
valid_data = [Entity(i) for i in valid_np]
valid_gen = Generator(valid_data, config.model.batch_size, shuffle=True, random_state=config.seed)

## モデル構築

In [4]:
model = LeNet()

## 学習時のチェックポイント設定

In [5]:
saver = Saver()
checkpoint = tf.train.Checkpoint(step=tf.Variable(1), model=model)
manager = tf.train.CheckpointManager(checkpoint, 
                                     directory=config.model.checkpoint_dir,
                                     max_to_keep=5,
                                     checkpoint_name="chkp")
checkpoint.restore(manager.latest_checkpoint)

if manager.latest_checkpoint:
    saver.old_loss = model.valid_loss.result()
    print("restored!")
else:
    print("init from scratch!")

init from scratch!


## TensorBoard設定

In [6]:
train_summary_writer = tf.summary.create_file_writer(config.tfboard.train_log_dir)
valid_summary_writer = tf.summary.create_file_writer(config.tfboard.valid_log_dir)

## 学習

In [7]:
for epoch in range(int(checkpoint.step), config.model.max_epoch):
    # 学習用データで学習
    print("Epoch:{}".format(epoch))
    for (x, y) in tqdm(train_gen):
        model.train_step(x, y)
        
    with train_summary_writer.as_default():
        tf.summary.scalar('loss', model.train_loss.result(), step=epoch)
        
    # 検証用データ
    for (x, y) in valid_gen:
        model.valid_step(x, y)
    with valid_summary_writer.as_default():
        tf.summary.scalar('loss', model.valid_loss.result(), step=epoch)
    
    print('Train Loss:{:.4f}, Validation Loss:{:.4f}'.format(model.train_loss.result(), model.valid_loss.result()))
    print('Train Acc:{:.4f}, Validation Acc:{:.4f}'.format(model.train_acc.result(), model.valid_acc.result()))
    checkpoint.step.assign_add(1)
    print(saver.save_model_good_valloss(model, config, manager, epoch))

    model.train_loss.reset_states()
    model.valid_loss.reset_states()
    model.train_acc.reset_states()
    model.valid_acc.reset_states()

  0%|          | 0/938 [00:00<?, ?it/s]

Epoch:1


100%|██████████| 938/938 [00:08<00:00, 109.39it/s]
  1%|▏         | 12/938 [00:00<00:07, 119.90it/s]

Train Loss:0.2491, Validation Loss:0.0877
Train Acc:0.9265, Validation Acc:0.9724
init saver loss.
Epoch:2


100%|██████████| 938/938 [00:07<00:00, 124.24it/s]
  1%|▏         | 13/938 [00:00<00:07, 123.10it/s]

Train Loss:0.0775, Validation Loss:0.0513
Train Acc:0.9757, Validation Acc:0.9840
model was saved!0.0877 -> 0.0513
Epoch:3


100%|██████████| 938/938 [00:07<00:00, 124.76it/s]
  1%|▏         | 13/938 [00:00<00:07, 125.98it/s]

Train Loss:0.0545, Validation Loss:0.0463
Train Acc:0.9827, Validation Acc:0.9846
model was saved!0.0513 -> 0.0463
Epoch:4


100%|██████████| 938/938 [00:07<00:00, 123.35it/s]
  1%|▏         | 13/938 [00:00<00:07, 122.35it/s]

Train Loss:0.0436, Validation Loss:0.0395
Train Acc:0.9859, Validation Acc:0.9872
model was saved!0.0463 -> 0.0395
Epoch:5


100%|██████████| 938/938 [00:07<00:00, 124.53it/s]
  1%|▏         | 13/938 [00:00<00:07, 122.71it/s]

Train Loss:0.0351, Validation Loss:0.0449
Train Acc:0.9884, Validation Acc:0.9859
pass. Best Validation loss:0.0395
Epoch:6


100%|██████████| 938/938 [00:07<00:00, 124.05it/s]
  1%|▏         | 13/938 [00:00<00:07, 121.87it/s]

Train Loss:0.0294, Validation Loss:0.0451
Train Acc:0.9905, Validation Acc:0.9864
pass. Best Validation loss:0.0395
Epoch:7


100%|██████████| 938/938 [00:07<00:00, 125.75it/s]
  1%|▏         | 13/938 [00:00<00:07, 126.89it/s]

Train Loss:0.0253, Validation Loss:0.0388
Train Acc:0.9917, Validation Acc:0.9897
model was saved!0.0395 -> 0.0388
Epoch:8


100%|██████████| 938/938 [00:07<00:00, 125.13it/s]
  1%|▏         | 13/938 [00:00<00:07, 127.14it/s]

Train Loss:0.0224, Validation Loss:0.0342
Train Acc:0.9926, Validation Acc:0.9895
model was saved!0.0388 -> 0.0342
Epoch:9


100%|██████████| 938/938 [00:07<00:00, 124.74it/s]


Train Loss:0.0194, Validation Loss:0.0423
Train Acc:0.9935, Validation Acc:0.9878
pass. Best Validation loss:0.0342


In [8]:
print('TRAINING 完了！ お疲れ様！')

TRAINING 完了！ お疲れ様！
