In [1]:
%load_ext autoreload
%autoreload 2

In [27]:
import os
import csv
from itertools import combinations
from pathlib import Path
from datetime import datetime
import pickle
import json
from dataclasses import dataclass
from typing import Callable, List

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.losses import BinaryCrossentropy, BinaryFocalCrossentropy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import TensorBoard, EarlyStopping

In [3]:
from models.handnet_based_model import handnet_based_model
from util.training import init_device
from util.training.dataloader import split_data_for_multiple_location, concat_and_shuffle
from util.training.metrics import IntersectionOverUnion, MeanPixelAccuracy

In [4]:
PROJECT_DIRPATH = Path('/tf/workspace/deformation-prediction-multi-environment')
NAS_DIRPATH = Path('/tf/nas/')

In [5]:
print("TensorFlow version:", tf.__version__)

TensorFlow version: 2.11.0


# 準備
---

## デバイスの初期化

In [9]:
!nvidia-smi

Fri Nov 22 15:51:58 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.74       Driver Version: 470.74       CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    On   | 00000000:18:00.0 Off |                  Off |
| 30%   29C    P8    17W / 300W |     28MiB / 48682MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000    On   | 00000000:3B:00.0 Off |                  Off |
| 30%   33C    P8    18W / 300W |      8MiB / 48685MiB |      0%      Default |
|       

In [10]:
# 使用するGPUを指定
gpu = [2]

In [11]:
seed = None
init_device(seed, gpu)

Detected 5 GPU(s): ['/physical_device:GPU:0', '/physical_device:GPU:1', '/physical_device:GPU:2', '/physical_device:GPU:3', '/physical_device:GPU:4']
Visible GPU devices (1): ['/physical_device:GPU:2']


## ハイパーパラメータクラスの定義

In [12]:
@dataclass
class HyperParameters:
    epochs: int
    batch_size: int
    loss_function: Callable
    metrics: List[Callable]
    optimizer: Callable

## モデルの初期化関数の定義

In [13]:
def init_model(hparam):
    model = handnet_based_model(
        input_shape = (10, 52, 2),
        num_block1 = 3,
        num_block2 = 3,
        num_residual_blocks = 14, # 残差ブロックは増やすと重くなる
    )
    
    model.compile(
        optimizer = hparam.optimizer,
        loss = hparam.loss_function,
        metrics = hparam.metrics,
    )

    return model

## コールバック準備関数の定義

In [14]:
def prepare_callbacks(hparam, log_dirpath):

    callbacks = [
        TensorBoard(log_dir=log_dirpath, histogram_freq=1),
        EarlyStopping(
            monitor='val_iou',
            mode='max',
            patience=10,
            verbose=1,
            restore_best_weights=True,
        )
    ]

    return callbacks

## モデルを訓練する関数の定義

In [24]:
def train(hparam, train_data, valid_data, log_dirpath):
    # データの用意
    X_train, Y_train = train_data
    X_valid, Y_valid = valid_data

    # モデルを作成
    model = init_model(hparam)

    # モデルのフィッティング
    history = model.fit(
        X_train, Y_train,
        validation_data=(X_valid, Y_valid),
        epochs = hparam.epochs,
        batch_size = hparam.batch_size,
        verbose = 1,
        callbacks = prepare_callbacks(hparam, log_dirpath)
    )

    return history, model

# モデルの訓練
---

In [16]:
def save_history(save_dirpath, history):
    history_filepath = save_dirpath/'history.pkl'
    with open(history_filepath, 'wb') as file:
        pickle.dump(history, file)

def save_model(save_dirpath, model):
    # モデルのアーキテクチャを保存
    model_json = model.to_json()
    model_json_filepath = save_dirpath/'model_architecture.json'
    with open(model_json_filepath, 'w') as json_file:
        json_file.write(model_json)

    # モデルを保存
    model_filepath = save_dirpath/'model.h5'
    model.save(model_filepath)

def save_results(csv_filepath, header, rows):
    
    # Check if the file exists and has content
    try:
        with open(csv_filepath, mode='r', newline='') as file:
            file_exists = True
    except FileNotFoundError:
        file_exists = False

    # Append to CSV file
    with open(csv_filepath, mode='a', newline='') as file:
        writer = csv.writer(file)

        # Write the header only if the file is new
        if not file_exists:
            writer.writerow(header)

        # Write the rows
        writer.writerows(rows)

In [17]:
def get_train_id(loc_list, train_loc_list):
    return f"{''.join(map(str, map(loc_list.index, train_loc_list)))}"

def get_evaluation_id(loc_list, train_loc_list, test_location):
    return f"{get_train_id(loc_list, train_loc_list)}-{loc_list.index(test_location)}"

In [29]:
def trial(experiment_dirpath, trial_id):
    # データをロード
    csi_preprocess_id = 'real_and_imag'
    location_list = ['511', '512', '514B']
    dataset_dict = split_data_for_multiple_location(csi_preprocess_id, location_list)
    train_data_dict, valid_data_dict, test_data_dict = dataset_dict

    csv_filepath = experiment_dirpath/'results.csv'

    for r in range(1, len(location_list) + 1):
        for train_location_list in combinations(location_list, r):
            print(f'train_location: {list(train_location_list)}')
            train_id = get_train_id(location_list, train_location_list)
            
            train_data = concat_and_shuffle(train_data_dict, train_location_list)
            valid_data = concat_and_shuffle(valid_data_dict, train_location_list)

            # ハイパーパラメータの指定
            hparam = HyperParameters(
                epochs = 150,
                batch_size = 256,
                optimizer = Adam(learning_rate=0.001),
                loss_function = BinaryFocalCrossentropy(gamma=2), # BinaryCrossentropy()
                metrics = [IntersectionOverUnion(threshold=0.5), MeanPixelAccuracy()],
            )

            # ディレクトリの設定
            train_dirpath = experiment_dirpath/trial_id/train_id
            train_dirpath.mkdir(parents=True, exist_ok=True)
            
            log_dirpath = PROJECT_DIRPATH/'logs'/trial_id/train_id
    
            history, model = train(hparam, train_data, valid_data, log_dirpath)

            save_history(train_dirpath, history)
            # save_model(train_dirpath, model)
            
            for test_location in location_list:
                evaluation_id = get_evaluation_id(location_list, train_location_list, test_location)
                print(f"evaluation_id = {evaluation_id}")
                
                test_data = test_data_dict[test_location]
                X_test, y_test = test_data

                results = model.evaluate(X_test, y_test)

                # Prepare data to append
                header = ['Trial_id', 'Evaluation_id', 'Metric', 'Value']
                rows = [
                    [trial_id, evaluation_id, name, f"{value:.04f}"]
                    for name, value in zip(model.metrics_names, results)
                ]
                
                save_results(csv_filepath, header, rows)

In [None]:
date = datetime.now().strftime("%Y-%m-%d")
time = datetime.now().strftime("%H-%M-%S")
experiment_dirpath = NAS_DIRPATH/'experiments'/date/time
experiment_dirpath.mkdir(parents=True, exist_ok=True)

n_trial = 30
for i in range(n_trial):
    trial_id = f'trial_{i}'
    trial(experiment_dirpath, trial_id)

train_location: ['511']
Epoch 1/150
