In [None]:
import os
import yaml
import optuna
import json
import save_load
from src.train_models import trainTargetModel
from visualize_model import VisualizeModel

from optuna.storages import JournalStorage
from optuna.storages.journal import JournalFileBackend

In [None]:
# ---------------------- #
#   Load training yaml   #
# ---------------------- #
config = None
with open("./train.yaml") as file:
    config = yaml.safe_load(file)
    
print(f"Initial training config: {config["train"]}")

# ------------------------------ #
#  Load study + best trial info  #
# ------------------------------ #
study_folder = "cifar10-baseline-2-c3b5ae7b140f363a"
study_path = os.path.join("study", study_folder)

# Journal + storage
journal_path = os.path.join(study_path, "journal.log")
storage = JournalStorage(JournalFileBackend(journal_path))

# Metadata
metadata_path = os.path.join(study_path, "metadata.json")
with open(metadata_path, "r") as f:
    metadata = json.load(f)

# Load study
study_name = metadata["study"]["study_name"]
study = optuna.load_study(storage=storage, study_name=study_name)

# Best trial
best_trial = study.best_trial
print(f"Best trial value: {best_trial.value}")
print("Best trial parameters:")
for k, v in best_trial.params.items():
    print(f"  {k}: {v}")

# ------------------------------ #
#  Update training configuration #
# ------------------------------ #
train_cfg = config["train"]
data_cfg = config["data"]
run_cfg = config["run"]

# Overwrite study-optimized params
train_cfg["learning_rate"] = best_trial.params["lr"]
train_cfg["momentum"] = best_trial.params["momentum"]
train_cfg["weight_decay"] = best_trial.params["weight_decay"]
train_cfg["batch_size"] = best_trial.params["batch_size"]

# Add missing scheduler param
train_cfg["T_max"] = best_trial.params.get("T_max", None)  # fallback default

print(f"Modified train_cfg with optimal hyperparameters: {train_cfg}")

# ------------------------------ #
#  Save the updated config file  #
# ------------------------------ #
train_metadata = save_load.buildTargetMetadata(train_cfg, data_cfg, run_cfg)
hash_id, save_dir = save_load.saveTarget(train_metadata)
print(f"Saved training metadata with {hash_id} at {save_dir}")


In [None]:
# ------------------- #
#   Prepare dataset   #
# ------------------- #
from src.dataset_handler import processDataset, loadDataset, get_dataloaders

config = None
with open("./train.yaml") as file:
    config = yaml.safe_load(file)

data_cfg = config['data']

print(f"Data_cfg: {data_cfg}")

batch_size = train_cfg["batch_size"]

trainset, testset = loadDataset(data_cfg)
train_dataset, test_dataset, train_indices, test_indices = processDataset(data_cfg, trainset, testset)
# Perpare loaders
train_loader, test_loader = get_dataloaders(batch_size, trainset, testset)

In [None]:
# -------------------------- #
#   Train the Target model   #
# -------------------------- #
train_result, test_result = trainModel(train_cfg, train_loader, test_loader, train_indices, test_indices)
if(train_result != None and test_result != None):
    VisualizeModel().visualize(train_result, test_result)