Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fixes #309

Merged
merged 3 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 74 additions & 21 deletions avae/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import pathlib

import yaml
from pydantic import (
Expand Down Expand Up @@ -110,7 +111,7 @@ class AffinityConfig(BaseModel):
)
pose_dims: int = Field(1, description="Pose dimensions")

rescale: float = Field(None, description="Rescale data")
rescale: int | None = Field(None, description="Rescale data")
restart: bool = Field(False, description="Restart training")
shift_min: bool = Field(
False, description="Scale data with min-max transformation"
Expand Down Expand Up @@ -291,12 +292,14 @@ def load_config_params(
+ key
+ " in config file or command line arguments. Default values will be used."
)

# turning all path into absolute paths and strings for output saving
if type(val) == pathlib.PosixPath:
setattr(data, key, str(val.absolute()))
# return data as dictionary
return data.model_dump()


def write_config_file(time_stamp_name, data):
def write_config_file(time_stamp_name, data) -> None:
# record final configuration in logger and save to yaml file
for key, val in data.items():
logging.info("Parameter " + key + " set to value: " + str(data[key]))
Expand All @@ -310,31 +313,62 @@ def write_config_file(time_stamp_name, data):
logging.info("YAML File saved!\n")


def setup_visualisation_config(data: dict) -> None:
settings.VIS_LOS = (
def setup_visualisation_config(data: dict) -> dict:
"""
Set up visualisation configurations. Logic is the following: if a specific visualisation is not set, it will be set to the value of vis_all.
For frequency, if it is set to 0 (the default), it will be set to the value of freq_all.

Parameters
----------
data : dict
Dictionary of configuration parameters.

Returns
-------
data : dict
Dictionary of configuration parameters updated with logic.
"""
data["vis_los"] = (
data["vis_los"] if data["vis_los"] is not None else data["vis_all"]
)
settings.VIS_ACC = (
settings.VIS_LOS = data["vis_los"]

data["vis_acc"] = (
data["vis_acc"] if data["vis_acc"] is not None else data["vis_all"]
)
settings.VIS_REC = (

settings.VIS_ACC = data["vis_acc"]

data["vis_rec"] = (
data["vis_rec"] if data["vis_rec"] is not None else data["vis_all"]
)
settings.VIS_CYC = (
settings.VIS_REC = data["vis_rec"]

data["vis_cyc"] = (
data["vis_cyc"] if data["vis_cyc"] is not None else data["vis_all"]
)
settings.VIS_AFF = (
settings.VIS_CYC = data["vis_cyc"]

data["vis_aff"] = (
data["vis_aff"] if data["vis_aff"] is not None else data["vis_all"]
)
settings.VIS_EMB = (
settings.VIS_AFF = data['vis_aff']

data["vis_emb"] = (
data["vis_emb"] if data["vis_emb"] is not None else data["vis_all"]
)
settings.VIS_INT = (
settings.VIS_EMB = data["vis_emb"]

data["vis_int"] = (
data["vis_int"] if data["vis_int"] is not None else data["vis_all"]
)
settings.VIS_DIS = (
settings.VIS_INT = data["vis_int"]

data["vis_dis"] = (
data["vis_dis"] if data["vis_dis"] is not None else data["vis_all"]
)
settings.VIS_DIS = data["vis_dis"]

settings.VIS_POS = (
data["vis_pos"] if data["vis_pos"] is not None else data["vis_all"]
)
Expand All @@ -351,30 +385,49 @@ def setup_visualisation_config(data: dict) -> None:
settings.VIS_FORMAT = data["vis_format"]
settings.VIS_Z_N_INT = data["vis_z_n_int"]

settings.FREQ_EVAL = (
data["freq_eval"] = (
data["freq_eval"] if data["freq_eval"] != 0 else data["freq_all"]
)
settings.FREQ_REC = (
settings.FREQ_EVAL = data["freq_eval"]

data["freq_rec"] = (
data["freq_rec"] if data["freq_rec"] != 0 else data["freq_all"]
)
settings.FREQ_EMB = (
settings.FREQ_REC = data["freq_rec"]

data["freq_emb"] = (
data["freq_emb"] if data["freq_emb"] != 0 else data["freq_all"]
)
settings.FREQ_INT = (
settings.FREQ_EMB = data["freq_emb"]

data["freq_int"] = (
data["freq_int"] if data["freq_int"] != 0 else data["freq_all"]
)
settings.FREQ_DIS = (
settings.FREQ_INT = data["freq_int"]

data["freq_dis"] = (
data["freq_dis"] if data["freq_dis"] != 0 else data["freq_all"]
)
settings.FREQ_POS = (
settings.FREQ_DIS = data["freq_dis"]

data["freq_pos"] = (
data["freq_pos"] if data["freq_pos"] != 0 else data["freq_all"]
)
settings.FREQ_ACC = (
settings.FREQ_POS = data["freq_pos"]

data["freq_acc"] = (
data["freq_acc"] if data["freq_acc"] != 0 else data["freq_all"]
)
settings.FREQ_STA = (
settings.FREQ_ACC = data["freq_acc"]

data['freq_sta'] = (
data["freq_sta"] if data["freq_sta"] != 0 else data["freq_all"]
)
settings.FREQ_SIM = (
settings.FREQ_STA = data["freq_sta"]

data["freq_sim"] = (
data["freq_sim"] if data["freq_sim"] != 0 else data["freq_all"]
)
settings.FREQ_SIM = data["freq_sim"]

return data
8 changes: 4 additions & 4 deletions avae/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def load_data(
gaussian_blur: bool = False,
normalise: bool = False,
shift_min: bool = False,
rescale: bool | None = None,
rescale: int | None = None,
) -> tuple[DataLoader, int]:
...

Expand All @@ -53,7 +53,7 @@ def load_data(
gaussian_blur: bool = False,
normalise: bool = False,
shift_min: bool = False,
rescale: bool | None = None,
rescale: int | None = None,
) -> tuple[DataLoader, DataLoader, DataLoader, pd.DataFrame, int]:
...

Expand All @@ -72,7 +72,7 @@ def load_data(
gaussian_blur: bool = False,
normalise: bool = False,
shift_min: bool = False,
rescale: bool | None = None,
rescale: int | None = None,
) -> tuple[DataLoader, DataLoader, DataLoader, pd.DataFrame, int] | tuple[
DataLoader, int
]:
Expand Down Expand Up @@ -106,7 +106,7 @@ def load_data(
In True, the input data is normalised before being passed to the model.
shift_min: bool
If True, the minimum value of the input data is shifted to 0 and maximum to 1.
rescale: int
rescale: int | None
If not None, the input data is rescaled to the given value.


Expand Down
6 changes: 3 additions & 3 deletions avae/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def train(
gaussian_blur: bool,
normalise: bool,
shift_min: bool,
rescale: bool,
rescale: int | None,
tensorboard: bool,
classifier: str,
strategy: str,
Expand Down Expand Up @@ -160,8 +160,8 @@ def train(
Path to the beta values to load.
gamma_load: str
Path to the gamma values to load.
rescale: bool
If True, the input data is rescaled to have a mean of 0 and std of 1.
rescale: int | None
If provided, the data is rescaled by the value.
"""
lt.pytorch.seed_everything(42)

Expand Down
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def run(
logging.getLogger("matplotlib.font_manager").disabled = True

# visualisation global settings defined from config file
setup_visualisation_config(data)
data = setup_visualisation_config(data)

if data["new_out"]:
dir_name = f'results_{settings.date_time_run}_model_{data["model"]}_lat{data["latent_dims"]}_pose{data["pose_dims"]}_lr{data["learning"]}_beta{data["beta"]}_gamma{data["gamma"]}'
Expand Down
Loading