In [1]:
def make_image(x_train, iter, time_hparam, hidden_dim):
    # 結果の出力(訓練データそのまま)
    evidence = model(x_train)
    alpha = evidence + 1
    y_pred = alpha/tf.reduce_sum(alpha, axis=1, keepdims=True)
    unc = 2/tf.reduce_sum(alpha, axis=1, keepdims=True)
    # 2次元空間での不確かさの出力
    # グラデーションを作る
    # 正解の散布図
    figure = plt.figure(figsize=(12,4))
    ax = figure.add_subplot(141)
    ax.scatter(x_train[:100,0], x_train[:100,1], c="r")
    ax.scatter(x_train[100:,0], x_train[100:,1], c="b")
    ax.set_title("true")
    # 予測の散布図
    ax = figure.add_subplot(142)
    ax.set_title("pred")
    y_pred_ctg = np.argmax(y_pred, axis=1)
    for x, label in zip(x_train, y_pred_ctg):
        if label == 0:
            ax.scatter(x[0], x[1], c="r")
        elif label == 1:
            ax.scatter(x[0], x[1], c="b")
        else:
            print("exception has occured")
            sys.exit(1)
    # 不確かさの分布
    ax = figure.add_subplot(143)
    ax.set_title("unc")
    im = ax.scatter(x_train[:,0], x_train[:,1], c=unc, cmap='Blues', vmin=0, vmax=1)
    figure.colorbar(im)

    # ロスの分布
    loss_fn = EDLLoss(K=2, annealing=1)
    _y_train = tf.one_hot(y_train, depth=2)
    loss = loss_fn.call(_y_train, alpha)
    ax = figure.add_subplot(144)
    ax.set_title("loss")
    im = ax.scatter(x_train[:,0], x_train[:,1], c=loss, cmap='Blues', vmin=0, vmax=1)
    figure.colorbar(im)

    plt.legend()
    save_path = os.path.join(os.environ["sleep"], "figures", "check_uncertainty", f"{time_hparam}", f"{hidden_dim}")
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    if iter < 10:
        plt.savefig(os.path.join(save_path, f"00{iter}.png"))
    elif iter < 100:
        plt.savefig(os.path.join(save_path, f"0{iter}.png"))
    elif iter < 1000:
        plt.savefig(os.path.join(save_path, f"{iter}.png"))
    plt.clf()

In [2]:
import tensorflow as tf
tf.random.set_seed(0)
import numpy as np
import matplotlib.pyplot as plt
from nn.model_base import edl_classifier4psedo_data, EDLModelBase
from nn.losses import EDLLoss
import os, sys
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

# parameterの設定
epochs = 300
time_hparam = 100
HIDDEN_DIM =  32 # TODO : ここを変える(3, 8, 16, 32)

# 仮データの作成
def psedo_data(row, col, x_bias, y_bias):
    # 極座標で考える
    row, col = (100, 2)
    r_class0 = tf.random.uniform(shape=(row,), minval=0, maxval=0.5)
    theta_class0 = tf.random.uniform(shape=(row,), minval=0, maxval=np.pi*2)
    r_class1 = tf.random.uniform(shape=(row,), minval=0.5, maxval=1)
    theta_class1 = tf.random.uniform(shape=(row,), minval=0, maxval=np.pi*2)
    input_class0 = (x_bias+r_class0*np.cos(theta_class0), y_bias+r_class0*np.sin(theta_class0))
    input_class1 = (x_bias+r_class1*np.cos(theta_class1), y_bias+r_class1*np.sin(theta_class1))
    x_train = tf.concat([input_class0, input_class1], axis=1)
    x_train = tf.transpose(x_train)
    y_train_0 = [0 for _ in range(row)]
    y_train_1 = [1 for _ in range(row)]
    y_train = y_train_0 + y_train_1
    x_test = None
    y_test = None
    return (x_train, x_test), (y_train, y_test)


(x_train, x_test), (y_train, y_test) = psedo_data(row=100, col=2, x_bias=0, y_bias=0)

# カスタムトレーニング
inputs = tf.keras.Input(shape=(2,))
outputs = edl_classifier4psedo_data(x=inputs, use_bias=True, hidden_dim=HIDDEN_DIM)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
# 最適化関数の設定
optimizer = tf.keras.optimizers.Adam()
# メトリクスの作成
# true side : カテゴリカルな状態，pred side : クラスの次元数（ソフトマックスをかける前)
# CategoricalAccuracy : one-hotに対して計算してくれる
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=x_train.shape[0]).batch(batch_size=32)

# 学習
for epoch in range(epochs):
    loss_fn = EDLLoss(K=2, annealing=min(1.0, epoch/epochs*time_hparam))
    print(f"エポック:{epoch}")
    # エポック内のバッチサイズごとのループ
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        # 勾配を計算
        with tf.GradientTape() as tape:
            # dataset.shuffleを入れることによってバッチサイズを設定できる
            evidence = model(x_batch_train, training=True)
            alpha = evidence+1
            y_pred = alpha/tf.reduce_sum(alpha, axis=1, keepdims=True)
            # NOTE : ここでone-hotにする
            y_batch_train = tf.one_hot(y_batch_train, depth=2)
            loss_value = loss_fn.call(y_batch_train, alpha)
            
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        train_acc_metric.update_state(y_batch_train, y_pred)
    # 画像の作成
    make_image(x_train=x_train, iter=epoch, time_hparam=time_hparam, hidden_dim=HIDDEN_DIM)
    # エポックの終わりにメトリクスを表示する
    train_acc = train_acc_metric.result()
    print(f"訓練一致率：{train_acc:.2%}")
    # エポックの終わりに訓練メトリクスを初期化
    train_acc_metric.reset_states()

致率：92.00%
エポック:14
No handles with labels found to put in legend.
訓練一致率：91.50%
エポック:15
No handles with labels found to put in legend.
訓練一致率：93.00%
エポック:16
No handles with labels found to put in legend.
訓練一致率：93.50%
エポック:17
No handles with labels found to put in legend.
訓練一致率：93.50%
エポック:18
No handles with labels found to put in legend.
訓練一致率：93.00%
エポック:19
No handles with labels found to put in legend.
訓練一致率：95.00%
エポック:20
  figure = plt.figure(figsize=(12,4))
No handles with labels found to put in legend.
訓練一致率：96.00%
エポック:21
No handles with labels found to put in legend.
訓練一致率：95.50%
エポック:22
No handles with labels found to put in legend.
訓練一致率：95.50%
エポック:23
No handles with labels found to put in legend.
訓練一致率：96.00%
エポック:24
No handles with labels found to put in legend.
訓練一致率：95.00%
エポック:25
No handles with labels found to put in legend.
訓練一致率：95.00%
エポック:26
No handles with labels found to put in legend.
訓練一致率：95.00%
エポック:27
No handles with labels found to put in legend.
訓練一致率：95.00%


<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

<Figure size 864x288 with 0 Axes>

In [3]:
# gitの作成
from PIL import Image
import glob

saved_path = os.path.join(os.environ["sleep"], "figures", "check_uncertainty", f"{time_hparam}", f"{HIDDEN_DIM}")
files = glob.glob(os.path.join(saved_path, "*.png"))
images = list(map(lambda file: Image.open(file), files))
images[0].save(os.path.join(saved_path, 'out.gif'), save_all=True, append_images=images[1:], loop=0)