In [2]:
import os
import json

import numpy as np
import pandas as pd

import torch
import torchvision

import matplotlib.pyplot as plt
import lightning as L

from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from lightning.pytorch.utilities.model_summary import ModelSummary

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from data_module.data_module import FFTDataModule
from model.AE_model import AECNN1DModel

import optuna
from optuna.integration.pytorch_lightning import PyTorchLightningPruningCallback

In [3]:
TEST = True

random_seed = 42
L.seed_everything(random_seed)

Seed set to 42


42

In [None]:
n_epochs = 2000
patience = n_epochs//10

optimizer_param_dict = {
    "Adam": (optim.Adam, {
        "lr": 0.001,
    }),
    "SGD": (optim.SGD, {
        "lr": 0.001,
        "momentum": 0.5,
    }),
}
batch_size = 512
optimizer, optimizer_param = optimizer_param_dict["Adam"]
dataset_path = "/nfs/ksdata/tran/HAR_AE/dataset/processed_concat_data"

log_save_dir = "lightning_logs"
log_save_name = "12_AE_train_optuna5"

In [None]:

def objective(trial):
    kernel_size = trial.suggest_categorical("kernel_size", [6, 8, 11, 15])
    first_cnn_filter = trial.suggest_int("first_cnn_filter", 5, 7)
    last_linear_input = trial.suggest_int("last_linear_input", 5, 8)
    cnn_layer_num = trial.suggest_int("cnn_layer_num", 1, 4)
    linear_layer_num = trial.suggest_int("linear_layer_num", 1, 4)

    trainer = L.Trainer(
        default_root_dir=os.path.join(log_save_dir, log_save_name),
        max_epochs=n_epochs,
        callbacks=[EarlyStopping(monitor="val_mse", patience=patience)],
        enable_checkpointing=False,
        accelerator="auto",
        check_val_every_n_epoch=100,
        )

    cnn_channel_param = []
    input_channel = 6
    for i in range(cnn_layer_num):
        cnn_channel_param.append((input_channel, 2 ** (first_cnn_filter + i), kernel_size, 0, 3))
        input_channel = 2 ** (first_cnn_filter + i)

    linear_channel_param = []
    for i in range(linear_layer_num):
        linear_channel_param.insert(0, 2 ** (last_linear_input + i))

    net = AECNN1DModel(
        optimizer = optimizer,
        optimizer_param = optimizer_param, 
        cnn_channel_param = cnn_channel_param,
        linear_channel_param = linear_channel_param,
    )

    model_summary = ModelSummary(net, max_depth=6)
    print("model_summary", model_summary)

    data_module = FFTDataModule(dataset_path=dataset_path, batch_size=batch_size)

    trainer.fit(model=net, datamodule=data_module)

    trainer_test_dict = trainer.logged_metrics
    return trainer_test_dict["val_mse"].item()

In [None]:
pruner: optuna.pruners.BasePruner = (optuna.pruners.MedianPruner())

study = optuna.create_study(direction="minimize", pruner=pruner)
study.optimize(objective, n_trials=50, n_jobs=6)

print("Best trial:")
trial = study.best_trial

for key, val in trial.params.items():
    print("{}: {}".format(key, val))

with open(os.path.join(log_save_dir, log_save_name, "optuna_params.json"), "w") as f:
    json.dump(dict(trial.params), f, indent=4)