From 536a24906e3114b2ed5fb4ed37af7c3a2de92b60 Mon Sep 17 00:00:00 2001 From: crangelsmith Date: Wed, 24 Apr 2024 13:46:08 +0100 Subject: [PATCH 1/3] making all paths absolutes --- avae/config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/avae/config.py b/avae/config.py index 4f66b5a..77580bb 100644 --- a/avae/config.py +++ b/avae/config.py @@ -1,5 +1,6 @@ import logging import os +import pathlib import yaml from pydantic import ( @@ -291,7 +292,9 @@ def load_config_params( + key + " in config file or command line arguments. Default values will be used." ) - + # turning all path into absolute paths + if type(val) == pathlib.PosixPath: + setattr(data, key, val.absolute()) # return data as dictionary return data.model_dump() From d6774cbaa8fba526a7ef8dce6abca62e56d5e11a Mon Sep 17 00:00:00 2001 From: crangelsmith Date: Wed, 24 Apr 2024 14:26:13 +0100 Subject: [PATCH 2/3] fixing typing inconsistencies on rescale --- avae/config.py | 6 +++--- avae/data.py | 8 ++++---- avae/train.py | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/avae/config.py b/avae/config.py index 77580bb..871c4d8 100644 --- a/avae/config.py +++ b/avae/config.py @@ -111,7 +111,7 @@ class AffinityConfig(BaseModel): ) pose_dims: int = Field(1, description="Pose dimensions") - rescale: float = Field(None, description="Rescale data") + rescale: int | None = Field(None, description="Rescale data") restart: bool = Field(False, description="Restart training") shift_min: bool = Field( False, description="Scale data with min-max transformation" @@ -292,9 +292,9 @@ def load_config_params( + key + " in config file or command line arguments. Default values will be used." ) - # turning all path into absolute paths + # turning all path into absolute paths and strings for output saving if type(val) == pathlib.PosixPath: - setattr(data, key, val.absolute()) + setattr(data, key, str(val.absolute())) # return data as dictionary return data.model_dump() diff --git a/avae/data.py b/avae/data.py index 87698f9..cf7a570 100644 --- a/avae/data.py +++ b/avae/data.py @@ -33,7 +33,7 @@ def load_data( gaussian_blur: bool = False, normalise: bool = False, shift_min: bool = False, - rescale: bool | None = None, + rescale: int | None = None, ) -> tuple[DataLoader, int]: ... @@ -53,7 +53,7 @@ def load_data( gaussian_blur: bool = False, normalise: bool = False, shift_min: bool = False, - rescale: bool | None = None, + rescale: int | None = None, ) -> tuple[DataLoader, DataLoader, DataLoader, pd.DataFrame, int]: ... @@ -72,7 +72,7 @@ def load_data( gaussian_blur: bool = False, normalise: bool = False, shift_min: bool = False, - rescale: bool | None = None, + rescale: int | None = None, ) -> tuple[DataLoader, DataLoader, DataLoader, pd.DataFrame, int] | tuple[ DataLoader, int ]: @@ -106,7 +106,7 @@ def load_data( In True, the input data is normalised before being passed to the model. shift_min: bool If True, the minimum value of the input data is shifted to 0 and maximum to 1. - rescale: int + rescale: int | None If not None, the input data is rescaled to the given value. diff --git a/avae/train.py b/avae/train.py index 4845f3a..88766eb 100644 --- a/avae/train.py +++ b/avae/train.py @@ -58,7 +58,7 @@ def train( gaussian_blur: bool, normalise: bool, shift_min: bool, - rescale: bool, + rescale: int | None, tensorboard: bool, classifier: str, strategy: str, @@ -160,8 +160,8 @@ def train( Path to the beta values to load. gamma_load: str Path to the gamma values to load. - rescale: bool - If True, the input data is rescaled to have a mean of 0 and std of 1. + rescale: int | None + If provided, the data is rescaled by the value. """ lt.pytorch.seed_everything(42) From 4d21464cf7463371c8ec54505eb327fd0766f1de Mon Sep 17 00:00:00 2001 From: crangelsmith Date: Fri, 26 Apr 2024 15:18:01 +0100 Subject: [PATCH 3/3] making sure the visualsiation and frequency logic is reflected in the output config file --- avae/config.py | 88 +++++++++++++++++++++++++++++++++++++++----------- run.py | 2 +- 2 files changed, 70 insertions(+), 20 deletions(-) diff --git a/avae/config.py b/avae/config.py index 871c4d8..b36035e 100644 --- a/avae/config.py +++ b/avae/config.py @@ -299,7 +299,7 @@ def load_config_params( return data.model_dump() -def write_config_file(time_stamp_name, data): +def write_config_file(time_stamp_name, data) -> None: # record final configuration in logger and save to yaml file for key, val in data.items(): logging.info("Parameter " + key + " set to value: " + str(data[key])) @@ -313,31 +313,62 @@ def write_config_file(time_stamp_name, data): logging.info("YAML File saved!\n") -def setup_visualisation_config(data: dict) -> None: - settings.VIS_LOS = ( +def setup_visualisation_config(data: dict) -> dict: + """ + Set up visualisation configurations. Logic is the following: if a specific visualisation is not set, it will be set to the value of vis_all. + For frequency, if it is set to 0 (the default), it will be set to the value of freq_all. + + Parameters + ---------- + data : dict + Dictionary of configuration parameters. + + Returns + ------- + data : dict + Dictionary of configuration parameters updated with logic. + """ + data["vis_los"] = ( data["vis_los"] if data["vis_los"] is not None else data["vis_all"] ) - settings.VIS_ACC = ( + settings.VIS_LOS = data["vis_los"] + + data["vis_acc"] = ( data["vis_acc"] if data["vis_acc"] is not None else data["vis_all"] ) - settings.VIS_REC = ( + + settings.VIS_ACC = data["vis_acc"] + + data["vis_rec"] = ( data["vis_rec"] if data["vis_rec"] is not None else data["vis_all"] ) - settings.VIS_CYC = ( + settings.VIS_REC = data["vis_rec"] + + data["vis_cyc"] = ( data["vis_cyc"] if data["vis_cyc"] is not None else data["vis_all"] ) - settings.VIS_AFF = ( + settings.VIS_CYC = data["vis_cyc"] + + data["vis_aff"] = ( data["vis_aff"] if data["vis_aff"] is not None else data["vis_all"] ) - settings.VIS_EMB = ( + settings.VIS_AFF = data['vis_aff'] + + data["vis_emb"] = ( data["vis_emb"] if data["vis_emb"] is not None else data["vis_all"] ) - settings.VIS_INT = ( + settings.VIS_EMB = data["vis_emb"] + + data["vis_int"] = ( data["vis_int"] if data["vis_int"] is not None else data["vis_all"] ) - settings.VIS_DIS = ( + settings.VIS_INT = data["vis_int"] + + data["vis_dis"] = ( data["vis_dis"] if data["vis_dis"] is not None else data["vis_all"] ) + settings.VIS_DIS = data["vis_dis"] + settings.VIS_POS = ( data["vis_pos"] if data["vis_pos"] is not None else data["vis_all"] ) @@ -354,30 +385,49 @@ def setup_visualisation_config(data: dict) -> None: settings.VIS_FORMAT = data["vis_format"] settings.VIS_Z_N_INT = data["vis_z_n_int"] - settings.FREQ_EVAL = ( + data["freq_eval"] = ( data["freq_eval"] if data["freq_eval"] != 0 else data["freq_all"] ) - settings.FREQ_REC = ( + settings.FREQ_EVAL = data["freq_eval"] + + data["freq_rec"] = ( data["freq_rec"] if data["freq_rec"] != 0 else data["freq_all"] ) - settings.FREQ_EMB = ( + settings.FREQ_REC = data["freq_rec"] + + data["freq_emb"] = ( data["freq_emb"] if data["freq_emb"] != 0 else data["freq_all"] ) - settings.FREQ_INT = ( + settings.FREQ_EMB = data["freq_emb"] + + data["freq_int"] = ( data["freq_int"] if data["freq_int"] != 0 else data["freq_all"] ) - settings.FREQ_DIS = ( + settings.FREQ_INT = data["freq_int"] + + data["freq_dis"] = ( data["freq_dis"] if data["freq_dis"] != 0 else data["freq_all"] ) - settings.FREQ_POS = ( + settings.FREQ_DIS = data["freq_dis"] + + data["freq_pos"] = ( data["freq_pos"] if data["freq_pos"] != 0 else data["freq_all"] ) - settings.FREQ_ACC = ( + settings.FREQ_POS = data["freq_pos"] + + data["freq_acc"] = ( data["freq_acc"] if data["freq_acc"] != 0 else data["freq_all"] ) - settings.FREQ_STA = ( + settings.FREQ_ACC = data["freq_acc"] + + data['freq_sta'] = ( data["freq_sta"] if data["freq_sta"] != 0 else data["freq_all"] ) - settings.FREQ_SIM = ( + settings.FREQ_STA = data["freq_sta"] + + data["freq_sim"] = ( data["freq_sim"] if data["freq_sim"] != 0 else data["freq_all"] ) + settings.FREQ_SIM = data["freq_sim"] + + return data diff --git a/run.py b/run.py index e7911f4..45283f7 100644 --- a/run.py +++ b/run.py @@ -661,7 +661,7 @@ def run( logging.getLogger("matplotlib.font_manager").disabled = True # visualisation global settings defined from config file - setup_visualisation_config(data) + data = setup_visualisation_config(data) if data["new_out"]: dir_name = f'results_{settings.date_time_run}_model_{data["model"]}_lat{data["latent_dims"]}_pose{data["pose_dims"]}_lr{data["learning"]}_beta{data["beta"]}_gamma{data["gamma"]}'