In [None]:
import os
import torch
import csv
from omegaconf import OmegaConf

# create checkpoint dir
def _ensure_checkpoint_dir(exp_idx, project_root_path):
    ckpt_dir = os.path.join(project_root_path,f'checkpoints/ex_{exp_idx}')
    os.makedirs(ckpt_dir, exist_ok=True)
    return ckpt_dir

# save best model
def save_best_model_weight(model,val_loss, best_loss, project_root_path,exp_idx ,filename="best_model_weight.pth"):
    chpt_dir = _ensure_checkpoint_dir(exp_idx, project_root_path)
    file_path = os.path.join(chpt_dir, filename)

    if val_loss < best_loss:

        torch.save(model.state_dict(), file_path)
    
    return best_loss

# save hparam log 
def save_hparam(cfg, exp_idx, project_root_path, filename='hparam.yaml'):
    ckpt_dir = _ensure_checkpoint_dir(exp_idx, project_root_path)
    file_path = os.path.join(ckpt_dir, filename)

    OmegaConf.save(config=cfg, f=file_path, resolve=True)

# save loss log
def save_loss_log(epoch, train_loss, val_loss, test_loss, exp_idx, project_root_path, filename="loss_log.csv"):
    ckpt_dir = _ensure_checkpoint_dir(exp_idx, project_root_path)
    file_path = os.path.join(ckpt_dir, filename)
    file_exists = os.path.exists(file_path)

    with open(file_path, mode='a', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        if not file_exists:
            writer.writerow(['epoch', 'train_loss', 'val_loss', 'test_loss'])
        writer.writerow([epoch, train_loss, val_loss, test_loss])