In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [49]:
import os
from itertools import combinations
from pathlib import Path
from datetime import datetime
import pickle
import json
from dataclasses import dataclass
from typing import Callable, List

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.losses import BinaryCrossentropy, BinaryFocalCrossentropy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import TensorBoard, EarlyStopping

In [8]:
from models.handnet_based_model import handnet_based_model
from util.training import init_device
from util.training.dataloader import split_data_for_multiple_location, concat_and_shuffle
from util.training.metrics import IntersectionOverUnion, MeanPixelAccuracy

In [9]:
PROJECT_DIRPATH = Path('/tf/workspace/deformation-prediction-multi-environment')
NAS_DIRPATH = Path('/tf/nas/')

In [10]:
print("TensorFlow version:", tf.__version__)

TensorFlow version: 2.11.0


# 準備
---

## デバイスの初期化

In [11]:
!nvidia-smi

Tue Nov  5 17:05:16 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.74       Driver Version: 470.74       CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    On   | 00000000:18:00.0 Off |                  Off |
| 44%   70C    P2   192W / 300W |   8354MiB / 48682MiB |     81%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000    On   | 00000000:3B:00.0 Off |                  Off |
| 50%   73C    P2   202W / 300W |   8166MiB / 48685MiB |     64%      Default |
|       

In [12]:
# 使用するGPUを指定
gpu = [2]

In [13]:
seed = 42
init_device(seed, gpu)

Setting random seed: 42
Random seed set for Python, NumPy, and TensorFlow.
Detected 5 GPU(s): ['/physical_device:GPU:0', '/physical_device:GPU:1', '/physical_device:GPU:2', '/physical_device:GPU:3', '/physical_device:GPU:4']
Visible GPU devices (1): ['/physical_device:GPU:2']


## ハイパーパラメータクラスの定義

In [14]:
@dataclass
class HyperParameters:
    epochs: int
    batch_size: int
    loss_function: Callable
    metrics: List[Callable]
    optimizer: Callable

## モデルの初期化関数の定義

In [15]:
def init_model(hparam):
    model = handnet_based_model(
        input_shape = (10, 52, 2),
        num_block1 = 3,
        num_block2 = 3,
        num_residual_blocks = 14, # 残差ブロックは増やすと重くなる
    )
    
    model.compile(
        optimizer = hparam.optimizer,
        loss = hparam.loss_function,
        metrics = hparam.metrics,
    )

    return model

## コールバック準備関数の定義

In [16]:
def prepare_callbacks(hparam, log_dirpath):

    callbacks = [
        TensorBoard(log_dir=log_dirpath, histogram_freq=1),
        EarlyStopping(
            monitor='val_iou',
            mode='max',
            patience=10,
            verbose=1,
            restore_best_weights=True,
        )
    ]

    return callbacks

## モデルを訓練する関数の定義

In [38]:
def save_history(save_dirpath, history):
    history_filepath = save_dirpath/'history.pkl'
    with open(history_filepath, 'wb') as file:
        pickle.dump(history, file)

def save_model(save_dirpath, model):
    # モデルのアーキテクチャを保存
    model_json = model.to_json()
    model_json_filepath = save_dirpath/'model_architecture.json'
    with open(model_json_filepath, 'w') as json_file:
        json_file.write(model_json)

    # モデルを保存
    model_filepath = save_dirpath/'model.h5'
    model.save(model_filepath)

def save_results(save_dirpath, model, test_data):
    X_test, y_test = test_data
    results = model.evaluate(X_test, y_test)
    with open(save_dirpath/'evaluation_results.txt', 'w') as f:
        for name, value in zip(model.metrics_names, results):
            f.write(f"{name}: {value:.04f}\n")
        

In [39]:
def train(hparam, train_data, valid_data, date, time, train_id):
    # データの用意
    X_train, Y_train = train_data
    X_valid, Y_valid = valid_data

    # モデルを作成
    model = init_model(hparam)

    # ディレクトリの設定
    train_dirpath = NAS_DIRPATH/'training'/date/time/train_id
    log_dirpath = PROJECT_DIRPATH/'logs'/date/time/train_id

    train_dirpath.mkdir(parents=True, exist_ok=True)

    # モデルのフィッティング
    history = model.fit(
        X_train, Y_train,
        validation_data=(X_valid, Y_valid),
        epochs = hparam.epochs,
        batch_size = hparam.batch_size,
        verbose = 1,
        callbacks = prepare_callbacks(hparam, log_dirpath)
    )
    save_history(train_dirpath, history)
    save_model(train_dirpath, model)

    return history, model

# モデルの訓練
---

In [40]:
def get_train_id(loc_list, train_loc_list):
    idx = lambda x: loc_list.index(x)

    return f"{''.join(map(str, map(idx, train_loc_list)))}"

In [41]:
def get_experiment_id(loc_list, train_loc_list, test_location):
    idx = lambda x: loc_list.index(x)

    return f"{get_train_id(loc_list, train_loc_list)}-{idx(test_location)}"

In [50]:
def main():
    date = datetime.now().strftime("%Y-%m-%d")
    time = datetime.now().strftime("%H-%M-%S")
    
    # データをロード
    csi_preprocess_id = 'real_and_imag'
    location_list = ['511', '512', '514B']
    dataset_dict = split_data_for_multiple_location(csi_preprocess_id, location_list)
    train_data_dict, valid_data_dict, test_data_dict = dataset_dict

    for r in range(1, len(location_list) + 1):
        for train_location_list in combinations(location_list, r):
            print(f'train_location: {list(train_location_list)}')

            train_id = get_train_id(location_list, train_location_list)
            train_data = concat_and_shuffle(train_data_dict, train_location_list)
            valid_data = concat_and_shuffle(valid_data_dict, train_location_list)

            # ハイパーパラメータの指定
            hparam = HyperParameters(
                epochs = 150,
                batch_size = 256,
                optimizer = Adam(learning_rate=0.001),
                loss_function = BinaryFocalCrossentropy(gamma=2), # BinaryCrossentropy()
                metrics = [IntersectionOverUnion(threshold=0.5), MeanPixelAccuracy()],
            )
    
            history, model = train(hparam, train_data, valid_data, date, time, train_id)
            
            for test_location in location_list:
                experiment_id = get_experiment_id(location_list, train_location_list, test_location)
                print(f"experiment_id = {experiment_id}")
                test_data = test_data_dict[test_location]

                experiment_dirpath = NAS_DIRPATH/'results'/date/time/experiment_id
                experiment_dirpath.mkdir(parents=True, exist_ok=True)
                save_results(experiment_dirpath, model, test_data)

In [51]:
main()

train_location: ['511']
Epoch 1/150
Epoch 2/150
Epoch 3/150
Epoch 4/150
Epoch 5/150
Epoch 6/150
Epoch 7/150
Epoch 8/150
Epoch 9/150
Epoch 10/150
Epoch 11/150
Epoch 12/150
Epoch 13/150
Epoch 14/150
Epoch 15/150
Epoch 16/150
Epoch 17/150
Epoch 18/150
Epoch 19/150
Epoch 20/150
Epoch 21/150
Epoch 22/150
Epoch 23/150
Epoch 24/150
Epoch 25/150
Epoch 26/150
Epoch 27/150
Epoch 28/150
Epoch 29/150
Epoch 30/150
Epoch 31/150
Epoch 32/150
Epoch 33/150
Epoch 33: early stopping
Keras weights file (<HDF5 file "variables.h5" (mode r+)>) saving:
...layers
......activation
.........vars
......activation_1
.........vars
......activation_10
.........vars
......activation_11
.........vars
......activation_12
.........vars
......activation_13
.........vars
......activation_14
.........vars
......activation_15
.........vars
......activation_16
.........vars
......activation_17
.........vars
......activation_18
.........vars
......activation_19
.........vars
......activation_2
.........vars
......activation_2