From 43e216e7e2882fa6280b496d663705bd49554c0e Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 25 Jul 2025 14:26:35 -0400 Subject: [PATCH 01/39] Making train and trainer modules, needs typing and docstrings --- src/midst_toolkit/models/clavaddpm/model.py | 449 +----------------- src/midst_toolkit/models/clavaddpm/train.py | 373 +++++++++++++++ src/midst_toolkit/models/clavaddpm/trainer.py | 105 ++++ .../models/clavaddpm/test_model.py | 2 +- 4 files changed, 481 insertions(+), 448 deletions(-) create mode 100644 src/midst_toolkit/models/clavaddpm/train.py create mode 100644 src/midst_toolkit/models/clavaddpm/trainer.py diff --git a/src/midst_toolkit/models/clavaddpm/model.py b/src/midst_toolkit/models/clavaddpm/model.py index 0d312107..f636b552 100644 --- a/src/midst_toolkit/models/clavaddpm/model.py +++ b/src/midst_toolkit/models/clavaddpm/model.py @@ -1,11 +1,10 @@ import hashlib import json import math -import os import pickle from abc import ABC, abstractmethod from collections import Counter, defaultdict -from collections.abc import Callable, Generator, Iterator +from collections.abc import Callable, Generator from copy import deepcopy from dataclasses import astuple, dataclass, replace from enum import Enum @@ -34,7 +33,7 @@ QuantileTransformer, StandardScaler, ) -from torch import Tensor, nn, optim +from torch import Tensor, nn from midst_toolkit.core import logger from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion @@ -228,233 +227,6 @@ def _get_labels_and_probs( return labels.astype("int64"), probs -def clava_clustering(tables, relation_order, save_dir, configs): - relation_order_reversed = relation_order[::-1] - all_group_lengths_prob_dicts = {} - - # Clustering - if os.path.exists(os.path.join(save_dir, "cluster_ckpt.pkl")): - print("Clustering checkpoint found, loading...") - cluster_ckpt = pickle.load(open(os.path.join(save_dir, "cluster_ckpt.pkl"), "rb")) - # ruff: noqa: SIM115 - tables = cluster_ckpt["tables"] - all_group_lengths_prob_dicts = cluster_ckpt["all_group_lengths_prob_dicts"] - else: - for parent, child in relation_order_reversed: - if parent is not None: - print(f"Clustering {parent} -> {child}") - if isinstance(configs["clustering"]["num_clusters"], dict): - num_clusters = configs["clustering"]["num_clusters"][child] - else: - num_clusters = configs["clustering"]["num_clusters"] - ( - parent_df_with_cluster, - child_df_with_cluster, - group_lengths_prob_dicts, - ) = pair_clustering_keep_id( - tables[child]["df"], - tables[child]["domain"], - tables[parent]["df"], - tables[parent]["domain"], - f"{child}_id", - f"{parent}_id", - num_clusters, - configs["clustering"]["parent_scale"], - 1, # not used for now - parent, - child, - clustering_method=configs["clustering"]["clustering_method"], - ) - tables[parent]["df"] = parent_df_with_cluster - tables[child]["df"] = child_df_with_cluster - all_group_lengths_prob_dicts[(parent, child)] = group_lengths_prob_dicts - - cluster_ckpt = { - "tables": tables, - "all_group_lengths_prob_dicts": all_group_lengths_prob_dicts, - } - pickle.dump(cluster_ckpt, open(os.path.join(save_dir, "cluster_ckpt.pkl"), "wb")) - # ruff: noqa: SIM115 - - for parent, child in relation_order: - if parent is None: - tables[child]["df"]["placeholder"] = list(range(len(tables[child]["df"]))) - - return tables, all_group_lengths_prob_dicts - - -def clava_training(tables, relation_order, save_dir, configs, device="cuda"): - models = {} - for parent, child in relation_order: - print(f"Training {parent} -> {child} model from scratch") - df_with_cluster = tables[child]["df"] - id_cols = [col for col in df_with_cluster.columns if "_id" in col] - df_without_id = df_with_cluster.drop(columns=id_cols) - - result = child_training( - df_without_id, - tables[child]["domain"], - parent, - child, - configs, - device, - ) - - models[(parent, child)] = result - pickle.dump( - result, - open(os.path.join(save_dir, f"models/{parent}_{child}_ckpt.pkl"), "wb"), - # ruff: noqa: SIM115 - ) - - return models - - -def child_training( - child_df_with_cluster: pd.DataFrame, - child_domain_dict: dict[str, Any], - parent_name: str | None, - child_name: str, - configs: dict[str, Any], - device: str = "cuda", -) -> dict[str, Any]: - if parent_name is None: - y_col = "placeholder" - child_df_with_cluster["placeholder"] = list(range(len(child_df_with_cluster))) - else: - y_col = f"{parent_name}_{child_name}_cluster" - child_info = get_table_info(child_df_with_cluster, child_domain_dict, y_col) - child_model_params = get_model_params( - { - "d_layers": configs["diffusion"]["d_layers"], - "dropout": configs["diffusion"]["dropout"], - } - ) - child_T_dict = get_T_dict() - - child_result = train_model( - child_df_with_cluster, - child_info, - child_model_params, - child_T_dict, - configs["diffusion"]["iterations"], - configs["diffusion"]["batch_size"], - configs["diffusion"]["model_type"], - configs["diffusion"]["gaussian_loss_type"], - configs["diffusion"]["num_timesteps"], - configs["diffusion"]["scheduler"], - configs["diffusion"]["lr"], - configs["diffusion"]["weight_decay"], - device=device, - ) - - if parent_name is None: - child_result["classifier"] = None - elif configs["classifier"]["iterations"] > 0: - child_classifier = train_classifier( - child_df_with_cluster, - child_info, - child_model_params, - child_T_dict, - configs["classifier"]["iterations"], - configs["classifier"]["batch_size"], - configs["diffusion"]["gaussian_loss_type"], - configs["diffusion"]["num_timesteps"], - configs["diffusion"]["scheduler"], - cluster_col=y_col, - d_layers=configs["classifier"]["d_layers"], - dim_t=configs["classifier"]["dim_t"], - lr=configs["classifier"]["lr"], - device=device, - ) - child_result["classifier"] = child_classifier - - child_result["df_info"] = child_info - child_result["model_params"] = child_model_params - child_result["T_dict"] = child_T_dict - return child_result - - -def train_model( - df: pd.DataFrame, - df_info: pd.DataFrame, - model_params: dict[str, Any], - T_dict: dict[str, Any], - steps: int, - batch_size: int, - model_type: str, - gaussian_loss_type: str, - num_timesteps: int, - scheduler: str, - lr: float, - weight_decay: float, - device: str = "cuda", -) -> dict[str, Any]: - T = Transformations(**T_dict) - dataset, label_encoders, column_orders = make_dataset_from_df( - df, - T, - is_y_cond=model_params["is_y_cond"], - ratios=[0.99, 0.005, 0.005], - df_info=df_info, - std=0, - ) - # print(dataset.n_features) - train_loader = prepare_fast_dataloader(dataset, split="train", batch_size=batch_size, y_type="long") - - num_numerical_features = dataset.X_num["train"].shape[1] if dataset.X_num is not None else 0 - - K = np.array(dataset.get_category_sizes("train")) - if len(K) == 0 or T_dict["cat_encoding"] == "one-hot": - K = np.array([0]) - # print(K) - - num_numerical_features = dataset.X_num["train"].shape[1] if dataset.X_num is not None else 0 - d_in = np.sum(K) + num_numerical_features - model_params["d_in"] = d_in - # print(d_in) - - print("Model params: {}".format(model_params)) - model = get_model(model_type, model_params) - model.to(device) - - train_loader = prepare_fast_dataloader(dataset, split="train", batch_size=batch_size) - - diffusion = GaussianMultinomialDiffusion( - num_classes=K, - num_numerical_features=num_numerical_features, - denoise_fn=model, - gaussian_loss_type=gaussian_loss_type, - num_timesteps=num_timesteps, - scheduler=scheduler, - device=torch.device(device), - ) - diffusion.to(device) - diffusion.train() - - trainer = Trainer( - diffusion, - train_loader, - lr=lr, - weight_decay=weight_decay, - steps=steps, - device=device, - ) - trainer.run_loop() - - if model_params["is_y_cond"] == "concat": - column_orders = column_orders[1:] + [column_orders[0]] - else: - column_orders = column_orders + [df_info["y_col"]] - - return { - "diffusion": diffusion, - "label_encoders": label_encoders, - "dataset": dataset, - "column_orders": column_orders, - } - - class Classifier(nn.Module): def __init__( self, @@ -504,127 +276,6 @@ def forward(self, x, timesteps): return self.model(x) -def train_classifier( - df: pd.DataFrame, - df_info: pd.DataFrame, - model_params: dict[str, Any], - T_dict: dict[str, Any], - classifier_steps: int, - batch_size: int, - gaussian_loss_type: str, - num_timesteps: int, - scheduler: str, - d_layers: list[int], - device: str = "cuda", - cluster_col: str = "cluster", - dim_t: int = 128, - lr: float = 0.0001, -) -> Classifier: - T = Transformations(**T_dict) - dataset, label_encoders, column_orders = make_dataset_from_df( - df, - T, - is_y_cond=model_params["is_y_cond"], - ratios=[0.99, 0.005, 0.005], - df_info=df_info, - std=0, - ) - print(dataset.n_features) - train_loader = prepare_fast_dataloader(dataset, split="train", batch_size=batch_size, y_type="long") - val_loader = prepare_fast_dataloader(dataset, split="val", batch_size=batch_size, y_type="long") - test_loader = prepare_fast_dataloader(dataset, split="test", batch_size=batch_size, y_type="long") - - eval_interval = 5 - # log_interval = 10 - - K = np.array(dataset.get_category_sizes("train")) - if len(K) == 0 or T_dict["cat_encoding"] == "one-hot": - K = np.array([0]) - print(K) - - num_numerical_features = dataset.X_num["train"].shape[1] if dataset.X_num is not None else 0 - if model_params["is_y_cond"] == "concat": - num_numerical_features -= 1 - - classifier = Classifier( - d_in=num_numerical_features, - d_out=int(max(df[cluster_col].values) + 1), - dim_t=dim_t, - hidden_sizes=d_layers, - ).to(device) - - classifier_optimizer = optim.AdamW(classifier.parameters(), lr=lr) - - empty_diffusion = GaussianMultinomialDiffusion( - num_classes=K, - num_numerical_features=num_numerical_features, - denoise_fn=None, # type: ignore[arg-type] - gaussian_loss_type=gaussian_loss_type, - num_timesteps=num_timesteps, - scheduler=scheduler, - device=torch.device(device), - ) - empty_diffusion.to(device) - - schedule_sampler = create_named_schedule_sampler("uniform", empty_diffusion) - - classifier.train() - resume_step = 0 - for step in range(classifier_steps): - logger.logkv("step", step + resume_step) - logger.logkv( - "samples", - (step + resume_step + 1) * batch_size, - ) - numerical_forward_backward_log( - classifier, - classifier_optimizer, - train_loader, - dataset, - schedule_sampler, - empty_diffusion, - prefix="train", - device=device, - ) - - classifier_optimizer.step() - if not step % eval_interval: - with torch.no_grad(): - classifier.eval() - numerical_forward_backward_log( - classifier, - classifier_optimizer, - val_loader, - dataset, - schedule_sampler, - empty_diffusion, - prefix="val", - device=device, - ) - classifier.train() - - # Removed because it's too verbose - # if not step % log_interval: - # logger.dumpkvs() - - # # test classifier - classifier.eval() - - correct = 0 - for _ in range(3000): - test_x, test_y = next(test_loader) - test_y = test_y.long().to(device) - test_x = test_x[:, 1:].to(device) if model_params["is_y_cond"] == "concat" else test_x.to(device) - with torch.no_grad(): - pred = classifier(test_x, timesteps=torch.zeros(test_x.shape[0]).to(device)) - correct += (pred.argmax(dim=1) == test_y).sum().item() - - acc = correct / (3000 * batch_size) - print(acc) - - return classifier - - def pair_clustering_keep_id( child_df: pd.DataFrame, child_domain_dict: dict[str, Any], @@ -1222,22 +873,6 @@ def get_model( raise ValueError("Unknown model!") -def update_ema( - target_params: Iterator[nn.Parameter], - source_params: Iterator[nn.Parameter], - rate: float = 0.999, -) -> None: - """ - Update target parameters to be closer to those of source parameters using - an exponential moving average. - :param target_params: the target parameter sequence. - :param source_params: the source parameter sequence. - :param rate: the EMA rate (closer to 1 means slower). - """ - for targ, src in zip(target_params, source_params): - targ.detach().mul_(rate).add_(src.detach(), alpha=1 - rate) - - class ScheduleSampler(ABC): """ A distribution over timesteps in the diffusion process, intended to reduce @@ -1718,86 +1353,6 @@ def build_target(y: ArrayDict, policy: YPolicy | None, task_type: TaskType) -> t return y, info -class Trainer: - def __init__( - self, - diffusion: GaussianMultinomialDiffusion, - train_iter: Generator[tuple[Tensor, ...]], - lr: float, - weight_decay: float, - steps: int, - device: str = "cuda", - ): - self.diffusion = diffusion - self.ema_model = deepcopy(self.diffusion._denoise_fn) - for param in self.ema_model.parameters(): - param.detach_() - - self.train_iter = train_iter - self.steps = steps - self.init_lr = lr - self.optimizer = torch.optim.AdamW(self.diffusion.parameters(), lr=lr, weight_decay=weight_decay) - self.device = device - self.loss_history = pd.DataFrame(columns=["step", "mloss", "gloss", "loss"]) - self.log_every = 100 - self.print_every = 500 - self.ema_every = 1000 - - def _anneal_lr(self, step: int) -> None: - frac_done = step / self.steps - lr = self.init_lr * (1 - frac_done) - for param_group in self.optimizer.param_groups: - param_group["lr"] = lr - - def _run_step(self, x: Tensor, out_dict: dict[str, Tensor]) -> tuple[Tensor, Tensor]: - x = x.to(self.device) - for k, v in out_dict.items(): - out_dict[k] = v.long().to(self.device) - self.optimizer.zero_grad() - loss_multi, loss_gauss = self.diffusion.mixed_loss(x, out_dict) - loss = loss_multi + loss_gauss - loss.backward() # type: ignore[no-untyped-call] - self.optimizer.step() - - return loss_multi, loss_gauss - - def run_loop(self) -> None: - step = 0 - curr_loss_multi = 0.0 - curr_loss_gauss = 0.0 - - curr_count = 0 - while step < self.steps: - x, out = next(self.train_iter) - out_dict = {"y": out} - batch_loss_multi, batch_loss_gauss = self._run_step(x, out_dict) - - self._anneal_lr(step) - - curr_count += len(x) - curr_loss_multi += batch_loss_multi.item() * len(x) - curr_loss_gauss += batch_loss_gauss.item() * len(x) - - if (step + 1) % self.log_every == 0: - mloss = np.around(curr_loss_multi / curr_count, 4) - gloss = np.around(curr_loss_gauss / curr_count, 4) - if (step + 1) % self.print_every == 0: - print(f"Step {(step + 1)}/{self.steps} MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}") - self.loss_history.loc[len(self.loss_history)] = [ - step + 1, - mloss, - gloss, - mloss + gloss, - ] - curr_count = 0 - curr_loss_gauss = 0.0 - curr_loss_multi = 0.0 - - update_ema(self.ema_model.parameters(), self.diffusion._denoise_fn.parameters()) - - step += 1 - - class FastTensorDataLoader: """ Defines a faster dataloader for PyTorch tensors. diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py new file mode 100644 index 00000000..0e25b2b5 --- /dev/null +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -0,0 +1,373 @@ +import os +import pickle +from typing import Any + +import numpy as np +import pandas as pd +import torch +from torch import optim + +from midst_toolkit.core import logger +from midst_toolkit.models.clavaddpm.model import ( + Classifier, + Transformations, + create_named_schedule_sampler, + get_model, + get_model_params, + get_T_dict, + get_table_info, + make_dataset_from_df, + numerical_forward_backward_log, + pair_clustering_keep_id, + prepare_fast_dataloader, +) +from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion +from midst_toolkit.models.clavaddpm.trainer import Trainer + + +def clava_clustering(tables, relation_order, save_dir, configs): + relation_order_reversed = relation_order[::-1] + all_group_lengths_prob_dicts = {} + + # Clustering + if os.path.exists(os.path.join(save_dir, "cluster_ckpt.pkl")): + print("Clustering checkpoint found, loading...") + cluster_ckpt = pickle.load(open(os.path.join(save_dir, "cluster_ckpt.pkl"), "rb")) + # ruff: noqa: SIM115 + tables = cluster_ckpt["tables"] + all_group_lengths_prob_dicts = cluster_ckpt["all_group_lengths_prob_dicts"] + else: + for parent, child in relation_order_reversed: + if parent is not None: + print(f"Clustering {parent} -> {child}") + if isinstance(configs["clustering"]["num_clusters"], dict): + num_clusters = configs["clustering"]["num_clusters"][child] + else: + num_clusters = configs["clustering"]["num_clusters"] + ( + parent_df_with_cluster, + child_df_with_cluster, + group_lengths_prob_dicts, + ) = pair_clustering_keep_id( + tables[child]["df"], + tables[child]["domain"], + tables[parent]["df"], + tables[parent]["domain"], + f"{child}_id", + f"{parent}_id", + num_clusters, + configs["clustering"]["parent_scale"], + 1, # not used for now + parent, + child, + clustering_method=configs["clustering"]["clustering_method"], + ) + tables[parent]["df"] = parent_df_with_cluster + tables[child]["df"] = child_df_with_cluster + all_group_lengths_prob_dicts[(parent, child)] = group_lengths_prob_dicts + + cluster_ckpt = { + "tables": tables, + "all_group_lengths_prob_dicts": all_group_lengths_prob_dicts, + } + pickle.dump(cluster_ckpt, open(os.path.join(save_dir, "cluster_ckpt.pkl"), "wb")) + # ruff: noqa: SIM115 + + for parent, child in relation_order: + if parent is None: + tables[child]["df"]["placeholder"] = list(range(len(tables[child]["df"]))) + + return tables, all_group_lengths_prob_dicts + + +def clava_training(tables, relation_order, save_dir, configs, device="cuda"): + models = {} + for parent, child in relation_order: + print(f"Training {parent} -> {child} model from scratch") + df_with_cluster = tables[child]["df"] + id_cols = [col for col in df_with_cluster.columns if "_id" in col] + df_without_id = df_with_cluster.drop(columns=id_cols) + + result = child_training( + df_without_id, + tables[child]["domain"], + parent, + child, + configs, + device, + ) + + models[(parent, child)] = result + pickle.dump( + result, + open(os.path.join(save_dir, f"models/{parent}_{child}_ckpt.pkl"), "wb"), + # ruff: noqa: SIM115 + ) + + return models + + +def child_training( + child_df_with_cluster: pd.DataFrame, + child_domain_dict: dict[str, Any], + parent_name: str | None, + child_name: str, + configs: dict[str, Any], + device: str = "cuda", +) -> dict[str, Any]: + if parent_name is None: + y_col = "placeholder" + child_df_with_cluster["placeholder"] = list(range(len(child_df_with_cluster))) + else: + y_col = f"{parent_name}_{child_name}_cluster" + child_info = get_table_info(child_df_with_cluster, child_domain_dict, y_col) + child_model_params = get_model_params( + { + "d_layers": configs["diffusion"]["d_layers"], + "dropout": configs["diffusion"]["dropout"], + } + ) + child_T_dict = get_T_dict() + + child_result = train_model( + child_df_with_cluster, + child_info, + child_model_params, + child_T_dict, + configs["diffusion"]["iterations"], + configs["diffusion"]["batch_size"], + configs["diffusion"]["model_type"], + configs["diffusion"]["gaussian_loss_type"], + configs["diffusion"]["num_timesteps"], + configs["diffusion"]["scheduler"], + configs["diffusion"]["lr"], + configs["diffusion"]["weight_decay"], + device=device, + ) + + if parent_name is None: + child_result["classifier"] = None + elif configs["classifier"]["iterations"] > 0: + child_classifier = train_classifier( + child_df_with_cluster, + child_info, + child_model_params, + child_T_dict, + configs["classifier"]["iterations"], + configs["classifier"]["batch_size"], + configs["diffusion"]["gaussian_loss_type"], + configs["diffusion"]["num_timesteps"], + configs["diffusion"]["scheduler"], + cluster_col=y_col, + d_layers=configs["classifier"]["d_layers"], + dim_t=configs["classifier"]["dim_t"], + lr=configs["classifier"]["lr"], + device=device, + ) + child_result["classifier"] = child_classifier + + child_result["df_info"] = child_info + child_result["model_params"] = child_model_params + child_result["T_dict"] = child_T_dict + return child_result + + +def train_model( + df: pd.DataFrame, + df_info: pd.DataFrame, + model_params: dict[str, Any], + T_dict: dict[str, Any], + steps: int, + batch_size: int, + model_type: str, + gaussian_loss_type: str, + num_timesteps: int, + scheduler: str, + lr: float, + weight_decay: float, + device: str = "cuda", +) -> dict[str, Any]: + T = Transformations(**T_dict) + dataset, label_encoders, column_orders = make_dataset_from_df( + df, + T, + is_y_cond=model_params["is_y_cond"], + ratios=[0.99, 0.005, 0.005], + df_info=df_info, + std=0, + ) + # print(dataset.n_features) + train_loader = prepare_fast_dataloader(dataset, split="train", batch_size=batch_size, y_type="long") + + num_numerical_features = dataset.X_num["train"].shape[1] if dataset.X_num is not None else 0 + + K = np.array(dataset.get_category_sizes("train")) + if len(K) == 0 or T_dict["cat_encoding"] == "one-hot": + K = np.array([0]) + # print(K) + + num_numerical_features = dataset.X_num["train"].shape[1] if dataset.X_num is not None else 0 + d_in = np.sum(K) + num_numerical_features + model_params["d_in"] = d_in + # print(d_in) + + print("Model params: {}".format(model_params)) + model = get_model(model_type, model_params) + model.to(device) + + train_loader = prepare_fast_dataloader(dataset, split="train", batch_size=batch_size) + + diffusion = GaussianMultinomialDiffusion( + num_classes=K, + num_numerical_features=num_numerical_features, + denoise_fn=model, + gaussian_loss_type=gaussian_loss_type, + num_timesteps=num_timesteps, + scheduler=scheduler, + device=torch.device(device), + ) + diffusion.to(device) + diffusion.train() + + trainer = Trainer( + diffusion, + train_loader, + lr=lr, + weight_decay=weight_decay, + steps=steps, + device=device, + ) + trainer.run_loop() + + if model_params["is_y_cond"] == "concat": + column_orders = column_orders[1:] + [column_orders[0]] + else: + column_orders = column_orders + [df_info["y_col"]] + + return { + "diffusion": diffusion, + "label_encoders": label_encoders, + "dataset": dataset, + "column_orders": column_orders, + } + + +def train_classifier( + df: pd.DataFrame, + df_info: pd.DataFrame, + model_params: dict[str, Any], + T_dict: dict[str, Any], + classifier_steps: int, + batch_size: int, + gaussian_loss_type: str, + num_timesteps: int, + scheduler: str, + d_layers: list[int], + device: str = "cuda", + cluster_col: str = "cluster", + dim_t: int = 128, + lr: float = 0.0001, +) -> Classifier: + T = Transformations(**T_dict) + dataset, label_encoders, column_orders = make_dataset_from_df( + df, + T, + is_y_cond=model_params["is_y_cond"], + ratios=[0.99, 0.005, 0.005], + df_info=df_info, + std=0, + ) + print(dataset.n_features) + train_loader = prepare_fast_dataloader(dataset, split="train", batch_size=batch_size, y_type="long") + val_loader = prepare_fast_dataloader(dataset, split="val", batch_size=batch_size, y_type="long") + test_loader = prepare_fast_dataloader(dataset, split="test", batch_size=batch_size, y_type="long") + + eval_interval = 5 + # log_interval = 10 + + K = np.array(dataset.get_category_sizes("train")) + if len(K) == 0 or T_dict["cat_encoding"] == "one-hot": + K = np.array([0]) + print(K) + + num_numerical_features = dataset.X_num["train"].shape[1] if dataset.X_num is not None else 0 + if model_params["is_y_cond"] == "concat": + num_numerical_features -= 1 + + classifier = Classifier( + d_in=num_numerical_features, + d_out=int(max(df[cluster_col].values) + 1), + dim_t=dim_t, + hidden_sizes=d_layers, + ).to(device) + + classifier_optimizer = optim.AdamW(classifier.parameters(), lr=lr) + + empty_diffusion = GaussianMultinomialDiffusion( + num_classes=K, + num_numerical_features=num_numerical_features, + denoise_fn=None, # type: ignore[arg-type] + gaussian_loss_type=gaussian_loss_type, + num_timesteps=num_timesteps, + scheduler=scheduler, + device=torch.device(device), + ) + empty_diffusion.to(device) + + schedule_sampler = create_named_schedule_sampler("uniform", empty_diffusion) + + classifier.train() + resume_step = 0 + for step in range(classifier_steps): + logger.logkv("step", step + resume_step) + logger.logkv( + "samples", + (step + resume_step + 1) * batch_size, + ) + numerical_forward_backward_log( + classifier, + classifier_optimizer, + train_loader, + dataset, + schedule_sampler, + empty_diffusion, + prefix="train", + device=device, + ) + + classifier_optimizer.step() + if not step % eval_interval: + with torch.no_grad(): + classifier.eval() + numerical_forward_backward_log( + classifier, + classifier_optimizer, + val_loader, + dataset, + schedule_sampler, + empty_diffusion, + prefix="val", + device=device, + ) + classifier.train() + + # Removed because it's too verbose + # if not step % log_interval: + # logger.dumpkvs() + + # # test classifier + classifier.eval() + + correct = 0 + for _ in range(3000): + test_x, test_y = next(test_loader) + test_y = test_y.long().to(device) + test_x = test_x[:, 1:].to(device) if model_params["is_y_cond"] == "concat" else test_x.to(device) + with torch.no_grad(): + pred = classifier(test_x, timesteps=torch.zeros(test_x.shape[0]).to(device)) + correct += (pred.argmax(dim=1) == test_y).sum().item() + + acc = correct / (3000 * batch_size) + print(acc) + + return classifier diff --git a/src/midst_toolkit/models/clavaddpm/trainer.py b/src/midst_toolkit/models/clavaddpm/trainer.py new file mode 100644 index 00000000..4cc75445 --- /dev/null +++ b/src/midst_toolkit/models/clavaddpm/trainer.py @@ -0,0 +1,105 @@ +from collections.abc import Generator, Iterator +from copy import deepcopy + +import numpy as np +import pandas as pd +import torch +from torch import Tensor, nn + +from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion + + +class Trainer: + def __init__( + self, + diffusion: GaussianMultinomialDiffusion, + train_iter: Generator[tuple[Tensor, ...]], + lr: float, + weight_decay: float, + steps: int, + device: str = "cuda", + ): + self.diffusion = diffusion + self.ema_model = deepcopy(self.diffusion._denoise_fn) + for param in self.ema_model.parameters(): + param.detach_() + + self.train_iter = train_iter + self.steps = steps + self.init_lr = lr + self.optimizer = torch.optim.AdamW(self.diffusion.parameters(), lr=lr, weight_decay=weight_decay) + self.device = device + self.loss_history = pd.DataFrame(columns=["step", "mloss", "gloss", "loss"]) + self.log_every = 100 + self.print_every = 500 + self.ema_every = 1000 + + def _anneal_lr(self, step: int) -> None: + frac_done = step / self.steps + lr = self.init_lr * (1 - frac_done) + for param_group in self.optimizer.param_groups: + param_group["lr"] = lr + + def _run_step(self, x: Tensor, out_dict: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + x = x.to(self.device) + for k, v in out_dict.items(): + out_dict[k] = v.long().to(self.device) + self.optimizer.zero_grad() + loss_multi, loss_gauss = self.diffusion.mixed_loss(x, out_dict) + loss = loss_multi + loss_gauss + loss.backward() # type: ignore[no-untyped-call] + self.optimizer.step() + + return loss_multi, loss_gauss + + def run_loop(self) -> None: + step = 0 + curr_loss_multi = 0.0 + curr_loss_gauss = 0.0 + + curr_count = 0 + while step < self.steps: + x, out = next(self.train_iter) + out_dict = {"y": out} + batch_loss_multi, batch_loss_gauss = self._run_step(x, out_dict) + + self._anneal_lr(step) + + curr_count += len(x) + curr_loss_multi += batch_loss_multi.item() * len(x) + curr_loss_gauss += batch_loss_gauss.item() * len(x) + + if (step + 1) % self.log_every == 0: + mloss = np.around(curr_loss_multi / curr_count, 4) + gloss = np.around(curr_loss_gauss / curr_count, 4) + if (step + 1) % self.print_every == 0: + print(f"Step {(step + 1)}/{self.steps} MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}") + self.loss_history.loc[len(self.loss_history)] = [ + step + 1, + mloss, + gloss, + mloss + gloss, + ] + curr_count = 0 + curr_loss_gauss = 0.0 + curr_loss_multi = 0.0 + + update_ema(self.ema_model.parameters(), self.diffusion._denoise_fn.parameters()) + + step += 1 + + +def update_ema( + target_params: Iterator[nn.Parameter], + source_params: Iterator[nn.Parameter], + rate: float = 0.999, +) -> None: + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src.detach(), alpha=1 - rate) diff --git a/tests/integration/models/clavaddpm/test_model.py b/tests/integration/models/clavaddpm/test_model.py index d7980ead..77995c24 100644 --- a/tests/integration/models/clavaddpm/test_model.py +++ b/tests/integration/models/clavaddpm/test_model.py @@ -4,7 +4,7 @@ import pytest from midst_toolkit.core.data_loaders import load_multi_table -from midst_toolkit.models.clavaddpm.model import clava_clustering, clava_training +from midst_toolkit.models.clavaddpm.train import clava_clustering, clava_training CLUSTERING_CONFIG = { From 9d6458a105bcf551a7e90c6b4fa99c99065afd6e Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 25 Jul 2025 14:52:05 -0400 Subject: [PATCH 02/39] Adding docstrings to clava_clusteting --- src/midst_toolkit/models/clavaddpm/train.py | 59 +++++++++++++++---- src/midst_toolkit/models/clavaddpm/trainer.py | 2 + .../models/clavaddpm/test_model.py | 11 ++-- 3 files changed, 56 insertions(+), 16 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index 0e25b2b5..8d3432b9 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -1,5 +1,8 @@ +"""Defines the training functions for the ClavaDDPM model.""" + import os import pickle +from pathlib import Path from typing import Any import numpy as np @@ -8,6 +11,7 @@ from torch import optim from midst_toolkit.core import logger +from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion from midst_toolkit.models.clavaddpm.model import ( Classifier, Transformations, @@ -21,18 +25,53 @@ pair_clustering_keep_id, prepare_fast_dataloader, ) -from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion from midst_toolkit.models.clavaddpm.trainer import Trainer -def clava_clustering(tables, relation_order, save_dir, configs): +Tables = dict[str, dict[str, Any]] +RelationOrder = list[tuple[str, str]] +Configs = dict[str, Any] + + +def clava_clustering( + tables: Tables, + relation_order: RelationOrder, + save_dir: Path, + configs: Configs, +) -> tuple[dict[str, Any], dict[tuple[str, str], dict[int, float]]]: + """ + Clustering function for the mutli-table function of theClavaDDPM model. + + Args: + tables: Definition of the tables and their relations. Example: + { + "table1": { + "children": ["table2"], + "parents": [] + }, + "table2": { + "children": [], + "parents": ["table1"] + } + } + relation_order: List of tuples of parent and child tables. Example: + [("table1", "table2")] + save_dir: Directory to save the clustering checkpoint. + configs: Dictionary of configurations. The following config keys are required: + { + num_clusters = int | dict, + parent_scale = float, + clustering_method = str["kmeans" | "both" | "variational" | "gmm"], + } + + """ relation_order_reversed = relation_order[::-1] all_group_lengths_prob_dicts = {} # Clustering - if os.path.exists(os.path.join(save_dir, "cluster_ckpt.pkl")): + if os.path.exists(save_dir / "cluster_ckpt.pkl"): print("Clustering checkpoint found, loading...") - cluster_ckpt = pickle.load(open(os.path.join(save_dir, "cluster_ckpt.pkl"), "rb")) + cluster_ckpt = pickle.load(open(save_dir / "cluster_ckpt.pkl", "rb")) # ruff: noqa: SIM115 tables = cluster_ckpt["tables"] all_group_lengths_prob_dicts = cluster_ckpt["all_group_lengths_prob_dicts"] @@ -40,10 +79,10 @@ def clava_clustering(tables, relation_order, save_dir, configs): for parent, child in relation_order_reversed: if parent is not None: print(f"Clustering {parent} -> {child}") - if isinstance(configs["clustering"]["num_clusters"], dict): - num_clusters = configs["clustering"]["num_clusters"][child] + if isinstance(configs["num_clusters"], dict): + num_clusters = configs["num_clusters"][child] else: - num_clusters = configs["clustering"]["num_clusters"] + num_clusters = configs["num_clusters"] ( parent_df_with_cluster, child_df_with_cluster, @@ -56,11 +95,11 @@ def clava_clustering(tables, relation_order, save_dir, configs): f"{child}_id", f"{parent}_id", num_clusters, - configs["clustering"]["parent_scale"], + configs["parent_scale"], 1, # not used for now parent, child, - clustering_method=configs["clustering"]["clustering_method"], + clustering_method=configs["clustering_method"], ) tables[parent]["df"] = parent_df_with_cluster tables[child]["df"] = child_df_with_cluster @@ -70,7 +109,7 @@ def clava_clustering(tables, relation_order, save_dir, configs): "tables": tables, "all_group_lengths_prob_dicts": all_group_lengths_prob_dicts, } - pickle.dump(cluster_ckpt, open(os.path.join(save_dir, "cluster_ckpt.pkl"), "wb")) + pickle.dump(cluster_ckpt, open(save_dir / "cluster_ckpt.pkl", "wb")) # ruff: noqa: SIM115 for parent, child in relation_order: diff --git a/src/midst_toolkit/models/clavaddpm/trainer.py b/src/midst_toolkit/models/clavaddpm/trainer.py index 4cc75445..c9ab3c7d 100644 --- a/src/midst_toolkit/models/clavaddpm/trainer.py +++ b/src/midst_toolkit/models/clavaddpm/trainer.py @@ -1,3 +1,5 @@ +"""Trainer class for the ClavaDDPM model.""" + from collections.abc import Generator, Iterator from copy import deepcopy diff --git a/tests/integration/models/clavaddpm/test_model.py b/tests/integration/models/clavaddpm/test_model.py index 77995c24..77b815c3 100644 --- a/tests/integration/models/clavaddpm/test_model.py +++ b/tests/integration/models/clavaddpm/test_model.py @@ -38,7 +38,7 @@ @pytest.mark.integration_test() def test_train_single_table(tmp_path: Path): os.makedirs(tmp_path / "models") - configs = {"clustering": CLUSTERING_CONFIG, "diffusion": DIFFUSION_CONFIG} + configs = {"diffusion": DIFFUSION_CONFIG} tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/single_table/") models = clava_training(tables, relation_order, tmp_path, configs, device="cpu") @@ -49,10 +49,10 @@ def test_train_single_table(tmp_path: Path): @pytest.mark.integration_test() def test_train_multi_table(tmp_path: Path): os.makedirs(tmp_path / "models") - configs = {"clustering": CLUSTERING_CONFIG, "diffusion": DIFFUSION_CONFIG, "classifier": CLASSIFIER_CONFIG} + configs = {"diffusion": DIFFUSION_CONFIG, "classifier": CLASSIFIER_CONFIG} tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/multi_table/") - tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, configs) + tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, CLUSTERING_CONFIG) models = clava_training(tables, relation_order, tmp_path, configs, device="cpu") assert models @@ -61,14 +61,13 @@ def test_train_multi_table(tmp_path: Path): @pytest.mark.integration_test() def test_clustering_reload(tmp_path: Path): os.makedirs(tmp_path / "models") - configs = {"clustering": CLUSTERING_CONFIG} tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/multi_table/") - tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, configs) + tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, CLUSTERING_CONFIG) assert all_group_lengths_prob_dicts # loading from previously saved clustering - tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, configs) + tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, CLUSTERING_CONFIG) assert all_group_lengths_prob_dicts From 4e58fec83beb58ace907c5527d91eaf14de64e13 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Mon, 28 Jul 2025 12:28:04 -0400 Subject: [PATCH 03/39] Adding docstring to the train.py module --- src/midst_toolkit/models/clavaddpm/train.py | 226 +++++++++++++++--- .../models/clavaddpm/test_model.py | 13 +- 2 files changed, 191 insertions(+), 48 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index 8d3432b9..a51095e9 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -55,7 +55,7 @@ def clava_clustering( } } relation_order: List of tuples of parent and child tables. Example: - [("table1", "table2")] + [("table1", "table2"), ("table1", "table3")] save_dir: Directory to save the clustering checkpoint. configs: Dictionary of configurations. The following config keys are required: { @@ -119,7 +119,59 @@ def clava_clustering( return tables, all_group_lengths_prob_dicts -def clava_training(tables, relation_order, save_dir, configs, device="cuda"): +def clava_training( + tables: Tables, + relation_order: RelationOrder, + save_dir: Path, + diffusion_config: Configs, + classifier_config: Configs | None, + device: str = "cuda", +) -> dict[tuple[str, str], dict[str, Any]]: + """ + Training function for the ClavaDDPM model. + + Args: + tables: Definition of the tables and their relations. Example: + { + "table1": { + "children": ["table2"], + "parents": [] + }, + "table2": { + "children": [], + "parents": ["table1"] + } + } + relation_order: List of tuples of parent and child tables. Example: + [("table1", "table2"), ("table1", "table3")] + save_dir: Directory to save the clustering checkpoint. + diffusion_config: Dictionary of configurations for the diffusion model. The following config keys are required: + { + d_layers = list[int], + dropout = float, + iterations = int, + batch_size = int, + model_type = str["mlp" | "resnet"], + gaussian_loss_type = str["mse" | "cross_entropy"], + num_timesteps = int, + scheduler = str["cosine" | "linear"], + lr = float, + weight_decay = float, + } + classifier_config: Dictionary of configurations for the classifier model. Not required for single table + training. The following config keys are required for multi-table training: + { + iterations = int, + batch_size = int, + d_layers = list[int], + dim_t = int, + lr = float, + } + device: Device to use for training. Default is `"cuda"`. + + Returns: + Dictionary of models for each parent-child pair. + """ models = {} for parent, child in relation_order: print(f"Training {parent} -> {child} model from scratch") @@ -132,16 +184,17 @@ def clava_training(tables, relation_order, save_dir, configs, device="cuda"): tables[child]["domain"], parent, child, - configs, + diffusion_config, + classifier_config, device, ) models[(parent, child)] = result - pickle.dump( - result, - open(os.path.join(save_dir, f"models/{parent}_{child}_ckpt.pkl"), "wb"), - # ruff: noqa: SIM115 - ) + + target_folder = save_dir / "models" + target_folder.mkdir(parents=True, exist_ok=True) + with open(target_folder / f"{parent}_{child}_ckpt.pkl", "wb") as f: + pickle.dump(result, f) return models @@ -151,9 +204,50 @@ def child_training( child_domain_dict: dict[str, Any], parent_name: str | None, child_name: str, - configs: dict[str, Any], + diffusion_config: Configs, + classifier_config: Configs | None, device: str = "cuda", ) -> dict[str, Any]: + """ + Training function for a single child table. + + Args: + child_df_with_cluster: DataFrame with the cluster column. + child_domain_dict: Dictionary of the child table domain. It should contain size and type for each + column of the table. For example: + { + "frequency": {"size": 3, "type": "discrete"}, + "account_date": {"size": 1535, "type": "continuous"}, + } + parent_name: Name of the parent table, or None if there is no parent. + child_name: Name of the child table. + diffusion_config: Dictionary of configurations for the diffusion model. The following config keys are required: + { + d_layers = list[int], + dropout = float, + iterations = int, + batch_size = int, + model_type = str["mlp" | "resnet"], + gaussian_loss_type = str["mse" | "cross_entropy"], + num_timesteps = int, + scheduler = str["cosine" | "linear"], + lr = float, + weight_decay = float, + } + classifier_config: Dictionary of configurations for the classifier model. Not required for single table + training. The following config keys are required for multi-table training: + { + iterations = int, + batch_size = int, + d_layers = list[int], + dim_t = int, + lr = float, + } + device: Device to use for training. Default is `"cuda"`. + + Returns: + Dictionary of the training results. + """ if parent_name is None: y_col = "placeholder" child_df_with_cluster["placeholder"] = list(range(len(child_df_with_cluster))) @@ -162,48 +256,51 @@ def child_training( child_info = get_table_info(child_df_with_cluster, child_domain_dict, y_col) child_model_params = get_model_params( { - "d_layers": configs["diffusion"]["d_layers"], - "dropout": configs["diffusion"]["dropout"], + "d_layers": diffusion_config["d_layers"], + "dropout": diffusion_config["dropout"], } ) child_T_dict = get_T_dict() + # ruff: noqa: N806 child_result = train_model( child_df_with_cluster, child_info, child_model_params, child_T_dict, - configs["diffusion"]["iterations"], - configs["diffusion"]["batch_size"], - configs["diffusion"]["model_type"], - configs["diffusion"]["gaussian_loss_type"], - configs["diffusion"]["num_timesteps"], - configs["diffusion"]["scheduler"], - configs["diffusion"]["lr"], - configs["diffusion"]["weight_decay"], + diffusion_config["iterations"], + diffusion_config["batch_size"], + diffusion_config["model_type"], + diffusion_config["gaussian_loss_type"], + diffusion_config["num_timesteps"], + diffusion_config["scheduler"], + diffusion_config["lr"], + diffusion_config["weight_decay"], device=device, ) if parent_name is None: child_result["classifier"] = None - elif configs["classifier"]["iterations"] > 0: - child_classifier = train_classifier( - child_df_with_cluster, - child_info, - child_model_params, - child_T_dict, - configs["classifier"]["iterations"], - configs["classifier"]["batch_size"], - configs["diffusion"]["gaussian_loss_type"], - configs["diffusion"]["num_timesteps"], - configs["diffusion"]["scheduler"], - cluster_col=y_col, - d_layers=configs["classifier"]["d_layers"], - dim_t=configs["classifier"]["dim_t"], - lr=configs["classifier"]["lr"], - device=device, - ) - child_result["classifier"] = child_classifier + else: + assert classifier_config is not None, "Classifier config is required for multi-table training" + if classifier_config["iterations"] > 0: + child_classifier = train_classifier( + child_df_with_cluster, + child_info, + child_model_params, + child_T_dict, + classifier_config["iterations"], + classifier_config["batch_size"], + diffusion_config["gaussian_loss_type"], + diffusion_config["num_timesteps"], + diffusion_config["scheduler"], + cluster_col=y_col, + d_layers=classifier_config["d_layers"], + dim_t=classifier_config["dim_t"], + lr=classifier_config["lr"], + device=device, + ) + child_result["classifier"] = child_classifier child_result["df_info"] = child_info child_result["model_params"] = child_model_params @@ -216,6 +313,7 @@ def train_model( df_info: pd.DataFrame, model_params: dict[str, Any], T_dict: dict[str, Any], + # ruff: noqa: N803 steps: int, batch_size: int, model_type: str, @@ -226,7 +324,33 @@ def train_model( weight_decay: float, device: str = "cuda", ) -> dict[str, Any]: + """ + Training function for the diffusion model. + + Args: + df: DataFrame to train the model on. + df_info: Dictionary of the table information. + model_params: Dictionary of the model parameters. + T_dict: Dictionary of the transformations. + steps: Number of steps to train the model. + batch_size: Batch size to use for training. + model_type: Type of the model to use. + gaussian_loss_type: Type of the gaussian loss to use. + num_timesteps: Number of timesteps to use for the diffusion model. + scheduler: Scheduler to use for the diffusion model. + lr: Learning rate to use for the diffusion model. + weight_decay: Weight decay to use for the diffusion model. + device: Device to use for training. Default is `"cuda"`. + + Returns: + Dictionary of the training results. It will contain the following keys: + - diffusion: The diffusion model. + - label_encoders: The label encoders. + - dataset: The dataset. + - column_orders: The column orders. + """ T = Transformations(**T_dict) + # ruff: noqa: N806 dataset, label_encoders, column_orders = make_dataset_from_df( df, T, @@ -241,8 +365,10 @@ def train_model( num_numerical_features = dataset.X_num["train"].shape[1] if dataset.X_num is not None else 0 K = np.array(dataset.get_category_sizes("train")) + # ruff: noqa: N806 if len(K) == 0 or T_dict["cat_encoding"] == "one-hot": K = np.array([0]) + # ruff: noqa: N806 # print(K) num_numerical_features = dataset.X_num["train"].shape[1] if dataset.X_num is not None else 0 @@ -296,6 +422,7 @@ def train_classifier( df_info: pd.DataFrame, model_params: dict[str, Any], T_dict: dict[str, Any], + # ruff: noqa: N803 classifier_steps: int, batch_size: int, gaussian_loss_type: str, @@ -307,7 +434,30 @@ def train_classifier( dim_t: int = 128, lr: float = 0.0001, ) -> Classifier: + """ + Training function for the classifier model. + + Args: + df: DataFrame to train the model on. + df_info: Dictionary of the table information. + model_params: Dictionary of the model parameters. + T_dict: Dictionary of the transformations. + classifier_steps: Number of steps to train the classifier. + batch_size: Batch size to use for training. + gaussian_loss_type: Type of the gaussian loss to use. + num_timesteps: Number of timesteps to use for the diffusion model. + scheduler: Scheduler to use for the diffusion model. + d_layers: List of the hidden sizes of the classifier. + device: Device to use for training. Default is `"cuda"`. + cluster_col: Name of the cluster column. Default is `"cluster"`. + dim_t: Dimension of the transformer. Default is 128. + lr: Learning rate to use for the classifier. Default is 0.0001. + + Returns: + The classifier model. + """ T = Transformations(**T_dict) + # ruff: noqa: N806 dataset, label_encoders, column_orders = make_dataset_from_df( df, T, @@ -325,8 +475,10 @@ def train_classifier( # log_interval = 10 K = np.array(dataset.get_category_sizes("train")) + # ruff: noqa: N806 if len(K) == 0 or T_dict["cat_encoding"] == "one-hot": K = np.array([0]) + # ruff: noqa: N806 print(K) num_numerical_features = dataset.X_num["train"].shape[1] if dataset.X_num is not None else 0 diff --git a/tests/integration/models/clavaddpm/test_model.py b/tests/integration/models/clavaddpm/test_model.py index 77b815c3..3538ef83 100644 --- a/tests/integration/models/clavaddpm/test_model.py +++ b/tests/integration/models/clavaddpm/test_model.py @@ -1,4 +1,3 @@ -import os from pathlib import Path import pytest @@ -37,31 +36,23 @@ @pytest.mark.integration_test() def test_train_single_table(tmp_path: Path): - os.makedirs(tmp_path / "models") - configs = {"diffusion": DIFFUSION_CONFIG} - tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/single_table/") - models = clava_training(tables, relation_order, tmp_path, configs, device="cpu") + models = clava_training(tables, relation_order, tmp_path, DIFFUSION_CONFIG, {}, device="cpu") assert models @pytest.mark.integration_test() def test_train_multi_table(tmp_path: Path): - os.makedirs(tmp_path / "models") - configs = {"diffusion": DIFFUSION_CONFIG, "classifier": CLASSIFIER_CONFIG} - tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/multi_table/") tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, CLUSTERING_CONFIG) - models = clava_training(tables, relation_order, tmp_path, configs, device="cpu") + models = clava_training(tables, relation_order, tmp_path, DIFFUSION_CONFIG, None, device="cpu") assert models @pytest.mark.integration_test() def test_clustering_reload(tmp_path: Path): - os.makedirs(tmp_path / "models") - tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/multi_table/") tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, CLUSTERING_CONFIG) From 08b9a107ed8bab20f69e73a843a2ec3852bb6c15 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Mon, 28 Jul 2025 12:32:04 -0400 Subject: [PATCH 04/39] fixing test, adding docstrings for the trainer module --- src/midst_toolkit/models/clavaddpm/trainer.py | 34 +++++++++++++++++-- .../models/clavaddpm/test_model.py | 4 +-- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/trainer.py b/src/midst_toolkit/models/clavaddpm/trainer.py index c9ab3c7d..43736e46 100644 --- a/src/midst_toolkit/models/clavaddpm/trainer.py +++ b/src/midst_toolkit/models/clavaddpm/trainer.py @@ -21,6 +21,18 @@ def __init__( steps: int, device: str = "cuda", ): + """ + Trainer class for the ClavaDDPM model. + + Args: + diffusion: The diffusion model. + train_iter: The training iterator. It should yield a tuple of tensors. The first tensor is the input + tensor and the second tensor is the output tensor. + lr: The learning rate. + weight_decay: The weight decay. + steps: The number of steps to train. + device: The device to use. Default is `"cuda"`. + """ self.diffusion = diffusion self.ema_model = deepcopy(self.diffusion._denoise_fn) for param in self.ema_model.parameters(): @@ -37,12 +49,25 @@ def __init__( self.ema_every = 1000 def _anneal_lr(self, step: int) -> None: + """ + Anneal the learning rate. + + Args: + step: The current step. + """ frac_done = step / self.steps lr = self.init_lr * (1 - frac_done) for param_group in self.optimizer.param_groups: param_group["lr"] = lr def _run_step(self, x: Tensor, out_dict: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + """ + Run a single step of the training loop. + + Args: + x: The input tensor. + out_dict: The output dictionary. + """ x = x.to(self.device) for k, v in out_dict.items(): out_dict[k] = v.long().to(self.device) @@ -55,6 +80,7 @@ def _run_step(self, x: Tensor, out_dict: dict[str, Tensor]) -> tuple[Tensor, Ten return loss_multi, loss_gauss def run_loop(self) -> None: + """Run the training loop.""" step = 0 curr_loss_multi = 0.0 curr_loss_gauss = 0.0 @@ -99,9 +125,11 @@ def update_ema( """ Update target parameters to be closer to those of source parameters using an exponential moving average. - :param target_params: the target parameter sequence. - :param source_params: the source parameter sequence. - :param rate: the EMA rate (closer to 1 means slower). + + Args: + target_params: the target parameter sequence. + source_params: the source parameter sequence. + rate: the EMA rate (closer to 1 means slower). """ for targ, src in zip(target_params, source_params): targ.detach().mul_(rate).add_(src.detach(), alpha=1 - rate) diff --git a/tests/integration/models/clavaddpm/test_model.py b/tests/integration/models/clavaddpm/test_model.py index 3538ef83..4d6158ed 100644 --- a/tests/integration/models/clavaddpm/test_model.py +++ b/tests/integration/models/clavaddpm/test_model.py @@ -37,7 +37,7 @@ @pytest.mark.integration_test() def test_train_single_table(tmp_path: Path): tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/single_table/") - models = clava_training(tables, relation_order, tmp_path, DIFFUSION_CONFIG, {}, device="cpu") + models = clava_training(tables, relation_order, tmp_path, DIFFUSION_CONFIG, None, device="cpu") assert models @@ -46,7 +46,7 @@ def test_train_single_table(tmp_path: Path): def test_train_multi_table(tmp_path: Path): tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/multi_table/") tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, CLUSTERING_CONFIG) - models = clava_training(tables, relation_order, tmp_path, DIFFUSION_CONFIG, None, device="cpu") + models = clava_training(tables, relation_order, tmp_path, DIFFUSION_CONFIG, CLASSIFIER_CONFIG, device="cpu") assert models From c7b50401f602bd0c4af2f6e18a5a616b8f573cb1 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Thu, 31 Jul 2025 16:33:21 -0400 Subject: [PATCH 05/39] WIP --- .../models/clavaddpm/clustering.py | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 src/midst_toolkit/models/clavaddpm/clustering.py diff --git a/src/midst_toolkit/models/clavaddpm/clustering.py b/src/midst_toolkit/models/clavaddpm/clustering.py new file mode 100644 index 00000000..ebd8f444 --- /dev/null +++ b/src/midst_toolkit/models/clavaddpm/clustering.py @@ -0,0 +1,84 @@ +def clava_clustering( + tables: Tables, + relation_order: RelationOrder, + save_dir: Path, + configs: Configs, +) -> tuple[dict[str, Any], dict[tuple[str, str], dict[int, float]]]: + """ + Clustering function for the mutli-table function of theClavaDDPM model. + + Args: + tables: Definition of the tables and their relations. Example: + { + "table1": { + "children": ["table2"], + "parents": [] + }, + "table2": { + "children": [], + "parents": ["table1"] + } + } + relation_order: List of tuples of parent and child tables. Example: + [("table1", "table2"), ("table1", "table3")] + save_dir: Directory to save the clustering checkpoint. + configs: Dictionary of configurations. The following config keys are required: + { + num_clusters = int | dict, + parent_scale = float, + clustering_method = str["kmeans" | "both" | "variational" | "gmm"], + } + + """ + relation_order_reversed = relation_order[::-1] + all_group_lengths_prob_dicts = {} + + # Clustering + if os.path.exists(save_dir / "cluster_ckpt.pkl"): + print("Clustering checkpoint found, loading...") + cluster_ckpt = pickle.load(open(save_dir / "cluster_ckpt.pkl", "rb")) + # ruff: noqa: SIM115 + tables = cluster_ckpt["tables"] + all_group_lengths_prob_dicts = cluster_ckpt["all_group_lengths_prob_dicts"] + else: + for parent, child in relation_order_reversed: + if parent is not None: + print(f"Clustering {parent} -> {child}") + if isinstance(configs["num_clusters"], dict): + num_clusters = configs["num_clusters"][child] + else: + num_clusters = configs["num_clusters"] + ( + parent_df_with_cluster, + child_df_with_cluster, + group_lengths_prob_dicts, + ) = pair_clustering_keep_id( + tables[child]["df"], + tables[child]["domain"], + tables[parent]["df"], + tables[parent]["domain"], + f"{child}_id", + f"{parent}_id", + num_clusters, + configs["parent_scale"], + 1, # not used for now + parent, + child, + clustering_method=configs["clustering_method"], + ) + tables[parent]["df"] = parent_df_with_cluster + tables[child]["df"] = child_df_with_cluster + all_group_lengths_prob_dicts[(parent, child)] = group_lengths_prob_dicts + + cluster_ckpt = { + "tables": tables, + "all_group_lengths_prob_dicts": all_group_lengths_prob_dicts, + } + pickle.dump(cluster_ckpt, open(save_dir / "cluster_ckpt.pkl", "wb")) + # ruff: noqa: SIM115 + + for parent, child in relation_order: + if parent is None: + tables[child]["df"]["placeholder"] = list(range(len(tables[child]["df"]))) + + return tables, all_group_lengths_prob_dicts From 0d8c49969a2dc8fc1996cf34d3436160f825520d Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Thu, 31 Jul 2025 17:24:58 -0400 Subject: [PATCH 06/39] Moving clustering code to its own module --- .../models/clavaddpm/clustering.py | 471 +++++++++++++++++- src/midst_toolkit/models/clavaddpm/model.py | 375 +------------- src/midst_toolkit/models/clavaddpm/params.py | 7 + src/midst_toolkit/models/clavaddpm/train.py | 94 +--- .../models/clavaddpm/test_model.py | 3 +- 5 files changed, 480 insertions(+), 470 deletions(-) create mode 100644 src/midst_toolkit/models/clavaddpm/params.py diff --git a/src/midst_toolkit/models/clavaddpm/clustering.py b/src/midst_toolkit/models/clavaddpm/clustering.py index ebd8f444..ddfefa09 100644 --- a/src/midst_toolkit/models/clavaddpm/clustering.py +++ b/src/midst_toolkit/models/clavaddpm/clustering.py @@ -1,9 +1,24 @@ +import os +import pickle +from collections import defaultdict +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import pandas as pd +from sklearn.cluster import KMeans +from sklearn.mixture import BayesianGaussianMixture, GaussianMixture +from sklearn.preprocessing import LabelEncoder, MinMaxScaler, OneHotEncoder, QuantileTransformer + +from midst_toolkit.models.clavaddpm.params import Configs, RelationOrder, Tables + + def clava_clustering( tables: Tables, relation_order: RelationOrder, save_dir: Path, configs: Configs, -) -> tuple[dict[str, Any], dict[tuple[str, str], dict[int, float]]]: +) -> tuple[Tables, dict[tuple[str, str], dict[int, float]]]: """ Clustering function for the mutli-table function of theClavaDDPM model. @@ -29,6 +44,8 @@ def clava_clustering( clustering_method = str["kmeans" | "both" | "variational" | "gmm"], } + Returns: + Tuple of the tables and the dictionary of group lengths and probabilities. """ relation_order_reversed = relation_order[::-1] all_group_lengths_prob_dicts = {} @@ -52,7 +69,7 @@ def clava_clustering( parent_df_with_cluster, child_df_with_cluster, group_lengths_prob_dicts, - ) = pair_clustering_keep_id( + ) = _pair_clustering_keep_id( tables[child]["df"], tables[child]["domain"], tables[parent]["df"], @@ -82,3 +99,453 @@ def clava_clustering( tables[child]["df"]["placeholder"] = list(range(len(tables[child]["df"]))) return tables, all_group_lengths_prob_dicts + + +def _pair_clustering_keep_id( + # ruff: noqa: PLR0912, PLR0915 + child_df: pd.DataFrame, + child_domain_dict: dict[str, Any], + parent_df: pd.DataFrame, + parent_domain_dict: dict[str, Any], + child_primary_key: str, + parent_primary_key: str, + num_clusters: int, + parent_scale: float, + key_scale: float, + parent_name: str, + child_name: str, + clustering_method: Literal["kmeans", "both", "variational", "gmm"] = "kmeans", +) -> tuple[pd.DataFrame, pd.DataFrame, dict[int, dict[int, float]]]: + """ + Pairs clustering information to the parent and child dataframes. + + Used by the mutli-table function of the ClavaDDPM model. + + Args: + child_df: DataFrame of the child table, as provided by the load_multi_table function. + child_domain_dict: Dictionary of the child table domain, as provided by the load_multi_table function. + parent_df: DataFrame of the parent table, as provided by the load_multi_table function. + parent_domain_dict: Dictionary of the parent table domain, as provided by the load_multi_table function. + child_primary_key: Name of the child primary key. + parent_primary_key: Name of the parent primary key. + num_clusters: Number of clusters. + parent_scale: Scale of the parent table, provided by the config. + key_scale: Scale of the key. + parent_name: Name of the parent table. + child_name: Name of the child table. + clustering_method: Method of clustering. Has to be one of ["kmeans", "both", "variational", "gmm"]. + Default is "kmeans". + + Returns: + Tuple with 3 elements: + - parent_df_with_cluster: DataFrame of the parent table with the cluster column. + - child_df_with_cluster: DataFrame of the child table with the cluster column. + - group_lengths_prob_dicts: Dictionary of group lengths and probabilities. + """ + original_child_cols = list(child_df.columns) + original_parent_cols = list(parent_df.columns) + + relation_cluster_name = f"{parent_name}_{child_name}_cluster" + + child_data = child_df.to_numpy() + parent_data = parent_df.to_numpy() + + child_num_cols = [] + child_cat_cols = [] + + parent_num_cols = [] + parent_cat_cols = [] + + for col_index, col in enumerate(original_child_cols): + if col in child_domain_dict: + if child_domain_dict[col]["type"] == "discrete": + child_cat_cols.append((col_index, col)) + else: + child_num_cols.append((col_index, col)) + + for col_index, col in enumerate(original_parent_cols): + if col in parent_domain_dict: + if parent_domain_dict[col]["type"] == "discrete": + parent_cat_cols.append((col_index, col)) + else: + parent_num_cols.append((col_index, col)) + + parent_primary_key_index = original_parent_cols.index(parent_primary_key) + foreing_key_index = original_child_cols.index(parent_primary_key) + + # sort child data by foreign key + sorted_child_data = child_data[np.argsort(child_data[:, foreing_key_index])] + child_group_data_dict = _get_group_data_dict(sorted_child_data, [foreing_key_index]) + + # sort parent data by primary key + sorted_parent_data = parent_data[np.argsort(parent_data[:, parent_primary_key_index])] + + group_lengths = [] + unique_group_ids = sorted_parent_data[:, parent_primary_key_index] + for group_id in unique_group_ids: + group_id = tuple([group_id]) + # ruff: noqa: C409 + if group_id not in child_group_data_dict: + group_lengths.append(0) + else: + group_lengths.append(len(child_group_data_dict[group_id])) + + group_lengths_np = np.array(group_lengths, dtype=int) + + sorted_parent_data_repeated = np.repeat(sorted_parent_data, group_lengths_np, axis=0) + assert (sorted_parent_data_repeated[:, parent_primary_key_index] == sorted_child_data[:, foreing_key_index]).all() + + sorted_child_num_data = sorted_child_data[:, [col_index for col_index, col in child_num_cols]] + sorted_child_cat_data = sorted_child_data[:, [col_index for col_index, col in child_cat_cols]] + sorted_parent_num_data = sorted_parent_data_repeated[:, [col_index for col_index, col in parent_num_cols]] + sorted_parent_cat_data = sorted_parent_data_repeated[:, [col_index for col_index, col in parent_cat_cols]] + + joint_num_matrix = np.concatenate([sorted_child_num_data, sorted_parent_num_data], axis=1) + joint_cat_matrix = np.concatenate([sorted_child_cat_data, sorted_parent_cat_data], axis=1) + + if joint_cat_matrix.shape[1] > 0: + joint_cat_matrix_p_index = sorted_child_cat_data.shape[1] + joint_num_matrix_p_index = sorted_child_num_data.shape[1] + + cat_converted = [] + label_encoders = [] + for i in range(joint_cat_matrix.shape[1]): + # A threshold of 1000 unique values is used to prevent the one-hot encoding of large categorical columns + if len(np.unique(joint_cat_matrix[:, i])) > 1000: + continue + label_encoder = LabelEncoder() + cat_converted.append(label_encoder.fit_transform(joint_cat_matrix[:, i]).astype(float)) + label_encoders.append(label_encoder) + + cat_converted_transposed = np.vstack(cat_converted).T + + # Initialize an empty array to store the encoded values + cat_one_hot = np.empty((cat_converted_transposed.shape[0], 0)) + + # Loop through each column in the data and encode it + for col in range(cat_converted_transposed.shape[1]): + encoder = OneHotEncoder(sparse_output=False) + column = cat_converted_transposed[:, col].reshape(-1, 1) + encoded_column = encoder.fit_transform(column) + cat_one_hot = np.concatenate((cat_one_hot, encoded_column), axis=1) + + cat_one_hot[:, joint_cat_matrix_p_index:] = parent_scale * cat_one_hot[:, joint_cat_matrix_p_index:] + + # Perform quantile normalization using QuantileTransformer + num_quantile = _quantile_normalize_sklearn(joint_num_matrix) + num_min_max = _min_max_normalize_sklearn(joint_num_matrix) + + # key_quantile = + # quantile_normalize_sklearn(sorted_parent_data_repeated[:, parent_primary_key_index].reshape(-1, 1)) + key_min_max = _min_max_normalize_sklearn(sorted_parent_data_repeated[:, parent_primary_key_index].reshape(-1, 1)) + + # key_scaled = key_scaler * key_quantile + key_scaled = key_scale * key_min_max + + num_quantile[:, joint_num_matrix_p_index:] = parent_scale * num_quantile[:, joint_num_matrix_p_index:] + num_min_max[:, joint_num_matrix_p_index:] = parent_scale * num_min_max[:, joint_num_matrix_p_index:] + + if joint_cat_matrix.shape[1] > 0: + cluster_data = np.concatenate((num_min_max, cat_one_hot, key_scaled), axis=1) + else: + cluster_data = np.concatenate((num_min_max, key_scaled), axis=1) + + child_group_data = _get_group_data(sorted_child_data, [foreing_key_index]) + child_group_lengths = np.array([len(group) for group in child_group_data], dtype=int) + num_clusters = min(num_clusters, len(cluster_data)) + + # print('clustering') + if clustering_method == "kmeans": + kmeans = KMeans(n_clusters=num_clusters, n_init="auto", init="k-means++") + kmeans.fit(cluster_data) + cluster_labels = kmeans.labels_ + elif clustering_method == "both": + gmm = GaussianMixture( + n_components=num_clusters, + verbose=1, + covariance_type="diag", + init_params="k-means++", + tol=0.0001, + ) + gmm.fit(cluster_data) + cluster_labels = gmm.predict(cluster_data) + elif clustering_method == "variational": + gmm = BayesianGaussianMixture( + n_components=num_clusters, + verbose=1, + covariance_type="diag", + init_params="k-means++", + tol=0.0001, + ) + gmm.fit(cluster_data) + cluster_labels = gmm.predict_proba(cluster_data) + elif clustering_method == "gmm": + gmm = GaussianMixture( + n_components=num_clusters, + verbose=1, + covariance_type="diag", + ) + gmm.fit(cluster_data) + cluster_labels = gmm.predict(cluster_data) + + if clustering_method == "variational": + group_cluster_labels, agree_rates = _aggregate_and_sample(cluster_labels, child_group_lengths) + else: + # voting to determine the cluster label for each parent + group_cluster_labels = [] + curr_index = 0 + agree_rates = [] + for group_length in child_group_lengths: + # First, determine the most common label in the current group + most_common_label_count = np.max(np.bincount(cluster_labels[curr_index : curr_index + group_length])) + group_cluster_label = np.argmax(np.bincount(cluster_labels[curr_index : curr_index + group_length])) + group_cluster_labels.append(int(group_cluster_label)) + + # Compute agree rate using the most common label count + agree_rate = most_common_label_count / group_length + agree_rates.append(agree_rate) + + # Then, update the curr_index for the next iteration + curr_index += group_length + + # Compute the average agree rate across all groups + average_agree_rate = np.mean(agree_rates) + print("Average agree rate: ", average_agree_rate) + + group_assignment = np.repeat(group_cluster_labels, child_group_lengths, axis=0).reshape((-1, 1)) + + # obtain the child data with clustering + sorted_child_data_with_cluster = np.concatenate([sorted_child_data, group_assignment], axis=1) + + group_labels_list = group_cluster_labels + group_lengths_list = child_group_lengths.tolist() + + group_lengths_dict: dict[int, dict[int, int]] = {} + for i in range(len(group_labels_list)): + group_label = group_labels_list[i] + if group_label not in group_lengths_dict: + group_lengths_dict[group_label] = defaultdict(int) + group_lengths_dict[group_label][group_lengths_list[i]] += 1 + + group_lengths_prob_dicts: dict[int, dict[int, float]] = {} + for group_label, freq_dict in group_lengths_dict.items(): + group_lengths_prob_dicts[group_label] = _freq_to_prob(freq_dict) + + # recover the preprocessed data back to dataframe + child_df_with_cluster = pd.DataFrame( + sorted_child_data_with_cluster, + columns=original_child_cols + [relation_cluster_name], + ) + + # recover child df order + child_df_with_cluster = pd.merge( + child_df[[child_primary_key]], + child_df_with_cluster, + on=child_primary_key, + how="left", + ) + + parent_id_to_cluster: dict[Any, Any] = {} + for i in range(len(sorted_child_data)): + parent_id = sorted_child_data[i, foreing_key_index] + if parent_id in parent_id_to_cluster: + assert parent_id_to_cluster[parent_id] == sorted_child_data_with_cluster[i, -1] + continue + parent_id_to_cluster[parent_id] = sorted_child_data_with_cluster[i, -1] + + max_cluster_label = max(parent_id_to_cluster.values()) + + parent_data_clusters = [] + for i in range(len(parent_data)): + if parent_data[i, parent_primary_key_index] in parent_id_to_cluster: + parent_data_clusters.append(parent_id_to_cluster[parent_data[i, parent_primary_key_index]]) + else: + parent_data_clusters.append(max_cluster_label + 1) + + parent_data_clusters_np = np.array(parent_data_clusters).reshape(-1, 1) + parent_data_with_cluster = np.concatenate([parent_data, parent_data_clusters_np], axis=1) + parent_df_with_cluster = pd.DataFrame( + parent_data_with_cluster, columns=original_parent_cols + [relation_cluster_name] + ) + + new_col_entry = { + "type": "discrete", + "size": len(set(parent_data_clusters_np.flatten())), + } + + print("Number of cluster centers: ", len(set(parent_data_clusters_np.flatten()))) + + parent_domain_dict[relation_cluster_name] = new_col_entry.copy() + child_domain_dict[relation_cluster_name] = new_col_entry.copy() + + return parent_df_with_cluster, child_df_with_cluster, group_lengths_prob_dicts + + +def _get_group_data_dict( + np_data: np.ndarray, + group_id_attrs: list[int] | None = None, +) -> dict[tuple[Any, ...], list[np.ndarray]]: + """ + Get the group data dictionary. + + Args: + np_data: Numpy array of the data. + group_id_attrs: List of attributes to group by. + + Returns: + Dictionary of group data. + """ + if group_id_attrs is None: + group_id_attrs = [0] + + group_data_dict: dict[tuple[Any, ...], list[np.ndarray]] = {} + data_len = len(np_data) + for i in range(data_len): + row_id = tuple(np_data[i, group_id_attrs]) + if row_id not in group_data_dict: + group_data_dict[row_id] = [] + group_data_dict[row_id].append(np_data[i]) + + return group_data_dict + + +def _get_group_data( + np_data: np.ndarray, + group_id_attrs: list[int] | None = None, +) -> np.ndarray: + """ + Get the group data. + + Args: + np_data: Numpy array of the data. + group_id_attrs: List of attributes to group by. + + Returns: + Numpy array of the group data. + """ + if group_id_attrs is None: + group_id_attrs = [0] + + group_data_list = [] + data_len = len(np_data) + i = 0 + while i < data_len: + group = [] + row_id = np_data[i, group_id_attrs] + + while (np_data[i, group_id_attrs] == row_id).all(): + group.append(np_data[i]) + i += 1 + if i >= data_len: + break + group_data_list.append(np.array(group)) + + return np.array(group_data_list, dtype=object) + + +# TODO: Refactor the functions below to be a single one with a "method" parameter. + + +def _quantile_normalize_sklearn(matrix: np.ndarray) -> np.ndarray: + """ + Quantile normalize the input matrix using Sklearn's QuantileTransformer. + + Args: + matrix: Numpy array of the matrix data. + + Returns: + Numpy array of the normalized data. + """ + transformer = QuantileTransformer( + output_distribution="normal", + random_state=42, # TODO: do we really need to hardcode the random state? + ) # Change output_distribution as needed + + normalized_data = np.empty((matrix.shape[0], 0)) + + # Apply QuantileTransformer to each column and concatenate the results + for col in range(matrix.shape[1]): + column = matrix[:, col].reshape(-1, 1) + transformed_column = transformer.fit_transform(column) + normalized_data = np.concatenate((normalized_data, transformed_column), axis=1) + + return normalized_data + + +def _min_max_normalize_sklearn(matrix: np.ndarray) -> np.ndarray: + """ + Min-max normalize the input matrix using Sklearn's MinMaxScaler. + + Args: + matrix: Numpy array of the matrix data. + + Returns: + Numpy array of the normalized data. + """ + scaler = MinMaxScaler(feature_range=(-1, 1)) + + normalized_data = np.empty((matrix.shape[0], 0)) + + # Apply MinMaxScaler to each column and concatenate the results + for col in range(matrix.shape[1]): + column = matrix[:, col].reshape(-1, 1) + transformed_column = scaler.fit_transform(column) + normalized_data = np.concatenate((normalized_data, transformed_column), axis=1) + + return normalized_data + + +def _aggregate_and_sample( + cluster_probabilities: np.ndarray, + child_group_lengths: np.ndarray, +) -> tuple[list[int], list[float]]: + """ + Aggregate the cluster probabilities and sample the labels. + + Used by the variational clustering method. + + Args: + cluster_probabilities: Numpy array of the cluster probabilities. + child_group_lengths: Numpy array of the child group lengths. + + Returns: + Tuple of the group cluster labels and the agree rates. + """ + group_cluster_labels = [] + curr_index = 0 + agree_rates = [] + + for group_length in child_group_lengths: + # Aggregate the probability distributions by taking the mean + group_probability_distribution = np.mean(cluster_probabilities[curr_index : curr_index + group_length], axis=0) + + # Sample the label from the aggregated distribution + group_cluster_label = np.random.choice( + range(len(group_probability_distribution)), p=group_probability_distribution + ) + group_cluster_labels.append(group_cluster_label) + + # Compute the max probability as the agree rate + max_probability = np.max(group_probability_distribution) + agree_rates.append(max_probability) + + # Update the curr_index for the next iteration + curr_index += group_length + + return group_cluster_labels, agree_rates + + +def _freq_to_prob(freq_dict: dict[int, int]) -> dict[int, float]: + """ + Convert a frequency dictionary to a probability dictionary. + + Args: + freq_dict: Dictionary of frequencies. + + Returns: + Dictionary of probabilities. + """ + prob_dict: dict[Any, float] = {} + for key, freq in freq_dict.items(): + prob_dict[key] = freq / sum(list(freq_dict.values())) + return prob_dict diff --git a/src/midst_toolkit/models/clavaddpm/model.py b/src/midst_toolkit/models/clavaddpm/model.py index f636b552..f67c1bc2 100644 --- a/src/midst_toolkit/models/clavaddpm/model.py +++ b/src/midst_toolkit/models/clavaddpm/model.py @@ -3,7 +3,7 @@ import math import pickle from abc import ABC, abstractmethod -from collections import Counter, defaultdict +from collections import Counter from collections.abc import Callable, Generator from copy import deepcopy from dataclasses import astuple, dataclass, replace @@ -19,10 +19,8 @@ # ruff: noqa: N812 from category_encoders import LeaveOneOutEncoder from scipy.special import expit, softmax -from sklearn.cluster import KMeans from sklearn.impute import SimpleImputer from sklearn.metrics import classification_report, mean_squared_error, r2_score, roc_auc_score -from sklearn.mixture import BayesianGaussianMixture, GaussianMixture from sklearn.model_selection import train_test_split from sklearn.pipeline import make_pipeline from sklearn.preprocessing import ( @@ -276,377 +274,6 @@ def forward(self, x, timesteps): return self.model(x) -def pair_clustering_keep_id( - child_df: pd.DataFrame, - child_domain_dict: dict[str, Any], - parent_df: pd.DataFrame, - parent_domain_dict: dict[str, Any], - child_primary_key: str, - parent_primary_key: str, - num_clusters: int, - parent_scale: float, - key_scale: float, - parent_name: str, - child_name: str, - clustering_method: str = "kmeans", -) -> tuple[pd.DataFrame, pd.DataFrame, dict[int, dict[int, float]]]: - original_child_cols = list(child_df.columns) - original_parent_cols = list(parent_df.columns) - - relation_cluster_name = f"{parent_name}_{child_name}_cluster" - - child_data = child_df.to_numpy() - parent_data = parent_df.to_numpy() - - child_num_cols = [] - child_cat_cols = [] - - parent_num_cols = [] - parent_cat_cols = [] - - for col_index, col in enumerate(original_child_cols): - if col in child_domain_dict: - if child_domain_dict[col]["type"] == "discrete": - child_cat_cols.append((col_index, col)) - else: - child_num_cols.append((col_index, col)) - - for col_index, col in enumerate(original_parent_cols): - if col in parent_domain_dict: - if parent_domain_dict[col]["type"] == "discrete": - parent_cat_cols.append((col_index, col)) - else: - parent_num_cols.append((col_index, col)) - - parent_primary_key_index = original_parent_cols.index(parent_primary_key) - foreing_key_index = original_child_cols.index(parent_primary_key) - - # sort child data by foreign key - sorted_child_data = child_data[np.argsort(child_data[:, foreing_key_index])] - child_group_data_dict = get_group_data_dict( - sorted_child_data, - [ - foreing_key_index, - ], - ) - - # sort parent data by primary key - sorted_parent_data = parent_data[np.argsort(parent_data[:, parent_primary_key_index])] - - group_lengths = [] - unique_group_ids = sorted_parent_data[:, parent_primary_key_index] - for group_id in unique_group_ids: - group_id = tuple([group_id]) - # ruff: noqa: C409 - if group_id not in child_group_data_dict: - group_lengths.append(0) - else: - group_lengths.append(len(child_group_data_dict[group_id])) - - group_lengths_np = np.array(group_lengths, dtype=int) - - sorted_parent_data_repeated = np.repeat(sorted_parent_data, group_lengths_np, axis=0) - assert (sorted_parent_data_repeated[:, parent_primary_key_index] == sorted_child_data[:, foreing_key_index]).all() - - child_group_data = get_group_data( - sorted_child_data, - [ - foreing_key_index, - ], - ) - - sorted_child_num_data = sorted_child_data[:, [col_index for col_index, col in child_num_cols]] - sorted_child_cat_data = sorted_child_data[:, [col_index for col_index, col in child_cat_cols]] - sorted_parent_num_data = sorted_parent_data_repeated[:, [col_index for col_index, col in parent_num_cols]] - sorted_parent_cat_data = sorted_parent_data_repeated[:, [col_index for col_index, col in parent_cat_cols]] - - joint_num_matrix = np.concatenate([sorted_child_num_data, sorted_parent_num_data], axis=1) - joint_cat_matrix = np.concatenate([sorted_child_cat_data, sorted_parent_cat_data], axis=1) - - if joint_cat_matrix.shape[1] > 0: - joint_cat_matrix_p_index = sorted_child_cat_data.shape[1] - joint_num_matrix_p_index = sorted_child_num_data.shape[1] - - cat_converted = [] - label_encoders = [] - for i in range(joint_cat_matrix.shape[1]): - # A threshold of 1000 unique values is used to prevent the one-hot encoding of large categorical columns - if len(np.unique(joint_cat_matrix[:, i])) > 1000: - continue - label_encoder = LabelEncoder() - cat_converted.append(label_encoder.fit_transform(joint_cat_matrix[:, i]).astype(float)) - label_encoders.append(label_encoder) - - cat_converted_transposed = np.vstack(cat_converted).T - - # Initialize an empty array to store the encoded values - cat_one_hot = np.empty((cat_converted_transposed.shape[0], 0)) - - # Loop through each column in the data and encode it - for col in range(cat_converted_transposed.shape[1]): - encoder = OneHotEncoder(sparse_output=False) - column = cat_converted_transposed[:, col].reshape(-1, 1) - encoded_column = encoder.fit_transform(column) - cat_one_hot = np.concatenate((cat_one_hot, encoded_column), axis=1) - - cat_one_hot[:, joint_cat_matrix_p_index:] = parent_scale * cat_one_hot[:, joint_cat_matrix_p_index:] - - # Perform quantile normalization using QuantileTransformer - num_quantile = quantile_normalize_sklearn(joint_num_matrix) - num_min_max = min_max_normalize_sklearn(joint_num_matrix) - - # key_quantile = - # quantile_normalize_sklearn(sorted_parent_data_repeated[:, parent_primary_key_index].reshape(-1, 1)) - key_min_max = min_max_normalize_sklearn(sorted_parent_data_repeated[:, parent_primary_key_index].reshape(-1, 1)) - - # key_scaled = key_scaler * key_quantile - key_scaled = key_scale * key_min_max - - num_quantile[:, joint_num_matrix_p_index:] = parent_scale * num_quantile[:, joint_num_matrix_p_index:] - num_min_max[:, joint_num_matrix_p_index:] = parent_scale * num_min_max[:, joint_num_matrix_p_index:] - - if joint_cat_matrix.shape[1] > 0: - cluster_data = np.concatenate((num_min_max, cat_one_hot, key_scaled), axis=1) - else: - cluster_data = np.concatenate((num_min_max, key_scaled), axis=1) - - child_group_lengths = np.array([len(group) for group in child_group_data], dtype=int) - num_clusters = min(num_clusters, len(cluster_data)) - - # print('clustering') - if clustering_method == "kmeans": - kmeans = KMeans(n_clusters=num_clusters, n_init="auto", init="k-means++") - kmeans.fit(cluster_data) - cluster_labels = kmeans.labels_ - elif clustering_method == "both": - gmm = GaussianMixture( - n_components=num_clusters, - verbose=1, - covariance_type="diag", - init_params="k-means++", - tol=0.0001, - ) - gmm.fit(cluster_data) - cluster_labels = gmm.predict(cluster_data) - elif clustering_method == "variational": - gmm = BayesianGaussianMixture( - n_components=num_clusters, - verbose=1, - covariance_type="diag", - init_params="k-means++", - tol=0.0001, - ) - gmm.fit(cluster_data) - cluster_labels = gmm.predict_proba(cluster_data) - elif clustering_method == "gmm": - gmm = GaussianMixture( - n_components=num_clusters, - verbose=1, - covariance_type="diag", - ) - gmm.fit(cluster_data) - cluster_labels = gmm.predict(cluster_data) - - if clustering_method == "variational": - group_cluster_labels, agree_rates = aggregate_and_sample(cluster_labels, child_group_lengths) - else: - # voting to determine the cluster label for each parent - group_cluster_labels = [] - curr_index = 0 - agree_rates = [] - for group_length in child_group_lengths: - # First, determine the most common label in the current group - most_common_label_count = np.max(np.bincount(cluster_labels[curr_index : curr_index + group_length])) - group_cluster_label = np.argmax(np.bincount(cluster_labels[curr_index : curr_index + group_length])) - group_cluster_labels.append(int(group_cluster_label)) - - # Compute agree rate using the most common label count - agree_rate = most_common_label_count / group_length - agree_rates.append(agree_rate) - - # Then, update the curr_index for the next iteration - curr_index += group_length - - # Compute the average agree rate across all groups - average_agree_rate = np.mean(agree_rates) - print("Average agree rate: ", average_agree_rate) - - group_assignment = np.repeat(group_cluster_labels, child_group_lengths, axis=0).reshape((-1, 1)) - - # obtain the child data with clustering - sorted_child_data_with_cluster = np.concatenate([sorted_child_data, group_assignment], axis=1) - - group_labels_list = group_cluster_labels - group_lengths_list = child_group_lengths.tolist() - - group_lengths_dict: dict[int, dict[int, int]] = {} - for i in range(len(group_labels_list)): - group_label = group_labels_list[i] - if group_label not in group_lengths_dict: - group_lengths_dict[group_label] = defaultdict(int) - group_lengths_dict[group_label][group_lengths_list[i]] += 1 - - group_lengths_prob_dicts: dict[int, dict[int, float]] = {} - for group_label, freq_dict in group_lengths_dict.items(): - group_lengths_prob_dicts[group_label] = freq_to_prob(freq_dict) - - # recover the preprocessed data back to dataframe - child_df_with_cluster = pd.DataFrame( - sorted_child_data_with_cluster, - columns=original_child_cols + [relation_cluster_name], - ) - - # recover child df order - child_df_with_cluster = pd.merge( - child_df[[child_primary_key]], - child_df_with_cluster, - on=child_primary_key, - how="left", - ) - - parent_id_to_cluster: dict[Any, Any] = {} - for i in range(len(sorted_child_data)): - parent_id = sorted_child_data[i, foreing_key_index] - if parent_id in parent_id_to_cluster: - assert parent_id_to_cluster[parent_id] == sorted_child_data_with_cluster[i, -1] - continue - parent_id_to_cluster[parent_id] = sorted_child_data_with_cluster[i, -1] - - max_cluster_label = max(parent_id_to_cluster.values()) - - parent_data_clusters = [] - for i in range(len(parent_data)): - if parent_data[i, parent_primary_key_index] in parent_id_to_cluster: - parent_data_clusters.append(parent_id_to_cluster[parent_data[i, parent_primary_key_index]]) - else: - parent_data_clusters.append(max_cluster_label + 1) - - parent_data_clusters_np = np.array(parent_data_clusters).reshape(-1, 1) - parent_data_with_cluster = np.concatenate([parent_data, parent_data_clusters_np], axis=1) - parent_df_with_cluster = pd.DataFrame( - parent_data_with_cluster, columns=original_parent_cols + [relation_cluster_name] - ) - - new_col_entry = { - "type": "discrete", - "size": len(set(parent_data_clusters_np.flatten())), - } - - print("Number of cluster centers: ", len(set(parent_data_clusters_np.flatten()))) - - parent_domain_dict[relation_cluster_name] = new_col_entry.copy() - child_domain_dict[relation_cluster_name] = new_col_entry.copy() - - return parent_df_with_cluster, child_df_with_cluster, group_lengths_prob_dicts - - -def get_group_data_dict( - np_data: np.ndarray, - group_id_attrs: list[int] | None = None, -) -> dict[tuple[Any, ...], list[np.ndarray]]: - if group_id_attrs is None: - group_id_attrs = [0] - - group_data_dict: dict[tuple[Any, ...], list[np.ndarray]] = {} - data_len = len(np_data) - for i in range(data_len): - row_id = tuple(np_data[i, group_id_attrs]) - if row_id not in group_data_dict: - group_data_dict[row_id] = [] - group_data_dict[row_id].append(np_data[i]) - - return group_data_dict - - -def get_group_data( - np_data: np.ndarray, - group_id_attrs: list[int] | None = None, -) -> np.ndarray: - if group_id_attrs is None: - group_id_attrs = [0] - - group_data_list = [] - data_len = len(np_data) - i = 0 - while i < data_len: - group = [] - row_id = np_data[i, group_id_attrs] - - while (np_data[i, group_id_attrs] == row_id).all(): - group.append(np_data[i]) - i += 1 - if i >= data_len: - break - group_data_list.append(np.array(group)) - - return np.array(group_data_list, dtype=object) - - -def quantile_normalize_sklearn(matrix: np.ndarray) -> np.ndarray: - transformer = QuantileTransformer( - output_distribution="normal", random_state=42 - ) # Change output_distribution as needed - - normalized_data = np.empty((matrix.shape[0], 0)) - - # Apply QuantileTransformer to each column and concatenate the results - for col in range(matrix.shape[1]): - column = matrix[:, col].reshape(-1, 1) - transformed_column = transformer.fit_transform(column) - normalized_data = np.concatenate((normalized_data, transformed_column), axis=1) - - return normalized_data - - -def min_max_normalize_sklearn(matrix: np.ndarray) -> np.ndarray: - scaler = MinMaxScaler(feature_range=(-1, 1)) - - normalized_data = np.empty((matrix.shape[0], 0)) - - # Apply MinMaxScaler to each column and concatenate the results - for col in range(matrix.shape[1]): - column = matrix[:, col].reshape(-1, 1) - transformed_column = scaler.fit_transform(column) - normalized_data = np.concatenate((normalized_data, transformed_column), axis=1) - - return normalized_data - - -def aggregate_and_sample( - cluster_probabilities: np.ndarray, - child_group_lengths: np.ndarray, -) -> tuple[list[int], list[float]]: - group_cluster_labels = [] - curr_index = 0 - agree_rates = [] - - for group_length in child_group_lengths: - # Aggregate the probability distributions by taking the mean - group_probability_distribution = np.mean(cluster_probabilities[curr_index : curr_index + group_length], axis=0) - - # Sample the label from the aggregated distribution - group_cluster_label = np.random.choice( - range(len(group_probability_distribution)), p=group_probability_distribution - ) - group_cluster_labels.append(group_cluster_label) - - # Compute the max probability as the agree rate - max_probability = np.max(group_probability_distribution) - agree_rates.append(max_probability) - - # Update the curr_index for the next iteration - curr_index += group_length - - return group_cluster_labels, agree_rates - - -def freq_to_prob(freq_dict: dict[int, int]) -> dict[int, float]: - prob_dict: dict[Any, float] = {} - for key, freq in freq_dict.items(): - prob_dict[key] = freq / sum(list(freq_dict.values())) - return prob_dict - - def get_table_info(df: pd.DataFrame, domain_dict: dict[str, Any], y_col: str) -> dict[str, Any]: cat_cols = [] num_cols = [] diff --git a/src/midst_toolkit/models/clavaddpm/params.py b/src/midst_toolkit/models/clavaddpm/params.py new file mode 100644 index 00000000..38217fc0 --- /dev/null +++ b/src/midst_toolkit/models/clavaddpm/params.py @@ -0,0 +1,7 @@ +from typing import Any + + +# TODO: Temporary, will wtich to classes later +Configs = dict[str, Any] +Tables = dict[str, dict[str, Any]] +RelationOrder = list[tuple[str, str]] diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index a51095e9..7048dccb 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -1,6 +1,5 @@ """Defines the training functions for the ClavaDDPM model.""" -import os import pickle from pathlib import Path from typing import Any @@ -22,103 +21,12 @@ get_table_info, make_dataset_from_df, numerical_forward_backward_log, - pair_clustering_keep_id, prepare_fast_dataloader, ) +from midst_toolkit.models.clavaddpm.params import Configs, RelationOrder, Tables from midst_toolkit.models.clavaddpm.trainer import Trainer -Tables = dict[str, dict[str, Any]] -RelationOrder = list[tuple[str, str]] -Configs = dict[str, Any] - - -def clava_clustering( - tables: Tables, - relation_order: RelationOrder, - save_dir: Path, - configs: Configs, -) -> tuple[dict[str, Any], dict[tuple[str, str], dict[int, float]]]: - """ - Clustering function for the mutli-table function of theClavaDDPM model. - - Args: - tables: Definition of the tables and their relations. Example: - { - "table1": { - "children": ["table2"], - "parents": [] - }, - "table2": { - "children": [], - "parents": ["table1"] - } - } - relation_order: List of tuples of parent and child tables. Example: - [("table1", "table2"), ("table1", "table3")] - save_dir: Directory to save the clustering checkpoint. - configs: Dictionary of configurations. The following config keys are required: - { - num_clusters = int | dict, - parent_scale = float, - clustering_method = str["kmeans" | "both" | "variational" | "gmm"], - } - - """ - relation_order_reversed = relation_order[::-1] - all_group_lengths_prob_dicts = {} - - # Clustering - if os.path.exists(save_dir / "cluster_ckpt.pkl"): - print("Clustering checkpoint found, loading...") - cluster_ckpt = pickle.load(open(save_dir / "cluster_ckpt.pkl", "rb")) - # ruff: noqa: SIM115 - tables = cluster_ckpt["tables"] - all_group_lengths_prob_dicts = cluster_ckpt["all_group_lengths_prob_dicts"] - else: - for parent, child in relation_order_reversed: - if parent is not None: - print(f"Clustering {parent} -> {child}") - if isinstance(configs["num_clusters"], dict): - num_clusters = configs["num_clusters"][child] - else: - num_clusters = configs["num_clusters"] - ( - parent_df_with_cluster, - child_df_with_cluster, - group_lengths_prob_dicts, - ) = pair_clustering_keep_id( - tables[child]["df"], - tables[child]["domain"], - tables[parent]["df"], - tables[parent]["domain"], - f"{child}_id", - f"{parent}_id", - num_clusters, - configs["parent_scale"], - 1, # not used for now - parent, - child, - clustering_method=configs["clustering_method"], - ) - tables[parent]["df"] = parent_df_with_cluster - tables[child]["df"] = child_df_with_cluster - all_group_lengths_prob_dicts[(parent, child)] = group_lengths_prob_dicts - - cluster_ckpt = { - "tables": tables, - "all_group_lengths_prob_dicts": all_group_lengths_prob_dicts, - } - pickle.dump(cluster_ckpt, open(save_dir / "cluster_ckpt.pkl", "wb")) - # ruff: noqa: SIM115 - - for parent, child in relation_order: - if parent is None: - tables[child]["df"]["placeholder"] = list(range(len(tables[child]["df"]))) - - return tables, all_group_lengths_prob_dicts - - def clava_training( tables: Tables, relation_order: RelationOrder, diff --git a/tests/integration/models/clavaddpm/test_model.py b/tests/integration/models/clavaddpm/test_model.py index 4d6158ed..1d81e7dc 100644 --- a/tests/integration/models/clavaddpm/test_model.py +++ b/tests/integration/models/clavaddpm/test_model.py @@ -3,7 +3,8 @@ import pytest from midst_toolkit.core.data_loaders import load_multi_table -from midst_toolkit.models.clavaddpm.train import clava_clustering, clava_training +from midst_toolkit.models.clavaddpm.clustering import clava_clustering +from midst_toolkit.models.clavaddpm.train import clava_training CLUSTERING_CONFIG = { From 63b93d801f3fa0f3645476c8f6330d6b68099fe3 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Thu, 31 Jul 2025 17:35:20 -0400 Subject: [PATCH 07/39] WIP --- src/midst_toolkit/models/clavaddpm/model.py | 148 ------------------ src/midst_toolkit/models/clavaddpm/sampler.py | 146 +++++++++++++++++ 2 files changed, 146 insertions(+), 148 deletions(-) create mode 100644 src/midst_toolkit/models/clavaddpm/sampler.py diff --git a/src/midst_toolkit/models/clavaddpm/model.py b/src/midst_toolkit/models/clavaddpm/model.py index f67c1bc2..9861743e 100644 --- a/src/midst_toolkit/models/clavaddpm/model.py +++ b/src/midst_toolkit/models/clavaddpm/model.py @@ -500,154 +500,6 @@ def get_model( raise ValueError("Unknown model!") -class ScheduleSampler(ABC): - """ - A distribution over timesteps in the diffusion process, intended to reduce - variance of the objective. - - By default, samplers perform unbiased importance sampling, in which the - objective's mean is unchanged. - However, subclasses may override sample() to change how the resampled - terms are reweighted, allowing for actual changes in the objective. - """ - - @abstractmethod - def weights(self) -> Tensor: - """ - Get a numpy array of weights, one per diffusion step. - - The weights needn't be normalized, but must be positive. - """ - - def sample(self, batch_size: int, device: str) -> tuple[Tensor, Tensor]: - """ - Importance-sample timesteps for a batch. - - :param batch_size: the number of timesteps. - :param device: the torch device to save to. - :return: a tuple (timesteps, weights): - - timesteps: a tensor of timestep indices. - - weights: a tensor of weights to scale the resulting losses. - """ - w = self.weights().cpu().numpy() - p = w / np.sum(w) - indices_np = np.random.choice(len(p), size=(batch_size,), p=p) - indices = torch.from_numpy(indices_np).long().to(device) - weights_np = 1 / (len(p) * p[indices_np]) - weights = torch.from_numpy(weights_np).float().to(device) - return indices, weights - - -class UniformSampler(ScheduleSampler): - def __init__(self, diffusion: GaussianMultinomialDiffusion): - self.diffusion = diffusion - self._weights = torch.from_numpy(np.ones([diffusion.num_timesteps])) - - def weights(self) -> Tensor: - return self._weights - - -class LossAwareSampler(ScheduleSampler): - def update_with_local_losses(self, local_ts: Tensor, local_losses: Tensor) -> None: - """ - Update the reweighting using losses from a model. - - Call this method from each rank with a batch of timesteps and the - corresponding losses for each of those timesteps. - This method will perform synchronization to make sure all of the ranks - maintain the exact same reweighting. - - :param local_ts: an integer Tensor of timesteps. - :param local_losses: a 1D Tensor of losses. - """ - batch_sizes = [ - torch.tensor([0], dtype=torch.int32, device=local_ts.device) - for _ in range(torch.distributed.get_world_size()) - ] - torch.distributed.all_gather( - batch_sizes, - torch.tensor([len(local_ts)], dtype=torch.int32, device=local_ts.device), - ) - - # Pad all_gather batches to be the maximum batch size. - max_bs = max([int(x.item()) for x in batch_sizes]) - - timestep_batches = [torch.zeros(max_bs).to(local_ts) for bs in batch_sizes] - loss_batches = [torch.zeros(max_bs).to(local_losses) for bs in batch_sizes] - torch.distributed.all_gather(timestep_batches, local_ts) - torch.distributed.all_gather(loss_batches, local_losses) - timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]] - losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] - self.update_with_all_losses(timesteps, losses) - - @abstractmethod - def update_with_all_losses(self, ts: list[int], losses: list[float]) -> None: - """ - Update the reweighting using losses from a model. - - Sub-classes should override this method to update the reweighting - using losses from the model. - - This method directly updates the reweighting without synchronizing - between workers. It is called by update_with_local_losses from all - ranks with identical arguments. Thus, it should have deterministic - behavior to maintain state across workers. - - :param ts: a list of int timesteps. - :param losses: a list of float losses, one per timestep. - """ - - -class LossSecondMomentResampler(LossAwareSampler): - def __init__( - self, - diffusion: GaussianMultinomialDiffusion, - history_per_term: int = 10, - uniform_prob: float = 0.001, - ): - self.diffusion = diffusion - self.history_per_term = history_per_term - self.uniform_prob = uniform_prob - self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64) - self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.uint) - - def weights(self): - if not self._warmed_up(): - return np.ones([self.diffusion.num_timesteps], dtype=np.float64) - weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) - weights /= np.sum(weights) - weights *= 1 - self.uniform_prob - weights += self.uniform_prob / len(weights) - return weights - - def update_with_all_losses(self, ts: list[int], losses: list[float]) -> None: - for t, loss in zip(ts, losses): - if self._loss_counts[t] == self.history_per_term: - # Shift out the oldest loss term. - self._loss_history[t, :-1] = self._loss_history[t, 1:] - self._loss_history[t, -1] = loss - else: - self._loss_history[t, self._loss_counts[t]] = loss - self._loss_counts[t] += 1 - - def _warmed_up(self) -> bool: - return (self._loss_counts == self.history_per_term).all() - - -def create_named_schedule_sampler(name: str, diffusion: GaussianMultinomialDiffusion) -> ScheduleSampler: - """ - Create a ScheduleSampler from a library of pre-defined samplers. - - :param name: the name of the sampler. - :param diffusion: the diffusion object to sample for. - """ - if name == "uniform": - return UniformSampler(diffusion) - if name == "loss-second-moment": - return LossSecondMomentResampler(diffusion) - raise NotImplementedError(f"unknown schedule sampler: {name}") - - def split_microbatches( microbatch: int, batch: Tensor, diff --git a/src/midst_toolkit/models/clavaddpm/sampler.py b/src/midst_toolkit/models/clavaddpm/sampler.py new file mode 100644 index 00000000..e60fa7bf --- /dev/null +++ b/src/midst_toolkit/models/clavaddpm/sampler.py @@ -0,0 +1,146 @@ +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self) -> Tensor: + """ + Get a numpy array of weights, one per diffusion step. + + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size: int, device: str) -> tuple[Tensor, Tensor]: + """ + Importance-sample timesteps for a batch. + + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights().cpu().numpy() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = torch.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = torch.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion: GaussianMultinomialDiffusion): + self.diffusion = diffusion + self._weights = torch.from_numpy(np.ones([diffusion.num_timesteps])) + + def weights(self) -> Tensor: + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts: Tensor, local_losses: Tensor) -> None: + """ + Update the reweighting using losses from a model. + + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + torch.tensor([0], dtype=torch.int32, device=local_ts.device) + for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather( + batch_sizes, + torch.tensor([len(local_ts)], dtype=torch.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + max_bs = max([int(x.item()) for x in batch_sizes]) + + timestep_batches = [torch.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [torch.zeros(max_bs).to(local_losses) for bs in batch_sizes] + torch.distributed.all_gather(timestep_batches, local_ts) + torch.distributed.all_gather(loss_batches, local_losses) + timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts: list[int], losses: list[float]) -> None: + """ + Update the reweighting using losses from a model. + + Sub-classes should override this method to update the reweighting + using losses from the model. + + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__( + self, + diffusion: GaussianMultinomialDiffusion, + history_per_term: int = 10, + uniform_prob: float = 0.001, + ): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.uint) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts: list[int], losses: list[float]) -> None: + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self) -> bool: + return (self._loss_counts == self.history_per_term).all() + + +def create_named_schedule_sampler(name: str, diffusion: GaussianMultinomialDiffusion) -> ScheduleSampler: + """ + Create a ScheduleSampler from a library of pre-defined samplers. + + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + if name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + raise NotImplementedError(f"unknown schedule sampler: {name}") From a9d24e8783decaa131b4b96ee6b019dc740846af Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Thu, 31 Jul 2025 17:36:12 -0400 Subject: [PATCH 08/39] Adding module docstring --- src/midst_toolkit/models/clavaddpm/clustering.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/midst_toolkit/models/clavaddpm/clustering.py b/src/midst_toolkit/models/clavaddpm/clustering.py index ddfefa09..c58e7ec6 100644 --- a/src/midst_toolkit/models/clavaddpm/clustering.py +++ b/src/midst_toolkit/models/clavaddpm/clustering.py @@ -1,3 +1,5 @@ +"""Clustering functions for the multi-tableClavaDDPM model.""" + import os import pickle from collections import defaultdict From 53e562ed1357ca3709f8c8c1ee3e24c30029f0f1 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Thu, 31 Jul 2025 18:27:15 -0400 Subject: [PATCH 09/39] Moving sampler classes to their own module and some more additional train functions --- src/midst_toolkit/models/clavaddpm/model.py | 78 ----------- src/midst_toolkit/models/clavaddpm/sampler.py | 73 ++++++++-- src/midst_toolkit/models/clavaddpm/train.py | 131 +++++++++++++++++- 3 files changed, 188 insertions(+), 94 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/model.py b/src/midst_toolkit/models/clavaddpm/model.py index 9861743e..7a579f83 100644 --- a/src/midst_toolkit/models/clavaddpm/model.py +++ b/src/midst_toolkit/models/clavaddpm/model.py @@ -2,7 +2,6 @@ import json import math import pickle -from abc import ABC, abstractmethod from collections import Counter from collections.abc import Callable, Generator from copy import deepcopy @@ -33,9 +32,6 @@ ) from torch import Tensor, nn -from midst_toolkit.core import logger -from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion - Normalization = Literal["standard", "quantile", "minmax"] NumNanPolicy = Literal["drop-rows", "mean"] @@ -500,80 +496,6 @@ def get_model( raise ValueError("Unknown model!") -def split_microbatches( - microbatch: int, - batch: Tensor, - labels: Tensor, - t: Tensor, -) -> Generator[tuple[Tensor, Tensor, Tensor]]: - bs = len(batch) - if microbatch == -1 or microbatch >= bs: - yield batch, labels, t - else: - for i in range(0, bs, microbatch): - yield batch[i : i + microbatch], labels[i : i + microbatch], t[i : i + microbatch] - - -def compute_top_k(logits: Tensor, labels: Tensor, k: int, reduction: str = "mean") -> Tensor: - _, top_ks = torch.topk(logits, k, dim=-1) - if reduction == "mean": - return (top_ks == labels[:, None]).float().sum(dim=-1).mean() - if reduction == "none": - return (top_ks == labels[:, None]).float().sum(dim=-1) - - raise ValueError(f"reduction should be one of ['mean', 'none']: {reduction}") - - -def log_loss_dict(diffusion: GaussianMultinomialDiffusion, ts: Tensor, losses: dict[str, Tensor]) -> None: - for key, values in losses.items(): - logger.logkv_mean(key, values.mean().item()) - # Log the quantiles (four quartiles, in particular). - for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): - quartile = int(4 * sub_t / diffusion.num_timesteps) - logger.logkv_mean(f"{key}_q{quartile}", sub_loss) - - -def numerical_forward_backward_log( - classifier: nn.Module, - optimizer: torch.optim.Optimizer, - data_loader: Generator[tuple[Tensor, ...]], - dataset: Dataset, - schedule_sampler: ScheduleSampler, - diffusion: GaussianMultinomialDiffusion, - prefix: str = "train", - remove_first_col: bool = False, - device: str = "cuda", -) -> None: - batch, labels = next(data_loader) - labels = labels.long().to(device) - - if remove_first_col: - # Remove the first column of the batch, which is the label. - batch = batch[:, 1:] - - num_batch = batch[:, : dataset.n_num_features].to(device) - - t, _ = schedule_sampler.sample(num_batch.shape[0], device) - batch = diffusion.gaussian_q_sample(num_batch, t).to(device) - - for i, (sub_batch, sub_labels, sub_t) in enumerate(split_microbatches(-1, batch, labels, t)): - logits = classifier(sub_batch, timesteps=sub_t) - loss = F.cross_entropy(logits, sub_labels, reduction="none") - - losses = {} - losses[f"{prefix}_loss"] = loss.detach() - losses[f"{prefix}_acc@1"] = compute_top_k(logits, sub_labels, k=1, reduction="none") - if logits.shape[1] >= 5: - losses[f"{prefix}_acc@5"] = compute_top_k(logits, sub_labels, k=5, reduction="none") - log_loss_dict(diffusion, sub_t, losses) - del losses - loss = loss.mean() - if loss.requires_grad: - if i == 0: - optimizer.zero_grad() - loss.backward(loss * len(sub_batch) / len(batch)) # type: ignore[no-untyped-call] - - def transform_dataset( dataset: Dataset, transformations: Transformations, diff --git a/src/midst_toolkit/models/clavaddpm/sampler.py b/src/midst_toolkit/models/clavaddpm/sampler.py index e60fa7bf..9ae23bd9 100644 --- a/src/midst_toolkit/models/clavaddpm/sampler.py +++ b/src/midst_toolkit/models/clavaddpm/sampler.py @@ -1,3 +1,14 @@ +"""Samplers for the ClavaDDPM model.""" + +from abc import ABC, abstractmethod + +import numpy as np +import torch +from torch import Tensor + +from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion + + class ScheduleSampler(ABC): """ A distribution over timesteps in the diffusion process, intended to reduce @@ -21,11 +32,14 @@ def sample(self, batch_size: int, device: str) -> tuple[Tensor, Tensor]: """ Importance-sample timesteps for a batch. - :param batch_size: the number of timesteps. - :param device: the torch device to save to. - :return: a tuple (timesteps, weights): - - timesteps: a tensor of timestep indices. - - weights: a tensor of weights to scale the resulting losses. + Args: + batch_size: The number of timesteps. + device: The torch device to save to. + + Returns: + A tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. """ w = self.weights().cpu().numpy() p = w / np.sum(w) @@ -38,10 +52,17 @@ def sample(self, batch_size: int, device: str) -> tuple[Tensor, Tensor]: class UniformSampler(ScheduleSampler): def __init__(self, diffusion: GaussianMultinomialDiffusion): + """ + Initialize the UniformSampler. + + Args: + diffusion: The diffusion object. + """ self.diffusion = diffusion self._weights = torch.from_numpy(np.ones([diffusion.num_timesteps])) def weights(self) -> Tensor: + """Return the weights.""" return self._weights @@ -55,8 +76,9 @@ def update_with_local_losses(self, local_ts: Tensor, local_losses: Tensor) -> No This method will perform synchronization to make sure all of the ranks maintain the exact same reweighting. - :param local_ts: an integer Tensor of timesteps. - :param local_losses: a 1D Tensor of losses. + Args: + local_ts: An integer Tensor of timesteps. + local_losses: A 1D Tensor of losses. """ batch_sizes = [ torch.tensor([0], dtype=torch.int32, device=local_ts.device) @@ -91,8 +113,9 @@ def update_with_all_losses(self, ts: list[int], losses: list[float]) -> None: ranks with identical arguments. Thus, it should have deterministic behavior to maintain state across workers. - :param ts: a list of int timesteps. - :param losses: a list of float losses, one per timestep. + Args: + ts: A list of int timesteps. + losses: A list of float losses, one per timestep. """ @@ -103,6 +126,14 @@ def __init__( history_per_term: int = 10, uniform_prob: float = 0.001, ): + """ + Initialize the LossSecondMomentResampler. + + Args: + diffusion: The diffusion object. + history_per_term: The number of losses to keep for each timestep. + uniform_prob: The probability of sampling a uniform timestep. + """ self.diffusion = diffusion self.history_per_term = history_per_term self.uniform_prob = uniform_prob @@ -110,6 +141,11 @@ def __init__( self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.uint) def weights(self): + """ + Return the weights. + + Warms up the sampler if it's not warmed up. + """ if not self._warmed_up(): return np.ones([self.diffusion.num_timesteps], dtype=np.float64) weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) @@ -119,6 +155,13 @@ def weights(self): return weights def update_with_all_losses(self, ts: list[int], losses: list[float]) -> None: + """ + Update the reweighting using losses from the model. + + Args: + ts: The timesteps. + losses: The losses. + """ for t, loss in zip(ts, losses): if self._loss_counts[t] == self.history_per_term: # Shift out the oldest loss term. @@ -129,6 +172,13 @@ def update_with_all_losses(self, ts: list[int], losses: list[float]) -> None: self._loss_counts[t] += 1 def _warmed_up(self) -> bool: + """ + Check if the sampler is warmed up by checking if the loss counts are equal + to the history per term. + + Returns: + True if the sampler is warmed up, False otherwise. + """ return (self._loss_counts == self.history_per_term).all() @@ -136,8 +186,9 @@ def create_named_schedule_sampler(name: str, diffusion: GaussianMultinomialDiffu """ Create a ScheduleSampler from a library of pre-defined samplers. - :param name: the name of the sampler. - :param diffusion: the diffusion object to sample for. + Args: + name: The name of the sampler. + diffusion: The diffusion object to sample for. """ if name == "uniform": return UniformSampler(diffusion) diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index 7048dccb..12f439b8 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -1,29 +1,30 @@ """Defines the training functions for the ClavaDDPM model.""" import pickle +from collections.abc import Generator from pathlib import Path from typing import Any import numpy as np import pandas as pd import torch -from torch import optim +from torch import Tensor, optim from midst_toolkit.core import logger from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion from midst_toolkit.models.clavaddpm.model import ( Classifier, + Dataset, Transformations, - create_named_schedule_sampler, get_model, get_model_params, get_T_dict, get_table_info, make_dataset_from_df, - numerical_forward_backward_log, prepare_fast_dataloader, ) from midst_toolkit.models.clavaddpm.params import Configs, RelationOrder, Tables +from midst_toolkit.models.clavaddpm.sampler import ScheduleSampler, create_named_schedule_sampler from midst_toolkit.models.clavaddpm.trainer import Trainer @@ -423,7 +424,7 @@ def train_classifier( "samples", (step + resume_step + 1) * batch_size, ) - numerical_forward_backward_log( + _numerical_forward_backward_log( classifier, classifier_optimizer, train_loader, @@ -438,7 +439,7 @@ def train_classifier( if not step % eval_interval: with torch.no_grad(): classifier.eval() - numerical_forward_backward_log( + _numerical_forward_backward_log( classifier, classifier_optimizer, val_loader, @@ -470,3 +471,123 @@ def train_classifier( print(acc) return classifier + + +def _numerical_forward_backward_log( + classifier: Classifier, + optimizer: torch.optim.Optimizer, + data_loader: Generator[tuple[Tensor, ...]], + dataset: Dataset, + schedule_sampler: ScheduleSampler, + diffusion: GaussianMultinomialDiffusion, + prefix: str = "train", + remove_first_col: bool = False, + device: str = "cuda", +) -> None: + """ + Forward and backward pass for the numerical features of the ClavaDDPM model. + + Args: + classifier: The classifier model. + optimizer: The optimizer. + data_loader: The data loader. + dataset: The dataset. + schedule_sampler: The schedule sampler. + diffusion: The diffusion object. + prefix: The prefix for the loss. Defaults to "train". + remove_first_col: Whether to remove the first column of the batch. Defaults to False. + device: The device to use. Defaults to "cuda". + """ + batch, labels = next(data_loader) + labels = labels.long().to(device) + + if remove_first_col: + # Remove the first column of the batch, which is the label. + batch = batch[:, 1:] + + num_batch = batch[:, : dataset.n_num_features].to(device) + + t, _ = schedule_sampler.sample(num_batch.shape[0], device) + batch = diffusion.gaussian_q_sample(num_batch, t).to(device) + + for i, (sub_batch, sub_labels, sub_t) in enumerate(_split_microbatches(-1, batch, labels, t)): + logits = classifier(sub_batch, timesteps=sub_t) + loss = torch.nn.functional.cross_entropy(logits, sub_labels, reduction="none") + + losses = {} + losses[f"{prefix}_loss"] = loss.detach() + losses[f"{prefix}_acc@1"] = _compute_top_k(logits, sub_labels, k=1, reduction="none") + if logits.shape[1] >= 5: + losses[f"{prefix}_acc@5"] = _compute_top_k(logits, sub_labels, k=5, reduction="none") + _log_loss_dict(diffusion, sub_t, losses) + del losses + loss = loss.mean() + if loss.requires_grad: + if i == 0: + optimizer.zero_grad() + loss.backward(loss * len(sub_batch) / len(batch)) # type: ignore[no-untyped-call] + + +def _compute_top_k(logits: Tensor, labels: Tensor, k: int, reduction: str = "mean") -> Tensor: + """ + Compute the top-k accuracy. + + Args: + logits: The logits of the classifier. + labels: The labels of the data. + k: The number of top-k. + reduction: The reduction method. Should be one of ["mean", "none"]. Defaults to "mean". + + Returns: + The top-k accuracy. + """ + _, top_ks = torch.topk(logits, k, dim=-1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean() + if reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + raise ValueError(f"reduction should be one of ['mean', 'none']: {reduction}") + + +def _log_loss_dict(diffusion: GaussianMultinomialDiffusion, ts: Tensor, losses: dict[str, Tensor]) -> None: + """ + Output the log loss dictionary in the logger. + + Args: + diffusion: The diffusion object. + ts: The timesteps. + losses: The losses. + """ + for key, values in losses.items(): + logger.logkv_mean(key, values.mean().item()) + # Log the quantiles (four quartiles, in particular). + for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): + quartile = int(4 * sub_t / diffusion.num_timesteps) + logger.logkv_mean(f"{key}_q{quartile}", sub_loss) + + +def _split_microbatches( + microbatch: int, + batch: Tensor, + labels: Tensor, + t: Tensor, +) -> Generator[tuple[Tensor, Tensor, Tensor]]: + """ + Split the batch into microbatches. + + Args: + microbatch: The size of the microbatch. If -1, the batch is not split. + batch: The batch of data as a tensor. + labels: The labels of the data as a tensor. + t: The timesteps tensor. + + Returns: + A generator of for the minibatch which outputs tuples of the batch, labels, and timesteps. + """ + bs = len(batch) + if microbatch == -1 or microbatch >= bs: + yield batch, labels, t + else: + for i in range(0, bs, microbatch): + yield batch[i : i + microbatch], labels[i : i + microbatch], t[i : i + microbatch] From 6e0a3e4bdb49a2cbed4c23bce96b5dc6dea40aef Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 5 Aug 2025 18:02:54 -0300 Subject: [PATCH 10/39] 1st round of Fatemeh's CR --- src/midst_toolkit/models/clavaddpm/train.py | 28 +++++++++---------- src/midst_toolkit/models/clavaddpm/trainer.py | 17 +++++++---- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index a51095e9..8d0f921b 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -1,5 +1,6 @@ """Defines the training functions for the ClavaDDPM model.""" +import logging import os import pickle from pathlib import Path @@ -25,7 +26,11 @@ pair_clustering_keep_id, prepare_fast_dataloader, ) -from midst_toolkit.models.clavaddpm.trainer import Trainer +from midst_toolkit.models.clavaddpm.trainer import ClavaDDPMTrainer + + +logging.basicConfig(level=logging.INFO) +LOGGER = logging.getLogger(__name__) Tables = dict[str, dict[str, Any]] @@ -144,7 +149,7 @@ def clava_training( } relation_order: List of tuples of parent and child tables. Example: [("table1", "table2"), ("table1", "table3")] - save_dir: Directory to save the clustering checkpoint. + save_dir: Directory to save the ClavaDDPM models. diffusion_config: Dictionary of configurations for the diffusion model. The following config keys are required: { d_layers = list[int], @@ -192,8 +197,13 @@ def clava_training( models[(parent, child)] = result target_folder = save_dir / "models" + target_file = target_folder / f"{parent}_{child}_ckpt.pkl" + + create_message = f"Creating {target_folder}. " if not target_folder.exists() else "" + LOGGER.info(f"{create_message}Saving {parent} -> {child} model to {target_file}") + target_folder.mkdir(parents=True, exist_ok=True) - with open(target_folder / f"{parent}_{child}_ckpt.pkl", "wb") as f: + with open(target_file, "wb") as f: pickle.dump(result, f) return models @@ -359,22 +369,17 @@ def train_model( df_info=df_info, std=0, ) - # print(dataset.n_features) train_loader = prepare_fast_dataloader(dataset, split="train", batch_size=batch_size, y_type="long") - num_numerical_features = dataset.X_num["train"].shape[1] if dataset.X_num is not None else 0 - K = np.array(dataset.get_category_sizes("train")) # ruff: noqa: N806 if len(K) == 0 or T_dict["cat_encoding"] == "one-hot": K = np.array([0]) # ruff: noqa: N806 - # print(K) num_numerical_features = dataset.X_num["train"].shape[1] if dataset.X_num is not None else 0 d_in = np.sum(K) + num_numerical_features model_params["d_in"] = d_in - # print(d_in) print("Model params: {}".format(model_params)) model = get_model(model_type, model_params) @@ -394,7 +399,7 @@ def train_model( diffusion.to(device) diffusion.train() - trainer = Trainer( + trainer = ClavaDDPMTrainer( diffusion, train_loader, lr=lr, @@ -472,7 +477,6 @@ def train_classifier( test_loader = prepare_fast_dataloader(dataset, split="test", batch_size=batch_size, y_type="long") eval_interval = 5 - # log_interval = 10 K = np.array(dataset.get_category_sizes("train")) # ruff: noqa: N806 @@ -542,10 +546,6 @@ def train_classifier( ) classifier.train() - # Removed because it's too verbose - # if not step % log_interval: - # logger.dumpkvs() - # # test classifier classifier.eval() diff --git a/src/midst_toolkit/models/clavaddpm/trainer.py b/src/midst_toolkit/models/clavaddpm/trainer.py index 43736e46..b3c47704 100644 --- a/src/midst_toolkit/models/clavaddpm/trainer.py +++ b/src/midst_toolkit/models/clavaddpm/trainer.py @@ -11,7 +11,7 @@ from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion -class Trainer: +class ClavaDDPMTrainer: def __init__( self, diffusion: GaussianMultinomialDiffusion, @@ -60,19 +60,24 @@ def _anneal_lr(self, step: int) -> None: for param_group in self.optimizer.param_groups: param_group["lr"] = lr - def _run_step(self, x: Tensor, out_dict: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + def _run_step(self, x: Tensor, target: dict[str, Tensor]) -> tuple[Tensor, Tensor]: """ Run a single step of the training loop. Args: x: The input tensor. - out_dict: The output dictionary. + target: The target dictionary (model output). + + Returns: + A tuple with 2 values: + - The multi-class loss. + - The Gaussian loss. """ x = x.to(self.device) - for k, v in out_dict.items(): - out_dict[k] = v.long().to(self.device) + for k, v in target.items(): + target[k] = v.long().to(self.device) self.optimizer.zero_grad() - loss_multi, loss_gauss = self.diffusion.mixed_loss(x, out_dict) + loss_multi, loss_gauss = self.diffusion.mixed_loss(x, target) loss = loss_multi + loss_gauss loss.backward() # type: ignore[no-untyped-call] self.optimizer.step() From 932e1af9fbbe7ed485736f22fdc8a65daea7a4d0 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 6 Aug 2025 11:42:57 -0300 Subject: [PATCH 11/39] Fatemeh's CR --- src/midst_toolkit/models/clavaddpm/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index 8d0f921b..d95685dd 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -369,7 +369,6 @@ def train_model( df_info=df_info, std=0, ) - train_loader = prepare_fast_dataloader(dataset, split="train", batch_size=batch_size, y_type="long") K = np.array(dataset.get_category_sizes("train")) # ruff: noqa: N806 From e09359205c61a6b7beda18c04942d3bf04be67e5 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 6 Aug 2025 12:10:10 -0300 Subject: [PATCH 12/39] removing print --- src/midst_toolkit/models/clavaddpm/clustering.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/midst_toolkit/models/clavaddpm/clustering.py b/src/midst_toolkit/models/clavaddpm/clustering.py index c58e7ec6..63f77e37 100644 --- a/src/midst_toolkit/models/clavaddpm/clustering.py +++ b/src/midst_toolkit/models/clavaddpm/clustering.py @@ -256,7 +256,6 @@ def _pair_clustering_keep_id( child_group_lengths = np.array([len(group) for group in child_group_data], dtype=int) num_clusters = min(num_clusters, len(cluster_data)) - # print('clustering') if clustering_method == "kmeans": kmeans = KMeans(n_clusters=num_clusters, n_init="auto", init="k-means++") kmeans.fit(cluster_data) From 5b06beeaba67aaa9ddbd4f3dd84b80d317d4a6fe Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 6 Aug 2025 13:22:13 -0300 Subject: [PATCH 13/39] WIP needs docstrings --- src/midst_toolkit/models/clavaddpm/dataset.py | 636 +++++++++++++++++ src/midst_toolkit/models/clavaddpm/model.py | 645 +----------------- src/midst_toolkit/models/clavaddpm/params.py | 7 + src/midst_toolkit/models/clavaddpm/train.py | 41 +- 4 files changed, 681 insertions(+), 648 deletions(-) create mode 100644 src/midst_toolkit/models/clavaddpm/dataset.py diff --git a/src/midst_toolkit/models/clavaddpm/dataset.py b/src/midst_toolkit/models/clavaddpm/dataset.py new file mode 100644 index 00000000..e473716e --- /dev/null +++ b/src/midst_toolkit/models/clavaddpm/dataset.py @@ -0,0 +1,636 @@ +"""Defines the dataset functions for the ClavaDDPM model.""" + +import hashlib +import json +import pickle +from collections import Counter +from copy import deepcopy +from dataclasses import astuple, dataclass, replace +from enum import Enum +from pathlib import Path +from typing import Any, Literal, cast + +import numpy as np +import pandas as pd +import torch +from category_encoders import LeaveOneOutEncoder +from scipy.special import expit, softmax +from sklearn.impute import SimpleImputer +from sklearn.metrics import classification_report, mean_squared_error, r2_score, roc_auc_score +from sklearn.model_selection import train_test_split +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import ( + LabelEncoder, + MinMaxScaler, + OneHotEncoder, + OrdinalEncoder, + QuantileTransformer, + StandardScaler, +) + +from midst_toolkit.models.clavaddpm.params import ArrayDict + + +CAT_MISSING_VALUE = "__nan__" +CAT_RARE_VALUE = "__rare__" + + +Normalization = Literal["standard", "quantile", "minmax"] +NumNanPolicy = Literal["drop-rows", "mean"] +CatNanPolicy = Literal["most_frequent"] +CatEncoding = Literal["one-hot", "counter"] +YPolicy = Literal["default"] + + +class TaskType(Enum): + BINCLASS = "binclass" + MULTICLASS = "multiclass" + REGRESSION = "regression" + + def __str__(self) -> str: + return self.value + + +class PredictionType(Enum): + LOGITS = "logits" + PROBS = "probs" + + +@dataclass(frozen=True) +class Transformations: + seed: int = 0 + normalization: Normalization | None = None + num_nan_policy: NumNanPolicy | None = None + cat_nan_policy: CatNanPolicy | None = None + cat_min_frequency: float | None = None + cat_encoding: CatEncoding | None = None + y_policy: YPolicy | None = "default" + + +# TODO move this into the Transformations' class init +def get_T_dict() -> dict[str, Any]: + """ + Return a dictionary used to initialize the transformation object. + + Returns: + The transformation object default parameters. + """ + # ruff: noqa: N802 + return { + "seed": 0, + "normalization": "quantile", + "num_nan_policy": None, + "cat_nan_policy": None, + "cat_min_frequency": None, + "cat_encoding": None, + "y_policy": "default", + } + + +@dataclass(frozen=False) +class Dataset: + X_num: ArrayDict | None + X_cat: ArrayDict | None + y: ArrayDict + y_info: dict[str, Any] + task_type: TaskType + n_classes: int | None + cat_transform: OneHotEncoder | None = None + num_transform: StandardScaler | None = None + + @classmethod + def from_dir(cls, dir_: Path | str) -> "Dataset": + dir_ = Path(dir_) + splits = [k for k in ["train", "val", "test"] if dir_.joinpath(f"y_{k}.npy").exists()] + + def load(item: str) -> ArrayDict: + return {x: cast(np.ndarray, np.load(dir_ / f"{item}_{x}.npy", allow_pickle=True)) for x in splits} + + if Path(dir_ / "info.json").exists(): + info = json.loads(Path(dir_ / "info.json").read_text()) + else: + info = None + # ruff: noqa: SIM108 + + return Dataset( + load("X_num") if dir_.joinpath("X_num_train.npy").exists() else None, + load("X_cat") if dir_.joinpath("X_cat_train.npy").exists() else None, + load("y"), + {}, + TaskType(info["task_type"]), + info.get("n_classes"), + ) + + @property + def is_binclass(self) -> bool: + return self.task_type == TaskType.BINCLASS + + @property + def is_multiclass(self) -> bool: + return self.task_type == TaskType.MULTICLASS + + @property + def is_regression(self) -> bool: + return self.task_type == TaskType.REGRESSION + + @property + def n_num_features(self) -> int: + return 0 if self.X_num is None else self.X_num["train"].shape[1] + + @property + def n_cat_features(self) -> int: + return 0 if self.X_cat is None else self.X_cat["train"].shape[1] + + @property + def n_features(self) -> int: + return self.n_num_features + self.n_cat_features + + def size(self, part: str | None) -> int: + return sum(map(len, self.y.values())) if part is None else len(self.y[part]) + + @property + def nn_output_dim(self) -> int: + if self.is_multiclass: + assert self.n_classes is not None + return self.n_classes + return 1 + + def get_category_sizes(self, part: str) -> list[int]: + return [] if self.X_cat is None else get_category_sizes(self.X_cat[part]) + + def calculate_metrics( + self, + predictions: dict[str, np.ndarray], + prediction_type: str | None, + ) -> dict[str, Any]: + metrics = { + x: calculate_metrics(self.y[x], predictions[x], self.task_type, prediction_type, self.y_info) + for x in predictions + } + if self.task_type == TaskType.REGRESSION: + score_key = "rmse" + score_sign = -1 + else: + score_key = "accuracy" + score_sign = 1 + for part_metrics in metrics.values(): + part_metrics["score"] = score_sign * part_metrics[score_key] + return metrics + + +# TODO consider moving all the functions below into the Dataset class +def get_category_sizes(X: torch.Tensor | np.ndarray) -> list[int]: + XT = X.T.cpu().tolist() if isinstance(X, torch.Tensor) else X.T.tolist() + return [len(set(x)) for x in XT] + + +def calculate_metrics( + y_true: np.ndarray, + y_pred: np.ndarray, + task_type: str | TaskType, + prediction_type: str | PredictionType | None, + y_info: dict[str, Any], +) -> dict[str, Any]: + # Example: calculate_metrics(y_true, y_pred, 'binclass', 'logits', {}) + task_type = TaskType(task_type) + if prediction_type is not None: + prediction_type = PredictionType(prediction_type) + + if task_type == TaskType.REGRESSION: + assert prediction_type is None + assert "std" in y_info + rmse = calculate_rmse(y_true, y_pred, y_info["std"]) + r2 = r2_score(y_true, y_pred) + result = {"rmse": rmse, "r2": r2} + else: + labels, probs = _get_labels_and_probs(y_pred, task_type, prediction_type) + result = cast(dict[str, Any], classification_report(y_true, labels, output_dict=True)) + if task_type == TaskType.BINCLASS: + result["roc_auc"] = roc_auc_score(y_true, probs) + return result + + +def calculate_rmse(y_true: np.ndarray, y_pred: np.ndarray, std: float | None) -> float: + rmse = mean_squared_error(y_true, y_pred) ** 0.5 + if std is not None: + rmse *= std + return rmse + + +def _get_labels_and_probs( + y_pred: np.ndarray, task_type: TaskType, prediction_type: PredictionType | None +) -> tuple[np.ndarray, np.ndarray | None]: + assert task_type in (TaskType.BINCLASS, TaskType.MULTICLASS) + + if prediction_type is None: + return y_pred, None + + if prediction_type == PredictionType.LOGITS: + probs = expit(y_pred) if task_type == TaskType.BINCLASS else softmax(y_pred, axis=1) + elif prediction_type == PredictionType.PROBS: + probs = y_pred + else: + raise ValueError(f"Unknown prediction_type: {prediction_type}") + + assert probs is not None + labels = np.round(probs) if task_type == TaskType.BINCLASS else probs.argmax(axis=1) + return labels.astype("int64"), probs + + +def make_dataset_from_df( + # ruff: noqa: PLR0915, PLR0912 + df: pd.DataFrame, + T: Transformations, + is_y_cond: str, + df_info: pd.DataFrame, + ratios: list[float] | None = None, + std: float = 0, +) -> tuple[Dataset, dict[int, LabelEncoder], list[int]]: + """ + The order of the generated dataset: (y, X_num, X_cat). + + is_y_cond: + concat: y is concatenated to X, the model learn a joint distribution of (y, X) + embedding: y is not concatenated to X. During computations, y is embedded + and added to the latent vector of X + none: y column is completely ignored + + How does is_y_cond affect the generation of y? + is_y_cond: + concat: the model synthesizes (y, X) directly, so y is just the first column + embedding: y is first sampled using empirical distribution of y. The model only + synthesizes X. When returning the generated data, we return the generated X + and the sampled y. (y is sampled from empirical distribution, instead of being + generated by the model) + Note that in this way, y is still not independent of X, because the model has been + adding the embedding of y to the latent vector of X during computations. + none: + y is synthesized using y's empirical distribution. X is generated by the model. + In this case, y is completely independent of X. + + Note: For now, n_classes has to be set to 0. This is because our matrix is the concatenation + of (X_num, X_cat). In this case, if we have is_y_cond == 'concat', we can guarantee that y + is the first column of the matrix. + However, if we have n_classes > 0, then y is not the first column of the matrix. + """ + if ratios is None: + ratios = [0.7, 0.2, 0.1] + + train_val_df, test_df = train_test_split(df, test_size=ratios[2], random_state=42) + train_df, val_df = train_test_split(train_val_df, test_size=ratios[1] / (ratios[0] + ratios[1]), random_state=42) + + cat_column_orders = [] + num_column_orders = [] + index_to_column = list(df.columns) + column_to_index = {col: i for i, col in enumerate(index_to_column)} + + if df_info["n_classes"] > 0: + X_cat: dict[str, np.ndarray] | None = {} if df_info["cat_cols"] is not None or is_y_cond == "concat" else None + X_num: dict[str, np.ndarray] | None = {} if df_info["num_cols"] is not None else None + y = {} + + cat_cols_with_y = [] + if df_info["cat_cols"] is not None: + cat_cols_with_y += df_info["cat_cols"] + if is_y_cond == "concat": + cat_cols_with_y = [df_info["y_col"]] + cat_cols_with_y + + if len(cat_cols_with_y) > 0: + X_cat["train"] = train_df[cat_cols_with_y].to_numpy(dtype=np.str_) # type: ignore[index] + X_cat["val"] = val_df[cat_cols_with_y].to_numpy(dtype=np.str_) # type: ignore[index] + X_cat["test"] = test_df[cat_cols_with_y].to_numpy(dtype=np.str_) # type: ignore[index] + + y["train"] = train_df[df_info["y_col"]].values.astype(np.float32) + y["val"] = val_df[df_info["y_col"]].values.astype(np.float32) + y["test"] = test_df[df_info["y_col"]].values.astype(np.float32) + + if df_info["num_cols"] is not None: + X_num["train"] = train_df[df_info["num_cols"]].values.astype(np.float32) # type: ignore[index] + X_num["val"] = val_df[df_info["num_cols"]].values.astype(np.float32) # type: ignore[index] + X_num["test"] = test_df[df_info["num_cols"]].values.astype(np.float32) # type: ignore[index] + + cat_column_orders = [column_to_index[col] for col in cat_cols_with_y] + num_column_orders = [column_to_index[col] for col in df_info["num_cols"]] + + else: + X_cat = {} if df_info["cat_cols"] is not None else None + X_num = {} if df_info["num_cols"] is not None or is_y_cond == "concat" else None + y = {} + + num_cols_with_y = [] + if df_info["num_cols"] is not None: + num_cols_with_y += df_info["num_cols"] + if is_y_cond == "concat": + num_cols_with_y = [df_info["y_col"]] + num_cols_with_y + + if len(num_cols_with_y) > 0: + X_num["train"] = train_df[num_cols_with_y].values.astype(np.float32) # type: ignore[index] + X_num["val"] = val_df[num_cols_with_y].values.astype(np.float32) # type: ignore[index] + X_num["test"] = test_df[num_cols_with_y].values.astype(np.float32) # type: ignore[index] + + y["train"] = train_df[df_info["y_col"]].values.astype(np.float32) + y["val"] = val_df[df_info["y_col"]].values.astype(np.float32) + y["test"] = test_df[df_info["y_col"]].values.astype(np.float32) + + if df_info["cat_cols"] is not None: + X_cat["train"] = train_df[df_info["cat_cols"]].to_numpy(dtype=np.str_) # type: ignore[index] + X_cat["val"] = val_df[df_info["cat_cols"]].to_numpy(dtype=np.str_) # type: ignore[index] + X_cat["test"] = test_df[df_info["cat_cols"]].to_numpy(dtype=np.str_) # type: ignore[index] + + cat_column_orders = [column_to_index[col] for col in df_info["cat_cols"]] + num_column_orders = [column_to_index[col] for col in num_cols_with_y] + + column_orders = num_column_orders + cat_column_orders + column_orders = [index_to_column[index] for index in column_orders] + + label_encoders = {} + if X_cat is not None and len(df_info["cat_cols"]) > 0: + X_cat_all = np.vstack((X_cat["train"], X_cat["val"], X_cat["test"])) + X_cat_converted = [] + for col_index in range(X_cat_all.shape[1]): + label_encoder = LabelEncoder() + X_cat_converted.append(label_encoder.fit_transform(X_cat_all[:, col_index]).astype(float)) + if std > 0: + # add noise + X_cat_converted[-1] += np.random.normal(0, std, X_cat_converted[-1].shape) + label_encoders[col_index] = label_encoder + + X_cat_converted = np.vstack(X_cat_converted).T # type: ignore[assignment] + + train_num = X_cat["train"].shape[0] + val_num = X_cat["val"].shape[0] + # test_num = X_cat["test"].shape[0] + + X_cat["train"] = X_cat_converted[:train_num, :] # type: ignore[call-overload] + X_cat["val"] = X_cat_converted[train_num : train_num + val_num, :] # type: ignore[call-overload] + X_cat["test"] = X_cat_converted[train_num + val_num :, :] # type: ignore[call-overload] + + if X_num and len(X_num) > 0: + X_num["train"] = np.concatenate((X_num["train"], X_cat["train"]), axis=1) + X_num["val"] = np.concatenate((X_num["val"], X_cat["val"]), axis=1) + X_num["test"] = np.concatenate((X_num["test"], X_cat["test"]), axis=1) + else: + X_num = X_cat + X_cat = None + + D = Dataset( + # ruff: noqa: N806 + X_num, + None, + y, + y_info={}, + task_type=TaskType(df_info["task_type"]), + n_classes=df_info["n_classes"], + ) + + return transform_dataset(D, T, None), label_encoders, column_orders + + +def transform_dataset( + dataset: Dataset, + transformations: Transformations, + cache_dir: Path | None, + transform_cols_num: int = 0, +) -> Dataset: + # WARNING: the order of transformations matters. Moreover, the current + # implementation is not ideal in that sense. + if cache_dir is not None: + transformations_md5 = hashlib.md5(str(transformations).encode("utf-8")).hexdigest() + transformations_str = "__".join(map(str, astuple(transformations))) + cache_path = cache_dir / f"cache__{transformations_str}__{transformations_md5}.pickle" + if cache_path.exists(): + cache_transformations, value = load_pickle(cache_path) + if transformations == cache_transformations: + print(f"Using cached features: {cache_dir.name + '/' + cache_path.name}") + return value + raise RuntimeError(f"Hash collision for {cache_path}") + else: + cache_path = None + + if dataset.X_num is not None: + dataset = num_process_nans(dataset, transformations.num_nan_policy) + + num_transform = None + cat_transform = None + X_num = dataset.X_num + + if X_num is not None and transformations.normalization is not None: + X_num, num_transform = normalize( # type: ignore[assignment] + X_num, + transformations.normalization, + transformations.seed, + return_normalizer=True, + ) + + if dataset.X_cat is None: + assert transformations.cat_nan_policy is None + assert transformations.cat_min_frequency is None + # assert transformations.cat_encoding is None + X_cat = None + else: + X_cat = cat_process_nans(dataset.X_cat, transformations.cat_nan_policy) + if transformations.cat_min_frequency is not None: + X_cat = cat_drop_rare(X_cat, transformations.cat_min_frequency) + X_cat, is_num, cat_transform = cat_encode( + X_cat, + transformations.cat_encoding, + dataset.y["train"], + transformations.seed, + return_encoder=True, + ) + if is_num: + X_num = X_cat if X_num is None else {x: np.hstack([X_num[x], X_cat[x]]) for x in X_num} + X_cat = None + + y, y_info = build_target(dataset.y, transformations.y_policy, dataset.task_type) + + dataset = replace(dataset, X_num=X_num, X_cat=X_cat, y=y, y_info=y_info) + dataset.num_transform = num_transform + dataset.cat_transform = cat_transform + + if cache_path is not None: + dump_pickle((transformations, dataset), cache_path) + # if return_transforms: + # return dataset, num_transform, cat_transform + return dataset + + +def load_pickle(path: Path | str, **kwargs: Any) -> Any: + # ruff: noqa: D103 + return pickle.loads(Path(path).read_bytes(), **kwargs) + + +def dump_pickle(x: Any, path: Path | str, **kwargs: Any) -> None: + # ruff: noqa: D103 + Path(path).write_bytes(pickle.dumps(x, **kwargs)) + + +def num_process_nans(dataset: Dataset, policy: NumNanPolicy | None) -> Dataset: + # ruff: noqa: D103 + assert dataset.X_num is not None + nan_masks = {k: np.isnan(v) for k, v in dataset.X_num.items()} + if not any(x.any() for x in nan_masks.values()): + assert policy is None + return dataset + + assert policy is not None + if policy == "drop-rows": + valid_masks = {k: ~v.any(1) for k, v in nan_masks.items()} + assert valid_masks["test"].all(), "Cannot drop test rows, since this will affect the final metrics." + new_data = {} + for data_name in ["X_num", "X_cat", "y"]: + data_dict = getattr(dataset, data_name) + if data_dict is not None: + new_data[data_name] = {k: v[valid_masks[k]] for k, v in data_dict.items()} + dataset = replace(dataset, **new_data) # type: ignore[arg-type] + elif policy == "mean": + new_values = np.nanmean(dataset.X_num["train"], axis=0) # type: ignore[index] + X_num = deepcopy(dataset.X_num) + for k, v in X_num.items(): # type: ignore[union-attr] + num_nan_indices = np.where(nan_masks[k]) + v[num_nan_indices] = np.take(new_values, num_nan_indices[1]) + dataset = replace(dataset, X_num=X_num) + else: + raise ValueError(f"Unknown policy: {policy}") + return dataset + + +# Inspired by: https://github.com/yandex-research/rtdl/blob/a4c93a32b334ef55d2a0559a4407c8306ffeeaee/lib/data.py#L20 +def normalize( + X: ArrayDict, + normalization: Normalization, + seed: int | None, + return_normalizer: bool = False, +) -> ArrayDict | tuple[ArrayDict, StandardScaler | MinMaxScaler | QuantileTransformer]: + # ruff: noqa: D103 + X_train = X["train"] + if normalization == "standard": + normalizer = StandardScaler() + elif normalization == "minmax": + normalizer = MinMaxScaler() + elif normalization == "quantile": + normalizer = QuantileTransformer( + output_distribution="normal", + n_quantiles=max(min(X["train"].shape[0] // 30, 1000), 10), + subsample=int(1e9), + random_state=seed, + ) + else: + raise ValueError(f"Unknown normalization: {normalization}") + normalizer.fit(X_train) + if return_normalizer: + return {k: normalizer.transform(v) for k, v in X.items()}, normalizer + return {k: normalizer.transform(v) for k, v in X.items()} + + +def cat_process_nans(X: ArrayDict, policy: CatNanPolicy | None) -> ArrayDict: + # ruff: noqa: D103 + assert X is not None + nan_masks = {k: v == CAT_MISSING_VALUE for k, v in X.items()} + if any(x.any() for x in nan_masks.values()): + if policy is None: + X_new = X + elif policy == "most_frequent": + imputer = SimpleImputer(missing_values=CAT_MISSING_VALUE, strategy=policy) + imputer.fit(X["train"]) + X_new = {k: cast(np.ndarray, imputer.transform(v)) for k, v in X.items()} + else: + raise ValueError(f"Unknown cat_nan_policy: {policy}") + else: + assert policy is None + X_new = X + return X_new + + +def cat_drop_rare(X: ArrayDict, min_frequency: float) -> ArrayDict: + # ruff: noqa: D103 + assert 0.0 < min_frequency < 1.0 + min_count = round(len(X["train"]) * min_frequency) + X_new: dict[str, list[Any]] = {x: [] for x in X} + for column_idx in range(X["train"].shape[1]): + counter = Counter(X["train"][:, column_idx].tolist()) + popular_categories = {k for k, v in counter.items() if v >= min_count} + for part, _ in X_new.items(): + X_new[part].append( + [(x if x in popular_categories else CAT_RARE_VALUE) for x in X[part][:, column_idx].tolist()] + ) + return {k: np.array(v).T for k, v in X_new.items()} + + +def cat_encode( + X: ArrayDict, + encoding: CatEncoding | None, + y_train: np.ndarray | None, + seed: int | None, + return_encoder: bool = False, +) -> tuple[ArrayDict, bool, Any | None]: # (X, is_converted_to_numerical) + # ruff: noqa: D103 + if encoding != "counter": + y_train = None + + # Step 1. Map strings to 0-based ranges + + if encoding is None: + unknown_value = np.iinfo("int64").max - 3 + oe = OrdinalEncoder( + handle_unknown="use_encoded_value", + unknown_value=unknown_value, + dtype="int64", + ).fit(X["train"]) + encoder = make_pipeline(oe) + encoder.fit(X["train"]) + X = {k: encoder.transform(v) for k, v in X.items()} + max_values = X["train"].max(axis=0) + for part in X: + if part == "train": + continue + for column_idx in range(X[part].shape[1]): + X[part][X[part][:, column_idx] == unknown_value, column_idx] = max_values[column_idx] + 1 + if return_encoder: + return X, False, encoder + return X, False, None + + # Step 2. Encode. + + if encoding == "one-hot": + ohe = OneHotEncoder( + handle_unknown="ignore", + sparse=False, + dtype=np.float32, + ) + encoder = make_pipeline(ohe) + + # encoder.steps.append(('ohe', ohe)) + encoder.fit(X["train"]) + X = {k: encoder.transform(v) for k, v in X.items()} + elif encoding == "counter": + assert y_train is not None + assert seed is not None + loe = LeaveOneOutEncoder(sigma=0.1, random_state=seed, return_df=False) + encoder.steps.append(("loe", loe)) + encoder.fit(X["train"], y_train) + X = {k: encoder.transform(v).astype("float32") for k, v in X.items()} + if not isinstance(X["train"], pd.DataFrame): + X = {k: v.values for k, v in X.items()} # type: ignore[attr-defined] + else: + raise ValueError(f"Unknown encoding: {encoding}") + + if return_encoder: + return X, True, encoder + return X, True, None + + +def build_target(y: ArrayDict, policy: YPolicy | None, task_type: TaskType) -> tuple[ArrayDict, dict[str, Any]]: + # ruff: noqa: D103 + info: dict[str, Any] = {"policy": policy} + if policy is None: + pass + elif policy == "default": + if task_type == TaskType.REGRESSION: + mean, std = float(y["train"].mean()), float(y["train"].std()) + y = {k: (v - mean) / std for k, v in y.items()} + info["mean"] = mean + info["std"] = std + else: + raise ValueError(f"Unknown policy: {policy}") + return y, info diff --git a/src/midst_toolkit/models/clavaddpm/model.py b/src/midst_toolkit/models/clavaddpm/model.py index 7a579f83..d5225014 100644 --- a/src/midst_toolkit/models/clavaddpm/model.py +++ b/src/midst_toolkit/models/clavaddpm/model.py @@ -1,224 +1,15 @@ -import hashlib -import json import math -import pickle -from collections import Counter from collections.abc import Callable, Generator -from copy import deepcopy -from dataclasses import astuple, dataclass, replace -from enum import Enum -from pathlib import Path -from typing import Any, Literal, Self, cast +from typing import Any, Self import numpy as np import pandas as pd import torch import torch.nn.functional as F - -# ruff: noqa: N812 -from category_encoders import LeaveOneOutEncoder -from scipy.special import expit, softmax -from sklearn.impute import SimpleImputer -from sklearn.metrics import classification_report, mean_squared_error, r2_score, roc_auc_score -from sklearn.model_selection import train_test_split -from sklearn.pipeline import make_pipeline -from sklearn.preprocessing import ( - LabelEncoder, - MinMaxScaler, - OneHotEncoder, - OrdinalEncoder, - QuantileTransformer, - StandardScaler, -) from torch import Tensor, nn - -Normalization = Literal["standard", "quantile", "minmax"] -NumNanPolicy = Literal["drop-rows", "mean"] -CatNanPolicy = Literal["most_frequent"] -CatEncoding = Literal["one-hot", "counter"] -YPolicy = Literal["default"] - - -ArrayDict = dict[str, np.ndarray] -ModuleType = str | Callable[..., nn.Module] - -CAT_MISSING_VALUE = "__nan__" -CAT_RARE_VALUE = "__rare__" - - -class TaskType(Enum): - BINCLASS = "binclass" - MULTICLASS = "multiclass" - REGRESSION = "regression" - - def __str__(self) -> str: - return self.value - - -class PredictionType(Enum): - LOGITS = "logits" - PROBS = "probs" - - -@dataclass(frozen=True) -class Transformations: - seed: int = 0 - normalization: Normalization | None = None - num_nan_policy: NumNanPolicy | None = None - cat_nan_policy: CatNanPolicy | None = None - cat_min_frequency: float | None = None - cat_encoding: CatEncoding | None = None - y_policy: YPolicy | None = "default" - - -@dataclass(frozen=False) -class Dataset: - X_num: ArrayDict | None - X_cat: ArrayDict | None - y: ArrayDict - y_info: dict[str, Any] - task_type: TaskType - n_classes: int | None - cat_transform: OneHotEncoder | None = None - num_transform: StandardScaler | None = None - - @classmethod - def from_dir(cls, dir_: Path | str) -> "Dataset": - dir_ = Path(dir_) - splits = [k for k in ["train", "val", "test"] if dir_.joinpath(f"y_{k}.npy").exists()] - - def load(item: str) -> ArrayDict: - return {x: cast(np.ndarray, np.load(dir_ / f"{item}_{x}.npy", allow_pickle=True)) for x in splits} - - if Path(dir_ / "info.json").exists(): - info = json.loads(Path(dir_ / "info.json").read_text()) - else: - info = None - # ruff: noqa: SIM108 - - return Dataset( - load("X_num") if dir_.joinpath("X_num_train.npy").exists() else None, - load("X_cat") if dir_.joinpath("X_cat_train.npy").exists() else None, - load("y"), - {}, - TaskType(info["task_type"]), - info.get("n_classes"), - ) - - @property - def is_binclass(self) -> bool: - return self.task_type == TaskType.BINCLASS - - @property - def is_multiclass(self) -> bool: - return self.task_type == TaskType.MULTICLASS - - @property - def is_regression(self) -> bool: - return self.task_type == TaskType.REGRESSION - - @property - def n_num_features(self) -> int: - return 0 if self.X_num is None else self.X_num["train"].shape[1] - - @property - def n_cat_features(self) -> int: - return 0 if self.X_cat is None else self.X_cat["train"].shape[1] - - @property - def n_features(self) -> int: - return self.n_num_features + self.n_cat_features - - def size(self, part: str | None) -> int: - return sum(map(len, self.y.values())) if part is None else len(self.y[part]) - - @property - def nn_output_dim(self) -> int: - if self.is_multiclass: - assert self.n_classes is not None - return self.n_classes - return 1 - - def get_category_sizes(self, part: str) -> list[int]: - return [] if self.X_cat is None else get_category_sizes(self.X_cat[part]) - - def calculate_metrics( - self, - predictions: dict[str, np.ndarray], - prediction_type: str | None, - ) -> dict[str, Any]: - metrics = { - x: calculate_metrics(self.y[x], predictions[x], self.task_type, prediction_type, self.y_info) - for x in predictions - } - if self.task_type == TaskType.REGRESSION: - score_key = "rmse" - score_sign = -1 - else: - score_key = "accuracy" - score_sign = 1 - for part_metrics in metrics.values(): - part_metrics["score"] = score_sign * part_metrics[score_key] - return metrics - - -def get_category_sizes(X: torch.Tensor | np.ndarray) -> list[int]: - XT = X.T.cpu().tolist() if isinstance(X, torch.Tensor) else X.T.tolist() - return [len(set(x)) for x in XT] - - -def calculate_metrics( - y_true: np.ndarray, - y_pred: np.ndarray, - task_type: str | TaskType, - prediction_type: str | PredictionType | None, - y_info: dict[str, Any], -) -> dict[str, Any]: - # Example: calculate_metrics(y_true, y_pred, 'binclass', 'logits', {}) - task_type = TaskType(task_type) - if prediction_type is not None: - prediction_type = PredictionType(prediction_type) - - if task_type == TaskType.REGRESSION: - assert prediction_type is None - assert "std" in y_info - rmse = calculate_rmse(y_true, y_pred, y_info["std"]) - r2 = r2_score(y_true, y_pred) - result = {"rmse": rmse, "r2": r2} - else: - labels, probs = _get_labels_and_probs(y_pred, task_type, prediction_type) - result = cast(dict[str, Any], classification_report(y_true, labels, output_dict=True)) - if task_type == TaskType.BINCLASS: - result["roc_auc"] = roc_auc_score(y_true, probs) - return result - - -def calculate_rmse(y_true: np.ndarray, y_pred: np.ndarray, std: float | None) -> float: - rmse = mean_squared_error(y_true, y_pred) ** 0.5 - if std is not None: - rmse *= std - return rmse - - -def _get_labels_and_probs( - y_pred: np.ndarray, task_type: TaskType, prediction_type: PredictionType | None -) -> tuple[np.ndarray, np.ndarray | None]: - assert task_type in (TaskType.BINCLASS, TaskType.MULTICLASS) - - if prediction_type is None: - return y_pred, None - - if prediction_type == PredictionType.LOGITS: - probs = expit(y_pred) if task_type == TaskType.BINCLASS else softmax(y_pred, axis=1) - elif prediction_type == PredictionType.PROBS: - probs = y_pred - else: - raise ValueError(f"Unknown prediction_type: {prediction_type}") - - assert probs is not None - labels = np.round(probs) if task_type == TaskType.BINCLASS else probs.argmax(axis=1) - return labels.astype("int64"), probs +from midst_toolkit.models.clavaddpm.dataset import Dataset +from midst_toolkit.models.clavaddpm.params import ModuleType class Classifier(nn.Module): @@ -290,178 +81,6 @@ def get_table_info(df: pd.DataFrame, domain_dict: dict[str, Any], y_col: str) -> return df_info -def get_model_params(rtdl_params: dict[str, Any] | None = None) -> dict[str, Any]: - return { - "num_classes": 0, - "is_y_cond": "none", - "rtdl_params": {"d_layers": [512, 1024, 1024, 1024, 1024, 512], "dropout": 0.0} - if rtdl_params is None - else rtdl_params, - } - - -def get_T_dict() -> dict[str, Any]: - # ruff: noqa: N802 - return { - "seed": 0, - "normalization": "quantile", - "num_nan_policy": None, - "cat_nan_policy": None, - "cat_min_frequency": None, - "cat_encoding": None, - "y_policy": "default", - } - - -def make_dataset_from_df( - # ruff: noqa: PLR0915, PLR0912 - df: pd.DataFrame, - T: Transformations, - is_y_cond: str, - df_info: pd.DataFrame, - ratios: list[float] | None = None, - std: float = 0, -) -> tuple[Dataset, dict[int, LabelEncoder], list[int]]: - """ - The order of the generated dataset: (y, X_num, X_cat). - - is_y_cond: - concat: y is concatenated to X, the model learn a joint distribution of (y, X) - embedding: y is not concatenated to X. During computations, y is embedded - and added to the latent vector of X - none: y column is completely ignored - - How does is_y_cond affect the generation of y? - is_y_cond: - concat: the model synthesizes (y, X) directly, so y is just the first column - embedding: y is first sampled using empirical distribution of y. The model only - synthesizes X. When returning the generated data, we return the generated X - and the sampled y. (y is sampled from empirical distribution, instead of being - generated by the model) - Note that in this way, y is still not independent of X, because the model has been - adding the embedding of y to the latent vector of X during computations. - none: - y is synthesized using y's empirical distribution. X is generated by the model. - In this case, y is completely independent of X. - - Note: For now, n_classes has to be set to 0. This is because our matrix is the concatenation - of (X_num, X_cat). In this case, if we have is_y_cond == 'concat', we can guarantee that y - is the first column of the matrix. - However, if we have n_classes > 0, then y is not the first column of the matrix. - """ - if ratios is None: - ratios = [0.7, 0.2, 0.1] - - train_val_df, test_df = train_test_split(df, test_size=ratios[2], random_state=42) - train_df, val_df = train_test_split(train_val_df, test_size=ratios[1] / (ratios[0] + ratios[1]), random_state=42) - - cat_column_orders = [] - num_column_orders = [] - index_to_column = list(df.columns) - column_to_index = {col: i for i, col in enumerate(index_to_column)} - - if df_info["n_classes"] > 0: - X_cat: dict[str, np.ndarray] | None = {} if df_info["cat_cols"] is not None or is_y_cond == "concat" else None - X_num: dict[str, np.ndarray] | None = {} if df_info["num_cols"] is not None else None - y = {} - - cat_cols_with_y = [] - if df_info["cat_cols"] is not None: - cat_cols_with_y += df_info["cat_cols"] - if is_y_cond == "concat": - cat_cols_with_y = [df_info["y_col"]] + cat_cols_with_y - - if len(cat_cols_with_y) > 0: - X_cat["train"] = train_df[cat_cols_with_y].to_numpy(dtype=np.str_) # type: ignore[index] - X_cat["val"] = val_df[cat_cols_with_y].to_numpy(dtype=np.str_) # type: ignore[index] - X_cat["test"] = test_df[cat_cols_with_y].to_numpy(dtype=np.str_) # type: ignore[index] - - y["train"] = train_df[df_info["y_col"]].values.astype(np.float32) - y["val"] = val_df[df_info["y_col"]].values.astype(np.float32) - y["test"] = test_df[df_info["y_col"]].values.astype(np.float32) - - if df_info["num_cols"] is not None: - X_num["train"] = train_df[df_info["num_cols"]].values.astype(np.float32) # type: ignore[index] - X_num["val"] = val_df[df_info["num_cols"]].values.astype(np.float32) # type: ignore[index] - X_num["test"] = test_df[df_info["num_cols"]].values.astype(np.float32) # type: ignore[index] - - cat_column_orders = [column_to_index[col] for col in cat_cols_with_y] - num_column_orders = [column_to_index[col] for col in df_info["num_cols"]] - - else: - X_cat = {} if df_info["cat_cols"] is not None else None - X_num = {} if df_info["num_cols"] is not None or is_y_cond == "concat" else None - y = {} - - num_cols_with_y = [] - if df_info["num_cols"] is not None: - num_cols_with_y += df_info["num_cols"] - if is_y_cond == "concat": - num_cols_with_y = [df_info["y_col"]] + num_cols_with_y - - if len(num_cols_with_y) > 0: - X_num["train"] = train_df[num_cols_with_y].values.astype(np.float32) # type: ignore[index] - X_num["val"] = val_df[num_cols_with_y].values.astype(np.float32) # type: ignore[index] - X_num["test"] = test_df[num_cols_with_y].values.astype(np.float32) # type: ignore[index] - - y["train"] = train_df[df_info["y_col"]].values.astype(np.float32) - y["val"] = val_df[df_info["y_col"]].values.astype(np.float32) - y["test"] = test_df[df_info["y_col"]].values.astype(np.float32) - - if df_info["cat_cols"] is not None: - X_cat["train"] = train_df[df_info["cat_cols"]].to_numpy(dtype=np.str_) # type: ignore[index] - X_cat["val"] = val_df[df_info["cat_cols"]].to_numpy(dtype=np.str_) # type: ignore[index] - X_cat["test"] = test_df[df_info["cat_cols"]].to_numpy(dtype=np.str_) # type: ignore[index] - - cat_column_orders = [column_to_index[col] for col in df_info["cat_cols"]] - num_column_orders = [column_to_index[col] for col in num_cols_with_y] - - column_orders = num_column_orders + cat_column_orders - column_orders = [index_to_column[index] for index in column_orders] - - label_encoders = {} - if X_cat is not None and len(df_info["cat_cols"]) > 0: - X_cat_all = np.vstack((X_cat["train"], X_cat["val"], X_cat["test"])) - X_cat_converted = [] - for col_index in range(X_cat_all.shape[1]): - label_encoder = LabelEncoder() - X_cat_converted.append(label_encoder.fit_transform(X_cat_all[:, col_index]).astype(float)) - if std > 0: - # add noise - X_cat_converted[-1] += np.random.normal(0, std, X_cat_converted[-1].shape) - label_encoders[col_index] = label_encoder - - X_cat_converted = np.vstack(X_cat_converted).T # type: ignore[assignment] - - train_num = X_cat["train"].shape[0] - val_num = X_cat["val"].shape[0] - # test_num = X_cat["test"].shape[0] - - X_cat["train"] = X_cat_converted[:train_num, :] # type: ignore[call-overload] - X_cat["val"] = X_cat_converted[train_num : train_num + val_num, :] # type: ignore[call-overload] - X_cat["test"] = X_cat_converted[train_num + val_num :, :] # type: ignore[call-overload] - - if X_num and len(X_num) > 0: - X_num["train"] = np.concatenate((X_num["train"], X_cat["train"]), axis=1) - X_num["val"] = np.concatenate((X_num["val"], X_cat["val"]), axis=1) - X_num["test"] = np.concatenate((X_num["test"], X_cat["test"]), axis=1) - else: - X_num = X_cat - X_cat = None - - D = Dataset( - # ruff: noqa: N806 - X_num, - None, - y, - y_info={}, - task_type=TaskType(df_info["task_type"]), - n_classes=df_info["n_classes"], - ) - - return transform_dataset(D, T, None), label_encoders, column_orders - - def prepare_fast_dataloader( D: Dataset, # ruff: noqa: N803 @@ -496,264 +115,6 @@ def get_model( raise ValueError("Unknown model!") -def transform_dataset( - dataset: Dataset, - transformations: Transformations, - cache_dir: Path | None, - transform_cols_num: int = 0, -) -> Dataset: - # WARNING: the order of transformations matters. Moreover, the current - # implementation is not ideal in that sense. - if cache_dir is not None: - transformations_md5 = hashlib.md5(str(transformations).encode("utf-8")).hexdigest() - transformations_str = "__".join(map(str, astuple(transformations))) - cache_path = cache_dir / f"cache__{transformations_str}__{transformations_md5}.pickle" - if cache_path.exists(): - cache_transformations, value = load_pickle(cache_path) - if transformations == cache_transformations: - print(f"Using cached features: {cache_dir.name + '/' + cache_path.name}") - return value - raise RuntimeError(f"Hash collision for {cache_path}") - else: - cache_path = None - - if dataset.X_num is not None: - dataset = num_process_nans(dataset, transformations.num_nan_policy) - - num_transform = None - cat_transform = None - X_num = dataset.X_num - - if X_num is not None and transformations.normalization is not None: - X_num, num_transform = normalize( # type: ignore[assignment] - X_num, - transformations.normalization, - transformations.seed, - return_normalizer=True, - ) - - if dataset.X_cat is None: - assert transformations.cat_nan_policy is None - assert transformations.cat_min_frequency is None - # assert transformations.cat_encoding is None - X_cat = None - else: - X_cat = cat_process_nans(dataset.X_cat, transformations.cat_nan_policy) - if transformations.cat_min_frequency is not None: - X_cat = cat_drop_rare(X_cat, transformations.cat_min_frequency) - X_cat, is_num, cat_transform = cat_encode( - X_cat, - transformations.cat_encoding, - dataset.y["train"], - transformations.seed, - return_encoder=True, - ) - if is_num: - X_num = X_cat if X_num is None else {x: np.hstack([X_num[x], X_cat[x]]) for x in X_num} - X_cat = None - - y, y_info = build_target(dataset.y, transformations.y_policy, dataset.task_type) - - dataset = replace(dataset, X_num=X_num, X_cat=X_cat, y=y, y_info=y_info) - dataset.num_transform = num_transform - dataset.cat_transform = cat_transform - - if cache_path is not None: - dump_pickle((transformations, dataset), cache_path) - # if return_transforms: - # return dataset, num_transform, cat_transform - return dataset - - -def load_pickle(path: Path | str, **kwargs: Any) -> Any: - # ruff: noqa: D103 - return pickle.loads(Path(path).read_bytes(), **kwargs) - - -def dump_pickle(x: Any, path: Path | str, **kwargs: Any) -> None: - # ruff: noqa: D103 - Path(path).write_bytes(pickle.dumps(x, **kwargs)) - - -def num_process_nans(dataset: Dataset, policy: NumNanPolicy | None) -> Dataset: - # ruff: noqa: D103 - assert dataset.X_num is not None - nan_masks = {k: np.isnan(v) for k, v in dataset.X_num.items()} - if not any(x.any() for x in nan_masks.values()): - assert policy is None - return dataset - - assert policy is not None - if policy == "drop-rows": - valid_masks = {k: ~v.any(1) for k, v in nan_masks.items()} - assert valid_masks["test"].all(), "Cannot drop test rows, since this will affect the final metrics." - new_data = {} - for data_name in ["X_num", "X_cat", "y"]: - data_dict = getattr(dataset, data_name) - if data_dict is not None: - new_data[data_name] = {k: v[valid_masks[k]] for k, v in data_dict.items()} - dataset = replace(dataset, **new_data) # type: ignore[arg-type] - elif policy == "mean": - new_values = np.nanmean(dataset.X_num["train"], axis=0) # type: ignore[index] - X_num = deepcopy(dataset.X_num) - for k, v in X_num.items(): # type: ignore[union-attr] - num_nan_indices = np.where(nan_masks[k]) - v[num_nan_indices] = np.take(new_values, num_nan_indices[1]) - dataset = replace(dataset, X_num=X_num) - else: - raise ValueError(f"Unknown policy: {policy}") - return dataset - - -# Inspired by: https://github.com/yandex-research/rtdl/blob/a4c93a32b334ef55d2a0559a4407c8306ffeeaee/lib/data.py#L20 -def normalize( - X: ArrayDict, - normalization: Normalization, - seed: int | None, - return_normalizer: bool = False, -) -> ArrayDict | tuple[ArrayDict, StandardScaler | MinMaxScaler | QuantileTransformer]: - # ruff: noqa: D103 - X_train = X["train"] - if normalization == "standard": - normalizer = StandardScaler() - elif normalization == "minmax": - normalizer = MinMaxScaler() - elif normalization == "quantile": - normalizer = QuantileTransformer( - output_distribution="normal", - n_quantiles=max(min(X["train"].shape[0] // 30, 1000), 10), - subsample=int(1e9), - random_state=seed, - ) - # noise = 1e-3 - # if noise > 0: - # assert seed is not None - # stds = np.std(X_train, axis=0, keepdims=True) - # noise_std = noise / np.maximum(stds, noise) # type: ignore[code] - # X_train = X_train + noise_std * np.random.default_rng(seed).standard_normal( - # X_train.shape - # ) - else: - raise ValueError(f"Unknown normalization: {normalization}") - normalizer.fit(X_train) - if return_normalizer: - return {k: normalizer.transform(v) for k, v in X.items()}, normalizer - return {k: normalizer.transform(v) for k, v in X.items()} - - -def cat_process_nans(X: ArrayDict, policy: CatNanPolicy | None) -> ArrayDict: - # ruff: noqa: D103 - assert X is not None - nan_masks = {k: v == CAT_MISSING_VALUE for k, v in X.items()} - if any(x.any() for x in nan_masks.values()): - if policy is None: - X_new = X - elif policy == "most_frequent": - imputer = SimpleImputer(missing_values=CAT_MISSING_VALUE, strategy=policy) - imputer.fit(X["train"]) - X_new = {k: cast(np.ndarray, imputer.transform(v)) for k, v in X.items()} - else: - raise ValueError(f"Unknown cat_nan_policy: {policy}") - else: - assert policy is None - X_new = X - return X_new - - -def cat_drop_rare(X: ArrayDict, min_frequency: float) -> ArrayDict: - # ruff: noqa: D103 - assert 0.0 < min_frequency < 1.0 - min_count = round(len(X["train"]) * min_frequency) - X_new: dict[str, list[Any]] = {x: [] for x in X} - for column_idx in range(X["train"].shape[1]): - counter = Counter(X["train"][:, column_idx].tolist()) - popular_categories = {k for k, v in counter.items() if v >= min_count} - for part, _ in X_new.items(): - X_new[part].append( - [(x if x in popular_categories else CAT_RARE_VALUE) for x in X[part][:, column_idx].tolist()] - ) - return {k: np.array(v).T for k, v in X_new.items()} - - -def cat_encode( - X: ArrayDict, - encoding: CatEncoding | None, - y_train: np.ndarray | None, - seed: int | None, - return_encoder: bool = False, -) -> tuple[ArrayDict, bool, Any | None]: # (X, is_converted_to_numerical) - # ruff: noqa: D103 - if encoding != "counter": - y_train = None - - # Step 1. Map strings to 0-based ranges - - if encoding is None: - unknown_value = np.iinfo("int64").max - 3 - oe = OrdinalEncoder( - handle_unknown="use_encoded_value", - unknown_value=unknown_value, - dtype="int64", - ).fit(X["train"]) - encoder = make_pipeline(oe) - encoder.fit(X["train"]) - X = {k: encoder.transform(v) for k, v in X.items()} - max_values = X["train"].max(axis=0) - for part in X: - if part == "train": - continue - for column_idx in range(X[part].shape[1]): - X[part][X[part][:, column_idx] == unknown_value, column_idx] = max_values[column_idx] + 1 - if return_encoder: - return X, False, encoder - return X, False, None - - # Step 2. Encode. - - if encoding == "one-hot": - ohe = OneHotEncoder( - handle_unknown="ignore", - sparse=False, - dtype=np.float32, - ) - encoder = make_pipeline(ohe) - - # encoder.steps.append(('ohe', ohe)) - encoder.fit(X["train"]) - X = {k: encoder.transform(v) for k, v in X.items()} - elif encoding == "counter": - assert y_train is not None - assert seed is not None - loe = LeaveOneOutEncoder(sigma=0.1, random_state=seed, return_df=False) - encoder.steps.append(("loe", loe)) - encoder.fit(X["train"], y_train) - X = {k: encoder.transform(v).astype("float32") for k, v in X.items()} - if not isinstance(X["train"], pd.DataFrame): - X = {k: v.values for k, v in X.items()} # type: ignore[attr-defined] - else: - raise ValueError(f"Unknown encoding: {encoding}") - - if return_encoder: - return X, True, encoder - return X, True, None - - -def build_target(y: ArrayDict, policy: YPolicy | None, task_type: TaskType) -> tuple[ArrayDict, dict[str, Any]]: - # ruff: noqa: D103 - info: dict[str, Any] = {"policy": policy} - if policy is None: - pass - elif policy == "default": - if task_type == TaskType.REGRESSION: - mean, std = float(y["train"].mean()), float(y["train"].std()) - y = {k: (v - mean) / std for k, v in y.items()} - info["mean"] = mean - info["std"] = std - else: - raise ValueError(f"Unknown policy: {policy}") - return y, info - - class FastTensorDataLoader: """ Defines a faster dataloader for PyTorch tensors. diff --git a/src/midst_toolkit/models/clavaddpm/params.py b/src/midst_toolkit/models/clavaddpm/params.py index 38217fc0..5d5f2080 100644 --- a/src/midst_toolkit/models/clavaddpm/params.py +++ b/src/midst_toolkit/models/clavaddpm/params.py @@ -1,7 +1,14 @@ +from collections.abc import Callable from typing import Any +import numpy as np +from torch import nn + # TODO: Temporary, will wtich to classes later Configs = dict[str, Any] Tables = dict[str, dict[str, Any]] RelationOrder = list[tuple[str, str]] + +ArrayDict = dict[str, np.ndarray] +ModuleType = str | Callable[..., nn.Module] diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index 449168f1..da31c6f4 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -12,16 +12,12 @@ from torch import Tensor, optim from midst_toolkit.core import logger +from midst_toolkit.models.clavaddpm.dataset import Dataset, Transformations, get_T_dict, make_dataset_from_df from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion from midst_toolkit.models.clavaddpm.model import ( Classifier, - Dataset, - Transformations, get_model, - get_model_params, - get_T_dict, get_table_info, - make_dataset_from_df, prepare_fast_dataloader, ) from midst_toolkit.models.clavaddpm.params import Configs, RelationOrder, Tables @@ -173,7 +169,7 @@ def child_training( else: y_col = f"{parent_name}_{child_name}_cluster" child_info = get_table_info(child_df_with_cluster, child_domain_dict, y_col) - child_model_params = get_model_params( + child_model_params = _get_model_params( { "d_layers": diffusion_config["d_layers"], "dropout": diffusion_config["dropout"], @@ -590,3 +586,36 @@ def _split_microbatches( else: for i in range(0, bs, microbatch): yield batch[i : i + microbatch], labels[i : i + microbatch], t[i : i + microbatch] + + +# TODO make this into a class with default parameters +def _get_model_params(rtdl_params: dict[str, Any] | None = None) -> dict[str, Any]: + """ + Return the model parameters. + + Args: + rtdl_params: The parameters for the RTDL model. If None, the default parameters below are used: + { + "d_layers": [512, 1024, 1024, 1024, 1024, 512], + "dropout": 0.0, + } + + Returns: + The model parameters as a dictionary containing the following keys: + - num_classes: The number of classes. Defaults to 0. + - is_y_cond: Affects how y is generated. For more information, see the documentation + of the `make_dataset_from_df` function. Can be any of ["none", "concat", "embedding"]. + Defaults to "none". + - rtdl_params: The parameters for the RTDL model. + """ + if rtdl_params is None: + rtdl_params = { + "d_layers": [512, 1024, 1024, 1024, 1024, 512], + "dropout": 0.0, + } + + return { + "num_classes": 0, + "is_y_cond": "none", + "rtdl_params": rtdl_params, + } From 9cc742189bfcd51dcde68ccd49559f744462680d Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 6 Aug 2025 17:52:31 -0300 Subject: [PATCH 14/39] WIP David's and Fatemeh's CR --- src/midst_toolkit/models/clavaddpm/train.py | 27 +++++++++++++------ src/midst_toolkit/models/clavaddpm/trainer.py | 10 +++---- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index d95685dd..ec6457e5 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -45,7 +45,7 @@ def clava_clustering( configs: Configs, ) -> tuple[dict[str, Any], dict[tuple[str, str], dict[int, float]]]: """ - Clustering function for the mutli-table function of theClavaDDPM model. + Clustering function for the mutli-table function of the ClavaDDPM model. Args: tables: Definition of the tables and their relations. Example: @@ -69,6 +69,10 @@ def clava_clustering( clustering_method = str["kmeans" | "both" | "variational" | "gmm"], } + Returns: + A tuple with 2 values: + - The tables dictionary. + - The dictionary with the group lengths probability for all the parent-child pairs. """ relation_order_reversed = relation_order[::-1] all_group_lengths_prob_dicts = {} @@ -76,8 +80,10 @@ def clava_clustering( # Clustering if os.path.exists(save_dir / "cluster_ckpt.pkl"): print("Clustering checkpoint found, loading...") - cluster_ckpt = pickle.load(open(save_dir / "cluster_ckpt.pkl", "rb")) - # ruff: noqa: SIM115 + + with open(save_dir / "cluster_ckpt.pkl", "rb") as f: + cluster_ckpt = pickle.load(f) + tables = cluster_ckpt["tables"] all_group_lengths_prob_dicts = cluster_ckpt["all_group_lengths_prob_dicts"] else: @@ -114,8 +120,8 @@ def clava_clustering( "tables": tables, "all_group_lengths_prob_dicts": all_group_lengths_prob_dicts, } - pickle.dump(cluster_ckpt, open(save_dir / "cluster_ckpt.pkl", "wb")) - # ruff: noqa: SIM115 + with open(save_dir / "cluster_ckpt.pkl", "wb") as f: + pickle.dump(cluster_ckpt, f) for parent, child in relation_order: if parent is None: @@ -259,6 +265,9 @@ def child_training( Dictionary of the training results. """ if parent_name is None: + # If there is no parent for this child table, just set a placeholder + # for its column name. This can happen on single table training or + # when the table is on the top level of the hierarchy. y_col = "placeholder" child_df_with_cluster["placeholder"] = list(range(len(child_df_with_cluster))) else: @@ -311,6 +320,8 @@ def child_training( device=device, ) child_result["classifier"] = child_classifier + else: + LOGGER.warning("Skipping classifier training since classifier_config['iterations'] <= 0") child_result["df_info"] = child_info child_result["model_params"] = child_model_params @@ -406,7 +417,7 @@ def train_model( steps=steps, device=device, ) - trainer.run_loop() + trainer.train() if model_params["is_y_cond"] == "concat": column_orders = column_orders[1:] + [column_orders[0]] @@ -454,11 +465,11 @@ def train_classifier( d_layers: List of the hidden sizes of the classifier. device: Device to use for training. Default is `"cuda"`. cluster_col: Name of the cluster column. Default is `"cluster"`. - dim_t: Dimension of the transformer. Default is 128. + dim_t: Dimension of the timestamp. Default is 128. lr: Learning rate to use for the classifier. Default is 0.0001. Returns: - The classifier model. + The trained classifier model. """ T = Transformations(**T_dict) # ruff: noqa: N806 diff --git a/src/midst_toolkit/models/clavaddpm/trainer.py b/src/midst_toolkit/models/clavaddpm/trainer.py index b3c47704..58b0cddc 100644 --- a/src/midst_toolkit/models/clavaddpm/trainer.py +++ b/src/midst_toolkit/models/clavaddpm/trainer.py @@ -60,13 +60,13 @@ def _anneal_lr(self, step: int) -> None: for param_group in self.optimizer.param_groups: param_group["lr"] = lr - def _run_step(self, x: Tensor, target: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + def _train_step(self, x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: """ Run a single step of the training loop. Args: x: The input tensor. - target: The target dictionary (model output). + y: The output tensor. Returns: A tuple with 2 values: @@ -74,6 +74,7 @@ def _run_step(self, x: Tensor, target: dict[str, Tensor]) -> tuple[Tensor, Tenso - The Gaussian loss. """ x = x.to(self.device) + target = {"y": y} for k, v in target.items(): target[k] = v.long().to(self.device) self.optimizer.zero_grad() @@ -84,7 +85,7 @@ def _run_step(self, x: Tensor, target: dict[str, Tensor]) -> tuple[Tensor, Tenso return loss_multi, loss_gauss - def run_loop(self) -> None: + def train(self) -> None: """Run the training loop.""" step = 0 curr_loss_multi = 0.0 @@ -93,8 +94,7 @@ def run_loop(self) -> None: curr_count = 0 while step < self.steps: x, out = next(self.train_iter) - out_dict = {"y": out} - batch_loss_multi, batch_loss_gauss = self._run_step(x, out_dict) + batch_loss_multi, batch_loss_gauss = self._train_step(x, out) self._anneal_lr(step) From 7b8e18742b65fad41d4fed8a306b7b7a1093f90b Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 6 Aug 2025 18:14:38 -0300 Subject: [PATCH 15/39] WIP merging parent --- src/midst_toolkit/models/clavaddpm/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/midst_toolkit/models/clavaddpm/model.py b/src/midst_toolkit/models/clavaddpm/model.py index d5225014..849a3218 100644 --- a/src/midst_toolkit/models/clavaddpm/model.py +++ b/src/midst_toolkit/models/clavaddpm/model.py @@ -9,7 +9,7 @@ from torch import Tensor, nn from midst_toolkit.models.clavaddpm.dataset import Dataset -from midst_toolkit.models.clavaddpm.params import ModuleType +from midst_toolkit.models.clavaddpm.typing import ModuleType class Classifier(nn.Module): From fcc00e9dbfcc2ac77ef94d8080f013a324c9fbf0 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Thu, 7 Aug 2025 11:29:03 -0300 Subject: [PATCH 16/39] David's CR: renamings --- src/midst_toolkit/models/clavaddpm/train.py | 83 ++++++++++----------- 1 file changed, 41 insertions(+), 42 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index ec6457e5..37c6321a 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -268,6 +268,7 @@ def child_training( # If there is no parent for this child table, just set a placeholder # for its column name. This can happen on single table training or # when the table is on the top level of the hierarchy. + # TODO: find a better name for this variable y_col = "placeholder" child_df_with_cluster["placeholder"] = list(range(len(child_df_with_cluster))) else: @@ -316,7 +317,7 @@ def child_training( cluster_col=y_col, d_layers=classifier_config["d_layers"], dim_t=classifier_config["dim_t"], - lr=classifier_config["lr"], + learning_rate=classifier_config["lr"], device=device, ) child_result["classifier"] = child_classifier @@ -330,18 +331,17 @@ def child_training( def train_model( - df: pd.DataFrame, - df_info: pd.DataFrame, + data_frame: pd.DataFrame, + data_frame_info: pd.DataFrame, model_params: dict[str, Any], - T_dict: dict[str, Any], - # ruff: noqa: N803 + transformations_dict: dict[str, Any], steps: int, batch_size: int, model_type: str, gaussian_loss_type: str, num_timesteps: int, scheduler: str, - lr: float, + learning_rate: float, weight_decay: float, device: str = "cuda", ) -> dict[str, Any]: @@ -349,18 +349,18 @@ def train_model( Training function for the diffusion model. Args: - df: DataFrame to train the model on. - df_info: Dictionary of the table information. + data_frame: DataFrame to train the model on. + data_frame_info: Dictionary of the table information. model_params: Dictionary of the model parameters. - T_dict: Dictionary of the transformations. + transformations_dict: Dictionary of the transformations. steps: Number of steps to train the model. batch_size: Batch size to use for training. model_type: Type of the model to use. gaussian_loss_type: Type of the gaussian loss to use. num_timesteps: Number of timesteps to use for the diffusion model. scheduler: Scheduler to use for the diffusion model. - lr: Learning rate to use for the diffusion model. - weight_decay: Weight decay to use for the diffusion model. + learning_rate: Learning rate to use for the optimizer in the diffusion model. + weight_decay: Weight decay to use for the optimizer in the diffusion model. device: Device to use for training. Default is `"cuda"`. Returns: @@ -370,25 +370,25 @@ def train_model( - dataset: The dataset. - column_orders: The column orders. """ - T = Transformations(**T_dict) + transformations = Transformations(**transformations_dict) # ruff: noqa: N806 dataset, label_encoders, column_orders = make_dataset_from_df( - df, - T, + data_frame, + transformations, is_y_cond=model_params["is_y_cond"], ratios=[0.99, 0.005, 0.005], - df_info=df_info, + df_info=data_frame_info, std=0, ) - K = np.array(dataset.get_category_sizes("train")) + category_sizes = np.array(dataset.get_category_sizes("train")) # ruff: noqa: N806 - if len(K) == 0 or T_dict["cat_encoding"] == "one-hot": - K = np.array([0]) + if len(category_sizes) == 0 or transformations_dict["cat_encoding"] == "one-hot": + category_sizes = np.array([0]) # ruff: noqa: N806 num_numerical_features = dataset.X_num["train"].shape[1] if dataset.X_num is not None else 0 - d_in = np.sum(K) + num_numerical_features + d_in = np.sum(category_sizes) + num_numerical_features model_params["d_in"] = d_in print("Model params: {}".format(model_params)) @@ -398,7 +398,7 @@ def train_model( train_loader = prepare_fast_dataloader(dataset, split="train", batch_size=batch_size) diffusion = GaussianMultinomialDiffusion( - num_classes=K, + num_classes=category_sizes, num_numerical_features=num_numerical_features, denoise_fn=model, gaussian_loss_type=gaussian_loss_type, @@ -412,7 +412,7 @@ def train_model( trainer = ClavaDDPMTrainer( diffusion, train_loader, - lr=lr, + lr=learning_rate, weight_decay=weight_decay, steps=steps, device=device, @@ -422,7 +422,7 @@ def train_model( if model_params["is_y_cond"] == "concat": column_orders = column_orders[1:] + [column_orders[0]] else: - column_orders = column_orders + [df_info["y_col"]] + column_orders = column_orders + [data_frame_info["y_col"]] return { "diffusion": diffusion, @@ -433,11 +433,10 @@ def train_model( def train_classifier( - df: pd.DataFrame, - df_info: pd.DataFrame, + data_frame: pd.DataFrame, + data_frame_info: pd.DataFrame, model_params: dict[str, Any], - T_dict: dict[str, Any], - # ruff: noqa: N803 + transformations_dict: dict[str, Any], classifier_steps: int, batch_size: int, gaussian_loss_type: str, @@ -447,16 +446,16 @@ def train_classifier( device: str = "cuda", cluster_col: str = "cluster", dim_t: int = 128, - lr: float = 0.0001, + learning_rate: float = 0.0001, ) -> Classifier: """ Training function for the classifier model. Args: - df: DataFrame to train the model on. - df_info: Dictionary of the table information. + data_frame: DataFrame to train the model on. + data_frame_info: Dictionary of the table information. model_params: Dictionary of the model parameters. - T_dict: Dictionary of the transformations. + transformations_dict: Dictionary of the transformations. classifier_steps: Number of steps to train the classifier. batch_size: Batch size to use for training. gaussian_loss_type: Type of the gaussian loss to use. @@ -466,19 +465,19 @@ def train_classifier( device: Device to use for training. Default is `"cuda"`. cluster_col: Name of the cluster column. Default is `"cluster"`. dim_t: Dimension of the timestamp. Default is 128. - lr: Learning rate to use for the classifier. Default is 0.0001. + learning_rate: Learning rate to use for the optimizer in the classifier. Default is 0.0001. Returns: The trained classifier model. """ - T = Transformations(**T_dict) + transformations = Transformations(**transformations_dict) # ruff: noqa: N806 dataset, label_encoders, column_orders = make_dataset_from_df( - df, - T, + data_frame, + transformations, is_y_cond=model_params["is_y_cond"], ratios=[0.99, 0.005, 0.005], - df_info=df_info, + df_info=data_frame_info, std=0, ) print(dataset.n_features) @@ -488,12 +487,12 @@ def train_classifier( eval_interval = 5 - K = np.array(dataset.get_category_sizes("train")) + category_sizes = np.array(dataset.get_category_sizes("train")) # ruff: noqa: N806 - if len(K) == 0 or T_dict["cat_encoding"] == "one-hot": - K = np.array([0]) + if len(category_sizes) == 0 or transformations_dict["cat_encoding"] == "one-hot": + category_sizes = np.array([0]) # ruff: noqa: N806 - print(K) + print(category_sizes) num_numerical_features = dataset.X_num["train"].shape[1] if dataset.X_num is not None else 0 if model_params["is_y_cond"] == "concat": @@ -501,15 +500,15 @@ def train_classifier( classifier = Classifier( d_in=num_numerical_features, - d_out=int(max(df[cluster_col].values) + 1), + d_out=int(max(data_frame[cluster_col].values) + 1), dim_t=dim_t, hidden_sizes=d_layers, ).to(device) - classifier_optimizer = optim.AdamW(classifier.parameters(), lr=lr) + classifier_optimizer = optim.AdamW(classifier.parameters(), lr=learning_rate) empty_diffusion = GaussianMultinomialDiffusion( - num_classes=K, + num_classes=category_sizes, num_numerical_features=num_numerical_features, denoise_fn=None, # type: ignore[arg-type] gaussian_loss_type=gaussian_loss_type, From f20f4820cc59b1885c55ce2d534cf28dc668724d Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Thu, 7 Aug 2025 13:06:49 -0300 Subject: [PATCH 17/39] David's CR: remaining comments --- src/midst_toolkit/models/clavaddpm/model.py | 2 ++ src/midst_toolkit/models/clavaddpm/train.py | 25 ++++++++++++------- src/midst_toolkit/models/clavaddpm/trainer.py | 21 ++++++++++------ 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/model.py b/src/midst_toolkit/models/clavaddpm/model.py index f636b552..6a7d18cd 100644 --- a/src/midst_toolkit/models/clavaddpm/model.py +++ b/src/midst_toolkit/models/clavaddpm/model.py @@ -839,6 +839,8 @@ def make_dataset_from_df( return transform_dataset(D, T, None), label_encoders, column_orders +# TODO: can this be refactored in a way it does not return a generator but +# rather an instance of FastTensorDataLoader and the generator is returned later? def prepare_fast_dataloader( D: Dataset, # ruff: noqa: N803 diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index 37c6321a..3b383c5c 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -447,6 +447,7 @@ def train_classifier( cluster_col: str = "cluster", dim_t: int = 128, learning_rate: float = 0.0001, + classifier_evaluation_interval: int = 5, ) -> Classifier: """ Training function for the classifier model. @@ -466,6 +467,8 @@ def train_classifier( cluster_col: Name of the cluster column. Default is `"cluster"`. dim_t: Dimension of the timestamp. Default is 128. learning_rate: Learning rate to use for the optimizer in the classifier. Default is 0.0001. + classifier_evaluation_interval: The amount of classifier_steps to wait + until the next evaluation of the classifier. Default is 5. Returns: The trained classifier model. @@ -485,8 +488,6 @@ def train_classifier( val_loader = prepare_fast_dataloader(dataset, split="val", batch_size=batch_size, y_type="long") test_loader = prepare_fast_dataloader(dataset, split="test", batch_size=batch_size, y_type="long") - eval_interval = 5 - category_sizes = np.array(dataset.get_category_sizes("train")) # ruff: noqa: N806 if len(category_sizes) == 0 or transformations_dict["cat_encoding"] == "one-hot": @@ -494,13 +495,19 @@ def train_classifier( # ruff: noqa: N806 print(category_sizes) - num_numerical_features = dataset.X_num["train"].shape[1] if dataset.X_num is not None else 0 + # TODO: understand what's going on here + if dataset.X_num is None: + LOGGER.warning("dataset.X_num is None. num_numerical_features will be set to 0") + num_numerical_features = 0 + else: + num_numerical_features = dataset.X_num["train"].shape[1] + if model_params["is_y_cond"] == "concat": num_numerical_features -= 1 classifier = Classifier( d_in=num_numerical_features, - d_out=int(max(data_frame[cluster_col].values) + 1), + d_out=int(max(data_frame[cluster_col].values) + 1), # TODO: add a comment why we need to add 1 dim_t=dim_t, hidden_sizes=d_layers, ).to(device) @@ -521,12 +528,11 @@ def train_classifier( schedule_sampler = create_named_schedule_sampler("uniform", empty_diffusion) classifier.train() - resume_step = 0 for step in range(classifier_steps): - logger.logkv("step", step + resume_step) + logger.logkv("step", step) logger.logkv( "samples", - (step + resume_step + 1) * batch_size, + (step + 1) * batch_size, ) numerical_forward_backward_log( classifier, @@ -540,7 +546,7 @@ def train_classifier( ) classifier_optimizer.step() - if not step % eval_interval: + if not step % classifier_evaluation_interval: with torch.no_grad(): classifier.eval() numerical_forward_backward_log( @@ -555,10 +561,11 @@ def train_classifier( ) classifier.train() - # # test classifier + # test classifier classifier.eval() correct = 0 + # TODO: why 3000 iterations? Why not just run through the test_loader once? Maybe it's a probabilistic classifier? for _ in range(3000): test_x, test_y = next(test_loader) test_y = test_y.long().to(device) diff --git a/src/midst_toolkit/models/clavaddpm/trainer.py b/src/midst_toolkit/models/clavaddpm/trainer.py index 58b0cddc..c727ee48 100644 --- a/src/midst_toolkit/models/clavaddpm/trainer.py +++ b/src/midst_toolkit/models/clavaddpm/trainer.py @@ -14,7 +14,7 @@ class ClavaDDPMTrainer: def __init__( self, - diffusion: GaussianMultinomialDiffusion, + diffusion_model: GaussianMultinomialDiffusion, train_iter: Generator[tuple[Tensor, ...]], lr: float, weight_decay: float, @@ -25,7 +25,7 @@ def __init__( Trainer class for the ClavaDDPM model. Args: - diffusion: The diffusion model. + diffusion_model: The diffusion model. train_iter: The training iterator. It should yield a tuple of tensors. The first tensor is the input tensor and the second tensor is the output tensor. lr: The learning rate. @@ -33,15 +33,15 @@ def __init__( steps: The number of steps to train. device: The device to use. Default is `"cuda"`. """ - self.diffusion = diffusion - self.ema_model = deepcopy(self.diffusion._denoise_fn) + self.diffusion_model = diffusion_model + self.ema_model = deepcopy(self.diffusion_model._denoise_fn) for param in self.ema_model.parameters(): param.detach_() self.train_iter = train_iter self.steps = steps self.init_lr = lr - self.optimizer = torch.optim.AdamW(self.diffusion.parameters(), lr=lr, weight_decay=weight_decay) + self.optimizer = torch.optim.AdamW(self.diffusion_model.parameters(), lr=lr, weight_decay=weight_decay) self.device = device self.loss_history = pd.DataFrame(columns=["step", "mloss", "gloss", "loss"]) self.log_every = 100 @@ -78,7 +78,7 @@ def _train_step(self, x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: for k, v in target.items(): target[k] = v.long().to(self.device) self.optimizer.zero_grad() - loss_multi, loss_gauss = self.diffusion.mixed_loss(x, target) + loss_multi, loss_gauss = self.diffusion_model.mixed_loss(x, target) loss = loss_multi + loss_gauss loss.backward() # type: ignore[no-untyped-call] self.optimizer.step() @@ -93,6 +93,8 @@ def train(self) -> None: curr_count = 0 while step < self.steps: + # TODO: improve this design. If self.steps is larger than self.train_iter, + # it will lead to a StopIteration error. x, out = next(self.train_iter) batch_loss_multi, batch_loss_gauss = self._train_step(x, out) @@ -102,11 +104,14 @@ def train(self) -> None: curr_loss_multi += batch_loss_multi.item() * len(x) curr_loss_gauss += batch_loss_gauss.item() * len(x) + # TODO: improve this code, starting by moving it into a function for better readability and modularity. if (step + 1) % self.log_every == 0: mloss = np.around(curr_loss_multi / curr_count, 4) gloss = np.around(curr_loss_gauss / curr_count, 4) if (step + 1) % self.print_every == 0: print(f"Step {(step + 1)}/{self.steps} MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}") + + # TODO: switch this for a concat for better code readability self.loss_history.loc[len(self.loss_history)] = [ step + 1, mloss, @@ -117,7 +122,7 @@ def train(self) -> None: curr_loss_gauss = 0.0 curr_loss_multi = 0.0 - update_ema(self.ema_model.parameters(), self.diffusion._denoise_fn.parameters()) + update_ema(self.ema_model.parameters(), self.diffusion_model._denoise_fn.parameters()) step += 1 @@ -137,4 +142,6 @@ def update_ema( rate: the EMA rate (closer to 1 means slower). """ for targ, src in zip(target_params, source_params): + # TODO: is this doing anything at all? The detach functions will create new tensors, + # so this will not modify the original tensors, and this function does not return anything. targ.detach().mul_(rate).add_(src.detach(), alpha=1 - rate) From 3213772ddc3b3133c9b90551acec9584bf385237 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Mon, 8 Sep 2025 16:23:08 -0400 Subject: [PATCH 18/39] Removing unused ignore --- src/midst_toolkit/models/clavaddpm/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/midst_toolkit/models/clavaddpm/trainer.py b/src/midst_toolkit/models/clavaddpm/trainer.py index c727ee48..56257008 100644 --- a/src/midst_toolkit/models/clavaddpm/trainer.py +++ b/src/midst_toolkit/models/clavaddpm/trainer.py @@ -80,7 +80,7 @@ def _train_step(self, x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: self.optimizer.zero_grad() loss_multi, loss_gauss = self.diffusion_model.mixed_loss(x, target) loss = loss_multi + loss_gauss - loss.backward() # type: ignore[no-untyped-call] + loss.backward() self.optimizer.step() return loss_multi, loss_gauss From 9541cce27abcb9b83cd7dca98fe706d92d2b3890 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 12 Sep 2025 13:22:14 -0400 Subject: [PATCH 19/39] Little refactorings --- .../models/clavaddpm/clustering.py | 152 +++++++++++------- src/midst_toolkit/models/clavaddpm/typing.py | 3 +- 2 files changed, 97 insertions(+), 58 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/clustering.py b/src/midst_toolkit/models/clavaddpm/clustering.py index e7fa3889..82097761 100644 --- a/src/midst_toolkit/models/clavaddpm/clustering.py +++ b/src/midst_toolkit/models/clavaddpm/clustering.py @@ -1,4 +1,4 @@ -"""Clustering functions for the multi-tableClavaDDPM model.""" +"""Clustering functions for the multi-table ClavaDDPM model.""" import os import pickle @@ -12,7 +12,7 @@ from sklearn.mixture import BayesianGaussianMixture, GaussianMixture from sklearn.preprocessing import LabelEncoder, MinMaxScaler, OneHotEncoder, QuantileTransformer -from midst_toolkit.models.clavaddpm.typing import Configs, RelationOrder, Tables +from midst_toolkit.models.clavaddpm.typing import Configs, GroupLengthsProbDicts, RelationOrder, Tables def clava_clustering( @@ -20,7 +20,7 @@ def clava_clustering( relation_order: RelationOrder, save_dir: Path, configs: Configs, -) -> tuple[dict[str, Any], dict[tuple[str, str], dict[int, float]]]: +) -> tuple[dict[str, Any], GroupLengthsProbDicts]: """ Clustering function for the mutli-table function of the ClavaDDPM model. @@ -51,48 +51,15 @@ def clava_clustering( - The tables dictionary. - The dictionary with the group lengths probability for all the parent-child pairs. """ - relation_order_reversed = relation_order[::-1] - all_group_lengths_prob_dicts = {} - - # Clustering - if os.path.exists(save_dir / "cluster_ckpt.pkl"): - print("Clustering checkpoint found, loading...") - - with open(save_dir / "cluster_ckpt.pkl", "rb") as f: - cluster_ckpt = pickle.load(f) - + cluster_ckpt = _load_clustering_info_from_checkpoint(save_dir) + if cluster_ckpt is not None: tables = cluster_ckpt["tables"] all_group_lengths_prob_dicts = cluster_ckpt["all_group_lengths_prob_dicts"] + else: - for parent, child in relation_order_reversed: - if parent is not None: - print(f"Clustering {parent} -> {child}") - if isinstance(configs["num_clusters"], dict): - num_clusters = configs["num_clusters"][child] - else: - num_clusters = configs["num_clusters"] - ( - parent_df_with_cluster, - child_df_with_cluster, - group_lengths_prob_dicts, - ) = _pair_clustering_keep_id( - tables[child]["df"], - tables[child]["domain"], - tables[parent]["df"], - tables[parent]["domain"], - f"{child}_id", - f"{parent}_id", - num_clusters, - configs["parent_scale"], - 1, # not used for now - parent, - child, - clustering_method=configs["clustering_method"], - ) - tables[parent]["df"] = parent_df_with_cluster - tables[child]["df"] = child_df_with_cluster - all_group_lengths_prob_dicts[(parent, child)] = group_lengths_prob_dicts + tables, all_group_lengths_prob_dicts = _run_clustering(tables, relation_order, configs) + # saving the clustering information in the checkpoint file cluster_ckpt = { "tables": tables, "all_group_lengths_prob_dicts": all_group_lengths_prob_dicts, @@ -100,6 +67,7 @@ def clava_clustering( with open(save_dir / "cluster_ckpt.pkl", "wb") as f: pickle.dump(cluster_ckpt, f) + # adding a placeholder for the top level tables (i.e. tables with no parent) for parent, child in relation_order: if parent is None: tables[child]["df"]["placeholder"] = list(range(len(tables[child]["df"]))) @@ -107,19 +75,88 @@ def clava_clustering( return tables, all_group_lengths_prob_dicts +def _load_clustering_info_from_checkpoint(save_dir: Path) -> dict[str, Any] | None: + """ + Load the clustering information from the checkpoint if it exists. + + Args: + save_dir: Directory to save the clustering checkpoint. + + Returns: + Clustering information as a dictionary if the checkpoint exists, None otherwise. + The dictionary contains the tables under the "tables" key and the group lengths + probabilities under the "all_group_lengths_prob_dicts" key. + """ + if not os.path.exists(save_dir / "cluster_ckpt.pkl"): + return None + + print("Clustering checkpoint found, loading...") + + with open(save_dir / "cluster_ckpt.pkl", "rb") as f: + return pickle.load(f) + + +def _run_clustering( + tables: Tables, + relation_order: RelationOrder, + configs: Configs, +) -> tuple[Tables, GroupLengthsProbDicts]: + """ + Run the clustering process. + + Args: + tables: Dictionary of the tables by name. + relation_order: List of tuples of parent and child tables. Example: + [("table1", "table2"), ("table1", "table3")] + configs: Dictionary of configurations. The following config keys are required: + { + num_clusters = int | dict, + parent_scale = float, + clustering_method = str["kmeans" | "both" | "variational" | "gmm"], + } + + Returns: + Tuple with 2 elements: + - The tables dictionary. + - The dictionary with the group lengths probability for all the parent-child pairs. + """ + all_group_lengths_prob_dicts = {} + relation_order_reversed = relation_order[::-1] + for parent, child in relation_order_reversed: + if parent is not None: + print(f"Clustering {parent} -> {child}") + if isinstance(configs["num_clusters"], dict): + num_clusters = configs["num_clusters"][child] + else: + num_clusters = configs["num_clusters"] + ( + parent_df_with_cluster, + child_df_with_cluster, + group_lengths_prob_dicts, + ) = _pair_clustering_keep_id( + tables, + child, + parent, + num_clusters, + configs["parent_scale"], + 1, # not used for now + clustering_method=configs["clustering_method"], + ) + tables[parent]["df"] = parent_df_with_cluster + tables[child]["df"] = child_df_with_cluster + all_group_lengths_prob_dicts[(parent, child)] = group_lengths_prob_dicts + + return tables, all_group_lengths_prob_dicts + + def _pair_clustering_keep_id( # ruff: noqa: PLR0912, PLR0915 - child_df: pd.DataFrame, - child_domain_dict: dict[str, Any], - parent_df: pd.DataFrame, - parent_domain_dict: dict[str, Any], - child_primary_key: str, - parent_primary_key: str, + tables: Tables, + child_name: str, + parent_name: str, num_clusters: int, parent_scale: float, key_scale: float, - parent_name: str, - child_name: str, clustering_method: Literal["kmeans", "both", "variational", "gmm"] = "kmeans", ) -> tuple[pd.DataFrame, pd.DataFrame, dict[int, dict[int, float]]]: """ @@ -128,17 +165,12 @@ def _pair_clustering_keep_id( Used by the mutli-table function of the ClavaDDPM model. Args: - child_df: DataFrame of the child table, as provided by the load_multi_table function. - child_domain_dict: Dictionary of the child table domain, as provided by the load_multi_table function. - parent_df: DataFrame of the parent table, as provided by the load_multi_table function. - parent_domain_dict: Dictionary of the parent table domain, as provided by the load_multi_table function. - child_primary_key: Name of the child primary key. - parent_primary_key: Name of the parent primary key. + tables: Dictionary of the tables by name. + parent_name: Name of the parent table. + child_name: Name of the child table. num_clusters: Number of clusters. parent_scale: Scale of the parent table, provided by the config. key_scale: Scale of the key. - parent_name: Name of the parent table. - child_name: Name of the child table. clustering_method: Method of clustering. Has to be one of ["kmeans", "both", "variational", "gmm"]. Default is "kmeans". @@ -148,6 +180,12 @@ def _pair_clustering_keep_id( - child_df_with_cluster: DataFrame of the child table with the cluster column. - group_lengths_prob_dicts: Dictionary of group lengths and probabilities. """ + child_df = tables[child_name]["df"] + parent_df = tables[parent_name]["df"] + child_domain_dict = tables[child_name]["domain"] + parent_domain_dict = tables[parent_name]["domain"] + child_primary_key = f"{child_name}_id" + parent_primary_key = f"{parent_name}_id" original_child_cols = list(child_df.columns) original_parent_cols = list(parent_df.columns) diff --git a/src/midst_toolkit/models/clavaddpm/typing.py b/src/midst_toolkit/models/clavaddpm/typing.py index 38217fc0..9a571890 100644 --- a/src/midst_toolkit/models/clavaddpm/typing.py +++ b/src/midst_toolkit/models/clavaddpm/typing.py @@ -1,7 +1,8 @@ from typing import Any -# TODO: Temporary, will wtich to classes later +# TODO: Temporary, will switch to classes later Configs = dict[str, Any] Tables = dict[str, dict[str, Any]] RelationOrder = list[tuple[str, str]] +GroupLengthsProbDicts = dict[tuple[str, str], dict[int, dict[int, float]]] From 0534b4195b09be9b6244d97848615361c4dcebcd Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Mon, 15 Sep 2025 14:02:13 -0400 Subject: [PATCH 20/39] Addressing some more comments by David --- .../models/clavaddpm/clustering.py | 187 +++++++++++------- .../models/clavaddpm/test_model.py | 2 +- 2 files changed, 115 insertions(+), 74 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/clustering.py b/src/midst_toolkit/models/clavaddpm/clustering.py index 82097761..0ef96bdc 100644 --- a/src/midst_toolkit/models/clavaddpm/clustering.py +++ b/src/midst_toolkit/models/clavaddpm/clustering.py @@ -12,6 +12,7 @@ from sklearn.mixture import BayesianGaussianMixture, GaussianMixture from sklearn.preprocessing import LabelEncoder, MinMaxScaler, OneHotEncoder, QuantileTransformer +from midst_toolkit.core import logger from midst_toolkit.models.clavaddpm.typing import Configs, GroupLengthsProbDicts, RelationOrder, Tables @@ -112,7 +113,7 @@ def _run_clustering( { num_clusters = int | dict, parent_scale = float, - clustering_method = str["kmeans" | "both" | "variational" | "gmm"], + clustering_method = str["kmeans" | "gmm" | "kmeans_and_gmm" | "variational"], } Returns: @@ -157,7 +158,7 @@ def _pair_clustering_keep_id( num_clusters: int, parent_scale: float, key_scale: float, - clustering_method: Literal["kmeans", "both", "variational", "gmm"] = "kmeans", + clustering_method: Literal["kmeans", "gmm", "kmeans_and_gmm", "variational"] = "kmeans", ) -> tuple[pd.DataFrame, pd.DataFrame, dict[int, dict[int, float]]]: """ Pairs clustering information to the parent and child dataframes. @@ -169,9 +170,12 @@ def _pair_clustering_keep_id( parent_name: Name of the parent table. child_name: Name of the child table. num_clusters: Number of clusters. - parent_scale: Scale of the parent table, provided by the config. - key_scale: Scale of the key. - clustering_method: Method of clustering. Has to be one of ["kmeans", "both", "variational", "gmm"]. + parent_scale: Scaling factor applied to the parent table, provided by the config. + It will be applied to the features to weight their importance during clustering. + key_scale: Scaling factor applied to the foreign key values that link + the child table to the parent table. This will weight how much influence + the parent-child relationship has in the clustering algorithm. + clustering_method: Method of clustering. Has to be one of ["kmeans", "gmm", "kmeans_and_gmm", "variational"]. Default is "kmeans". Returns: @@ -182,53 +186,35 @@ def _pair_clustering_keep_id( """ child_df = tables[child_name]["df"] parent_df = tables[parent_name]["df"] + # The domain dictionary holds metadata about the columns of each one of the tables. child_domain_dict = tables[child_name]["domain"] parent_domain_dict = tables[parent_name]["domain"] child_primary_key = f"{child_name}_id" parent_primary_key = f"{parent_name}_id" - original_child_cols = list(child_df.columns) - original_parent_cols = list(parent_df.columns) + all_child_cols = list(child_df.columns) + all_parent_cols = list(parent_df.columns) - relation_cluster_name = f"{parent_name}_{child_name}_cluster" + # Splitting the data columns into categorical and numerical based on the domain dictionary. + # Columns that are not in the domain dictionary are ignored (except for the primary and foreign keys). + child_num_cols, child_cat_cols = _get_categorical_and_numerical_columns(all_child_cols, child_domain_dict) + parent_num_cols, parent_cat_cols = _get_categorical_and_numerical_columns(all_parent_cols, parent_domain_dict) - child_data = child_df.to_numpy() - parent_data = parent_df.to_numpy() - - child_num_cols = [] - child_cat_cols = [] - - parent_num_cols = [] - parent_cat_cols = [] - - for col_index, col in enumerate(original_child_cols): - if col in child_domain_dict: - if child_domain_dict[col]["type"] == "discrete": - child_cat_cols.append((col_index, col)) - else: - child_num_cols.append((col_index, col)) - - for col_index, col in enumerate(original_parent_cols): - if col in parent_domain_dict: - if parent_domain_dict[col]["type"] == "discrete": - parent_cat_cols.append((col_index, col)) - else: - parent_num_cols.append((col_index, col)) - - parent_primary_key_index = original_parent_cols.index(parent_primary_key) - foreing_key_index = original_child_cols.index(parent_primary_key) + parent_primary_key_index = all_parent_cols.index(parent_primary_key) + foreign_key_index = all_child_cols.index(parent_primary_key) # sort child data by foreign key - sorted_child_data = child_data[np.argsort(child_data[:, foreing_key_index])] - child_group_data_dict = _get_group_data_dict(sorted_child_data, [foreing_key_index]) + child_data = child_df.to_numpy() + sorted_child_data = child_data[np.argsort(child_data[:, foreign_key_index])] + child_group_data_dict = _get_group_data_dict(sorted_child_data, [foreign_key_index]) # sort parent data by primary key + parent_data = parent_df.to_numpy() sorted_parent_data = parent_data[np.argsort(parent_data[:, parent_primary_key_index])] group_lengths = [] unique_group_ids = sorted_parent_data[:, parent_primary_key_index] for group_id in unique_group_ids: - group_id = tuple([group_id]) - # ruff: noqa: C409 + group_id = (group_id,) if group_id not in child_group_data_dict: group_lengths.append(0) else: @@ -237,12 +223,12 @@ def _pair_clustering_keep_id( group_lengths_np = np.array(group_lengths, dtype=int) sorted_parent_data_repeated = np.repeat(sorted_parent_data, group_lengths_np, axis=0) - assert (sorted_parent_data_repeated[:, parent_primary_key_index] == sorted_child_data[:, foreing_key_index]).all() + assert (sorted_parent_data_repeated[:, parent_primary_key_index] == sorted_child_data[:, foreign_key_index]).all() - sorted_child_num_data = sorted_child_data[:, [col_index for col_index, col in child_num_cols]] - sorted_child_cat_data = sorted_child_data[:, [col_index for col_index, col in child_cat_cols]] - sorted_parent_num_data = sorted_parent_data_repeated[:, [col_index for col_index, col in parent_num_cols]] - sorted_parent_cat_data = sorted_parent_data_repeated[:, [col_index for col_index, col in parent_cat_cols]] + sorted_child_num_data = sorted_child_data[:, child_num_cols] + sorted_child_cat_data = sorted_child_data[:, child_cat_cols] + sorted_parent_num_data = sorted_parent_data_repeated[:, parent_num_cols] + sorted_parent_cat_data = sorted_parent_data_repeated[:, parent_cat_cols] joint_num_matrix = np.concatenate([sorted_child_num_data, sorted_parent_num_data], axis=1) joint_cat_matrix = np.concatenate([sorted_child_cat_data, sorted_parent_cat_data], axis=1) @@ -256,6 +242,7 @@ def _pair_clustering_keep_id( for i in range(joint_cat_matrix.shape[1]): # A threshold of 1000 unique values is used to prevent the one-hot encoding of large categorical columns if len(np.unique(joint_cat_matrix[:, i])) > 1000: + logger.warn(f"Categorical column {i} has more than 1000 unique values, skipping...") continue label_encoder = LabelEncoder() cat_converted.append(label_encoder.fit_transform(joint_cat_matrix[:, i]).astype(float)) @@ -279,6 +266,7 @@ def _pair_clustering_keep_id( num_quantile = _quantile_normalize_sklearn(joint_num_matrix) num_min_max = _min_max_normalize_sklearn(joint_num_matrix) + # TODO: change the commented lines below into options/if-conditions. # key_quantile = # quantile_normalize_sklearn(sorted_parent_data_repeated[:, parent_primary_key_index].reshape(-1, 1)) key_min_max = _min_max_normalize_sklearn(sorted_parent_data_repeated[:, parent_primary_key_index].reshape(-1, 1)) @@ -294,7 +282,7 @@ def _pair_clustering_keep_id( else: cluster_data = np.concatenate((num_min_max, key_scaled), axis=1) - child_group_data = _get_group_data(sorted_child_data, [foreing_key_index]) + child_group_data = _get_group_data(sorted_child_data, [foreign_key_index]) child_group_lengths = np.array([len(group) for group in child_group_data], dtype=int) num_clusters = min(num_clusters, len(cluster_data)) @@ -302,7 +290,7 @@ def _pair_clustering_keep_id( kmeans = KMeans(n_clusters=num_clusters, n_init="auto", init="k-means++") kmeans.fit(cluster_data) cluster_labels = kmeans.labels_ - elif clustering_method == "both": + elif clustering_method == "kmeans_and_gmm": gmm = GaussianMixture( n_components=num_clusters, verbose=1, @@ -313,15 +301,15 @@ def _pair_clustering_keep_id( gmm.fit(cluster_data) cluster_labels = gmm.predict(cluster_data) elif clustering_method == "variational": - gmm = BayesianGaussianMixture( + bgmm = BayesianGaussianMixture( n_components=num_clusters, verbose=1, covariance_type="diag", init_params="k-means++", tol=0.0001, ) - gmm.fit(cluster_data) - cluster_labels = gmm.predict_proba(cluster_data) + bgmm.fit(cluster_data) + cluster_labels = bgmm.predict_proba(cluster_data) elif clustering_method == "gmm": gmm = GaussianMixture( n_components=num_clusters, @@ -334,22 +322,9 @@ def _pair_clustering_keep_id( if clustering_method == "variational": group_cluster_labels, agree_rates = _aggregate_and_sample(cluster_labels, child_group_lengths) else: - # voting to determine the cluster label for each parent - group_cluster_labels = [] - curr_index = 0 - agree_rates = [] - for group_length in child_group_lengths: - # First, determine the most common label in the current group - most_common_label_count = np.max(np.bincount(cluster_labels[curr_index : curr_index + group_length])) - group_cluster_label = np.argmax(np.bincount(cluster_labels[curr_index : curr_index + group_length])) - group_cluster_labels.append(int(group_cluster_label)) - - # Compute agree rate using the most common label count - agree_rate = most_common_label_count / group_length - agree_rates.append(agree_rate) - - # Then, update the curr_index for the next iteration - curr_index += group_length + group_cluster_labels, agree_rates = _get_group_cluster_labels_through_voting( + cluster_labels, child_group_lengths + ) # Compute the average agree rate across all groups average_agree_rate = np.mean(agree_rates) @@ -375,9 +350,10 @@ def _pair_clustering_keep_id( group_lengths_prob_dicts[group_label] = _freq_to_prob(freq_dict) # recover the preprocessed data back to dataframe + relation_cluster_name = f"{parent_name}_{child_name}_cluster" child_df_with_cluster = pd.DataFrame( sorted_child_data_with_cluster, - columns=original_child_cols + [relation_cluster_name], + columns=all_child_cols + [relation_cluster_name], ) # recover child df order @@ -390,11 +366,11 @@ def _pair_clustering_keep_id( parent_id_to_cluster: dict[Any, Any] = {} for i in range(len(sorted_child_data)): - parent_id = sorted_child_data[i, foreing_key_index] + parent_id = sorted_child_data[i, foreign_key_index] if parent_id in parent_id_to_cluster: assert parent_id_to_cluster[parent_id] == sorted_child_data_with_cluster[i, -1] - continue - parent_id_to_cluster[parent_id] = sorted_child_data_with_cluster[i, -1] + else: + parent_id_to_cluster[parent_id] = sorted_child_data_with_cluster[i, -1] max_cluster_label = max(parent_id_to_cluster.values()) @@ -407,16 +383,14 @@ def _pair_clustering_keep_id( parent_data_clusters_np = np.array(parent_data_clusters).reshape(-1, 1) parent_data_with_cluster = np.concatenate([parent_data, parent_data_clusters_np], axis=1) - parent_df_with_cluster = pd.DataFrame( - parent_data_with_cluster, columns=original_parent_cols + [relation_cluster_name] - ) + parent_df_with_cluster = pd.DataFrame(parent_data_with_cluster, columns=all_parent_cols + [relation_cluster_name]) new_col_entry = { "type": "discrete", "size": len(set(parent_data_clusters_np.flatten())), } - print("Number of cluster centers: ", len(set(parent_data_clusters_np.flatten()))) + logger.info(f"Number of cluster centers: {new_col_entry['size']}") parent_domain_dict[relation_cluster_name] = new_col_entry.copy() child_domain_dict[relation_cluster_name] = new_col_entry.copy() @@ -424,12 +398,40 @@ def _pair_clustering_keep_id( return parent_df_with_cluster, child_df_with_cluster, group_lengths_prob_dicts +def _get_categorical_and_numerical_columns( + all_columns: list[str], + domain_dictionary: dict[str, Any], +) -> tuple[list[int], list[int]]: + """ + Return the list of numerical and categorical column indices from the domain dictionary. + + Args: + all_columns: List of all columns. + domain_dictionary: Dictionary of the domain. + + Returns: + Tuple with two lists of indices, one for the numerical columns and one for the categorical columns. + """ + numerical_columns = [] + categorical_columns = [] + + for col_index, col in enumerate(all_columns): + if col in domain_dictionary: + if domain_dictionary[col]["type"] == "discrete": + categorical_columns.append(col_index) + else: + numerical_columns.append(col_index) + + return numerical_columns, categorical_columns + + def _get_group_data_dict( np_data: np.ndarray, group_id_attrs: list[int] | None = None, ) -> dict[tuple[Any, ...], list[np.ndarray]]: """ - Get the group data dictionary. + Group rows in a numpy array by their values in specified grouping columns into a dictionary. + Returns a dict where keys are tuples of grouping values and values are lists of corresponding rows. Args: np_data: Numpy array of the data. @@ -457,7 +459,9 @@ def _get_group_data( group_id_attrs: list[int] | None = None, ) -> np.ndarray: """ - Get the group data. + Group consecutive rows in a numpy array based on specified grouping attributes. + Returns an array of arrays where each sub-array contains rows with identical + values in the grouping columns. Args: np_data: Numpy array of the data. @@ -476,6 +480,7 @@ def _get_group_data( group = [] row_id = np_data[i, group_id_attrs] + # TODO refactor this condition to be more readable/understandable. while (np_data[i, group_id_attrs] == row_id).all(): group.append(np_data[i]) i += 1 @@ -578,6 +583,42 @@ def _aggregate_and_sample( return group_cluster_labels, agree_rates +def _get_group_cluster_labels_through_voting( + cluster_labels: np.ndarray, + child_group_lengths: np.ndarray, +) -> tuple[list[int], list[float]]: + """ + Get the group cluster labels through voting. + + Used by the non-variational clustering methods. + + Args: + cluster_labels: Numpy array of the cluster labels. + child_group_lengths: Numpy array of the child group lengths. + + Returns: + Tuple of the group cluster labels and the agree rates. + """ + # voting to determine the cluster label for each parent + group_cluster_labels = [] + curr_index = 0 + agree_rates = [] + for group_length in child_group_lengths: + # First, determine the most common label in the current group + most_common_label_count = np.max(np.bincount(cluster_labels[curr_index : curr_index + group_length])) + group_cluster_label = np.argmax(np.bincount(cluster_labels[curr_index : curr_index + group_length])) + group_cluster_labels.append(int(group_cluster_label)) + + # Compute agree rate using the most common label count + agree_rate = most_common_label_count / group_length + agree_rates.append(agree_rate) + + # Then, update the curr_index for the next iteration + curr_index += group_length + + return group_cluster_labels, agree_rates + + def _freq_to_prob(freq_dict: dict[int, int]) -> dict[int, float]: """ Convert a frequency dictionary to a probability dictionary. diff --git a/tests/integration/models/clavaddpm/test_model.py b/tests/integration/models/clavaddpm/test_model.py index 344b046c..4895e7b8 100644 --- a/tests/integration/models/clavaddpm/test_model.py +++ b/tests/integration/models/clavaddpm/test_model.py @@ -19,7 +19,7 @@ CLUSTERING_CONFIG = { "parent_scale": 1.0, "num_clusters": 3, - "clustering_method": "both", + "clustering_method": "kmeans_and_gmm", } DIFFUSION_CONFIG = { From 96ab05d796b2c6788f5770679be38776e2a58417 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 16 Sep 2025 17:08:56 -0400 Subject: [PATCH 21/39] Addressing comments by David --- src/midst_toolkit/models/clavaddpm/sampler.py | 20 ++++++++++++++----- src/midst_toolkit/models/clavaddpm/train.py | 9 +++++++-- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/sampler.py b/src/midst_toolkit/models/clavaddpm/sampler.py index 9ae23bd9..2020fcc4 100644 --- a/src/midst_toolkit/models/clavaddpm/sampler.py +++ b/src/midst_toolkit/models/clavaddpm/sampler.py @@ -1,6 +1,7 @@ """Samplers for the ClavaDDPM model.""" from abc import ABC, abstractmethod +from typing import Literal import numpy as np import torch @@ -15,9 +16,9 @@ class ScheduleSampler(ABC): variance of the objective. By default, samplers perform unbiased importance sampling, in which the - objective's mean is unchanged. - However, subclasses may override sample() to change how the resampled - terms are reweighted, allowing for actual changes in the objective. + objective's mean is unchanged. However, subclasses may override sample() to + change how the resampled terms are reweighted, allowing for actual changes + in the objective. """ @abstractmethod @@ -29,6 +30,8 @@ def weights(self) -> Tensor: """ def sample(self, batch_size: int, device: str) -> tuple[Tensor, Tensor]: + # TODO: what's happening with batch_size? Is is also the number of timesteps? + # We need to clarify this. """ Importance-sample timesteps for a batch. @@ -182,13 +185,20 @@ def _warmed_up(self) -> bool: return (self._loss_counts == self.history_per_term).all() -def create_named_schedule_sampler(name: str, diffusion: GaussianMultinomialDiffusion) -> ScheduleSampler: +def create_named_schedule_sampler( + name: Literal["uniform", "loss-second-moment"], + diffusion: GaussianMultinomialDiffusion, +) -> ScheduleSampler: """ Create a ScheduleSampler from a library of pre-defined samplers. Args: - name: The name of the sampler. + name: The name of the sampler. Should be one of ["uniform", "loss-second-moment"]. diffusion: The diffusion object to sample for. + + Returns: + The UniformSampler if ``name`` is "uniform", LossSecondMomentResampler if ``name`` + is "loss-second-moment". """ if name == "uniform": return UniformSampler(diffusion) diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index 8a9fb470..a1d63e17 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -4,7 +4,7 @@ import pickle from collections.abc import Generator from pathlib import Path -from typing import Any +from typing import Any, Literal import numpy as np import pandas as pd @@ -615,7 +615,12 @@ def _numerical_forward_backward_log( loss.backward(loss * len(sub_batch) / len(batch)) -def _compute_top_k(logits: Tensor, labels: Tensor, k: int, reduction: str = "mean") -> Tensor: +def _compute_top_k( + logits: Tensor, + labels: Tensor, + k: int, + reduction: Literal["mean", "none"] = "mean", +) -> Tensor: """ Compute the top-k accuracy. From 8b0f214124e147c5a78966d99744b0b86cdaeb90 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 16 Sep 2025 17:19:40 -0400 Subject: [PATCH 22/39] Merge branch 'sampler-module' into data-module --- src/midst_toolkit/models/clavaddpm/dataset.py | 2 +- src/midst_toolkit/models/clavaddpm/model.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/dataset.py b/src/midst_toolkit/models/clavaddpm/dataset.py index e473716e..07d0350d 100644 --- a/src/midst_toolkit/models/clavaddpm/dataset.py +++ b/src/midst_toolkit/models/clavaddpm/dataset.py @@ -28,7 +28,7 @@ StandardScaler, ) -from midst_toolkit.models.clavaddpm.params import ArrayDict +from midst_toolkit.models.clavaddpm.typing import ArrayDict CAT_MISSING_VALUE = "__nan__" diff --git a/src/midst_toolkit/models/clavaddpm/model.py b/src/midst_toolkit/models/clavaddpm/model.py index c9a9d2f4..f1703c43 100644 --- a/src/midst_toolkit/models/clavaddpm/model.py +++ b/src/midst_toolkit/models/clavaddpm/model.py @@ -8,8 +8,6 @@ import torch.nn.functional as F from torch import Tensor, nn -from midst_toolkit.common.enumerations import PredictionType, TaskType - from midst_toolkit.models.clavaddpm.dataset import Dataset from midst_toolkit.models.clavaddpm.typing import ModuleType From d1a51dbf402c93a13487c3e0302ea7b9d6cc2e4e Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 17 Sep 2025 12:10:58 -0400 Subject: [PATCH 23/39] Adding docstrings --- src/midst_toolkit/models/clavaddpm/dataset.py | 392 ++++++++++++++---- 1 file changed, 318 insertions(+), 74 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/dataset.py b/src/midst_toolkit/models/clavaddpm/dataset.py index 07d0350d..8a5fbf57 100644 --- a/src/midst_toolkit/models/clavaddpm/dataset.py +++ b/src/midst_toolkit/models/clavaddpm/dataset.py @@ -8,7 +8,7 @@ from dataclasses import astuple, dataclass, replace from enum import Enum from pathlib import Path -from typing import Any, Literal, cast +from typing import Any, Literal, Self, cast import numpy as np import pandas as pd @@ -48,6 +48,12 @@ class TaskType(Enum): REGRESSION = "regression" def __str__(self) -> str: + """ + Return the string representation of the task type, which is the value of the enum. + + Returns: + The string representation of the task type. + """ return self.value @@ -99,7 +105,16 @@ class Dataset: num_transform: StandardScaler | None = None @classmethod - def from_dir(cls, dir_: Path | str) -> "Dataset": + def from_dir(cls, dir_: Path | str) -> Self: + """ + Load a dataset from a directory. + + Args: + dir_: The directory to load the dataset from. Can be a Path object or a path string. + + Returns: + The loaded dataset. + """ dir_ = Path(dir_) splits = [k for k in ["train", "val", "test"] if dir_.joinpath(f"y_{k}.npy").exists()] @@ -108,11 +123,8 @@ def load(item: str) -> ArrayDict: if Path(dir_ / "info.json").exists(): info = json.loads(Path(dir_ / "info.json").read_text()) - else: - info = None - # ruff: noqa: SIM108 - return Dataset( + return cls( load("X_num") if dir_.joinpath("X_num_train.npy").exists() else None, load("X_cat") if dir_.joinpath("X_cat_train.npy").exists() else None, load("y"), @@ -123,46 +135,119 @@ def load(item: str) -> ArrayDict: @property def is_binclass(self) -> bool: + """ + Check if the dataset is a binary classification dataset. + + Returns: + True if the dataset is a binary classification dataset, False otherwise. + """ return self.task_type == TaskType.BINCLASS @property def is_multiclass(self) -> bool: + """ + Check if the dataset is a multiclass classification dataset. + + Returns: + True if the dataset is a multiclass classification dataset, False otherwise. + """ return self.task_type == TaskType.MULTICLASS @property def is_regression(self) -> bool: + """ + Check if the dataset is a regression dataset. + + Returns: + True if the dataset is a regression dataset, False otherwise. + """ return self.task_type == TaskType.REGRESSION @property def n_num_features(self) -> int: + """ + Get the number of numerical features in the dataset. + + Returns: + The number of numerical features in the dataset. + """ return 0 if self.X_num is None else self.X_num["train"].shape[1] @property def n_cat_features(self) -> int: + """ + Get the number of categorical features in the dataset. + + Returns: + The number of categorical features in the dataset. + """ return 0 if self.X_cat is None else self.X_cat["train"].shape[1] @property def n_features(self) -> int: + """ + Get the total number of features in the dataset. + + Returns: + The total number of features in the dataset. + """ return self.n_num_features + self.n_cat_features - def size(self, part: str | None) -> int: - return sum(map(len, self.y.values())) if part is None else len(self.y[part]) + # TODO: make partition into an Enum + def size(self, partition: Literal["train", "val", "test"] | None) -> int: + """ + Get the size of the dataset. + + Args: + partition: The partition of the dataset to get the size of. + If None, the size of the entire dataset is returned. + + Returns: + The size of the dataset. + """ + return sum(map(len, self.y.values())) if partition is None else len(self.y[partition]) @property def nn_output_dim(self) -> int: + """ + Get the output dimension of the neural network. + + Returns: + The output dimension of the neural network. + """ if self.is_multiclass: assert self.n_classes is not None return self.n_classes return 1 - def get_category_sizes(self, part: str) -> list[int]: - return [] if self.X_cat is None else get_category_sizes(self.X_cat[part]) + def get_category_sizes(self, partition: Literal["train", "val", "test"]) -> list[int]: + """ + Get the size of the categories in the dataset. + + Args: + partition: The partition of the dataset to get the size of the categories of. + Returns: + The size of the categories in the partition of the dataset. + """ + return [] if self.X_cat is None else get_category_sizes(self.X_cat[partition]) + + # TODO: prediciton_type should be of type PredictionType def calculate_metrics( self, predictions: dict[str, np.ndarray], - prediction_type: str | None, + prediction_type: str | PredictionType | None, ) -> dict[str, Any]: + """ + Calculate the metrics of the predictions. + + Args: + predictions: The predictions to calculate the metrics of. + prediction_type: The type of the predictions. + + Returns: + The metrics of the predictions. + """ metrics = { x: calculate_metrics(self.y[x], predictions[x], self.task_type, prediction_type, self.y_info) for x in predictions @@ -180,6 +265,15 @@ def calculate_metrics( # TODO consider moving all the functions below into the Dataset class def get_category_sizes(X: torch.Tensor | np.ndarray) -> list[int]: + """ + Get the size of the categories in the data. + + Args: + X: The data to get the size of the categories of. + + Returns: + A list with the category sizes in the data. + """ XT = X.T.cpu().tolist() if isinstance(X, torch.Tensor) else X.T.tolist() return [len(set(x)) for x in XT] @@ -191,7 +285,21 @@ def calculate_metrics( prediction_type: str | PredictionType | None, y_info: dict[str, Any], ) -> dict[str, Any]: - # Example: calculate_metrics(y_true, y_pred, 'binclass', 'logits', {}) + """ + Calculate the metrics of the predictions. + + Usage: calculate_metrics(y_true, y_pred, 'binclass', 'logits', {}) + + Args: + y_true: The true labels as a numpy array. + y_pred: The predicted labels as a numpy array. + task_type: The type of the task. + prediction_type: The type of the predictions. + y_info: A dictionary with metadata about the labels. + + Returns: + The metrics of the predictions. + """ task_type = TaskType(task_type) if prediction_type is not None: prediction_type = PredictionType(prediction_type) @@ -211,6 +319,18 @@ def calculate_metrics( def calculate_rmse(y_true: np.ndarray, y_pred: np.ndarray, std: float | None) -> float: + """ + Calculate the root mean squared error (RMSE) of the predictions. + + Args: + y_true: The true labels as a numpy array. + y_pred: The predicted labels as a numpy array. + std: The standard deviation of the labels. If None, the RMSE is calculated + without the standard deviation. + + Returns: + The RMSE of the predictions. + """ rmse = mean_squared_error(y_true, y_pred) ** 0.5 if std is not None: rmse *= std @@ -220,6 +340,18 @@ def calculate_rmse(y_true: np.ndarray, y_pred: np.ndarray, std: float | None) -> def _get_labels_and_probs( y_pred: np.ndarray, task_type: TaskType, prediction_type: PredictionType | None ) -> tuple[np.ndarray, np.ndarray | None]: + """ + Get the labels and probabilities from the predictions. + + Args: + y_pred: The predicted labels as a numpy array. + task_type: The type of the task. + prediction_type: The type of the predictions. + + Returns: + A tuple with the labels and probabilities. The probabilities are None + if the prediction_type is None. + """ assert task_type in (TaskType.BINCLASS, TaskType.MULTICLASS) if prediction_type is None: @@ -241,37 +373,50 @@ def make_dataset_from_df( # ruff: noqa: PLR0915, PLR0912 df: pd.DataFrame, T: Transformations, - is_y_cond: str, + # ruff: noqa: N803 + is_y_cond: Literal["concat", "embedding", "none"], df_info: pd.DataFrame, ratios: list[float] | None = None, std: float = 0, ) -> tuple[Dataset, dict[int, LabelEncoder], list[int]]: """ - The order of the generated dataset: (y, X_num, X_cat). + Generate a dataset from a pandas DataFrame. - is_y_cond: - concat: y is concatenated to X, the model learn a joint distribution of (y, X) - embedding: y is not concatenated to X. During computations, y is embedded - and added to the latent vector of X - none: y column is completely ignored - - How does is_y_cond affect the generation of y? - is_y_cond: - concat: the model synthesizes (y, X) directly, so y is just the first column - embedding: y is first sampled using empirical distribution of y. The model only - synthesizes X. When returning the generated data, we return the generated X - and the sampled y. (y is sampled from empirical distribution, instead of being - generated by the model) - Note that in this way, y is still not independent of X, because the model has been - adding the embedding of y to the latent vector of X during computations. - none: - y is synthesized using y's empirical distribution. X is generated by the model. - In this case, y is completely independent of X. + The order of the generated dataset: (y, X_num, X_cat). Note: For now, n_classes has to be set to 0. This is because our matrix is the concatenation of (X_num, X_cat). In this case, if we have is_y_cond == 'concat', we can guarantee that y is the first column of the matrix. However, if we have n_classes > 0, then y is not the first column of the matrix. + + Args: + df: The pandas DataFrame to generate the dataset from. + T: The transformations to apply to the dataset. + is_y_cond: The condition on the y column. + concat: y is concatenated to X, the model learn a joint distribution of (y, X) + embedding: y is not concatenated to X. During computations, y is embedded + and added to the latent vector of X + none: y column is completely ignored + + How does is_y_cond affect the generation of y? + is_y_cond: + concat: the model synthesizes (y, X) directly, so y is just the first column + embedding: y is first sampled using empirical distribution of y. The model only + synthesizes X. When returning the generated data, we return the generated X + and the sampled y. (y is sampled from empirical distribution, instead of being + generated by the model) + Note that in this way, y is still not independent of X, because the model has been + adding the embedding of y to the latent vector of X during computations. + none: + y is synthesized using y's empirical distribution. X is generated by the model. + In this case, y is completely independent of X. + + df_info: A dictionary with metadata about the DataFrame. + ratios: The ratios of the dataset to split into train, val, and test. Optional, default is [0.7, 0.2, 0.1]. + std: The standard deviation of the labels. Optional, default is 0. + + Returns: + A tuple with the dataset, the label encoders, and the column orders. """ if ratios is None: ratios = [0.7, 0.2, 0.1] @@ -390,11 +535,26 @@ def transform_dataset( dataset: Dataset, transformations: Transformations, cache_dir: Path | None, - transform_cols_num: int = 0, ) -> Dataset: + """ + Transform the dataset. + + Args: + dataset: The dataset to transform. + transformations: The transformations to apply to the dataset. + cache_dir: The directory to cache the transformed dataset. + Optional, default is None. If not None, will check if the transformations exist in the cache directory. + If they do, will returned the cached transformed dataset. If not, will transform the dataset and cache it. + + Returns: + The transformed dataset. + """ # WARNING: the order of transformations matters. Moreover, the current # implementation is not ideal in that sense. + cache_path = None if cache_dir is not None: + # if cache_dir is not None, will save the cahe file path into the cache_path variable + # so the transformations can be saved in the cache dir transformations_md5 = hashlib.md5(str(transformations).encode("utf-8")).hexdigest() transformations_str = "__".join(map(str, astuple(transformations))) cache_path = cache_dir / f"cache__{transformations_str}__{transformations_md5}.pickle" @@ -404,8 +564,6 @@ def transform_dataset( print(f"Using cached features: {cache_dir.name + '/' + cache_path.name}") return value raise RuntimeError(f"Hash collision for {cache_path}") - else: - cache_path = None if dataset.X_num is not None: dataset = num_process_nans(dataset, transformations.num_nan_policy) @@ -450,23 +608,91 @@ def transform_dataset( if cache_path is not None: dump_pickle((transformations, dataset), cache_path) - # if return_transforms: - # return dataset, num_transform, cat_transform + return dataset def load_pickle(path: Path | str, **kwargs: Any) -> Any: - # ruff: noqa: D103 + """ + Load a pickle file. + + Args: + path: The path to the pickle file. + **kwargs: Additional arguments to pass to the pickle.loads function. + + Returns: + The loaded pickle file. + """ return pickle.loads(Path(path).read_bytes(), **kwargs) def dump_pickle(x: Any, path: Path | str, **kwargs: Any) -> None: - # ruff: noqa: D103 + """ + Dump an object into a pickle file. + + Args: + x: The object to dump. + path: The path to the pickle file. + **kwargs: Additional arguments to pass to the pickle.dumps function. + """ Path(path).write_bytes(pickle.dumps(x, **kwargs)) +# Inspired by: https://github.com/yandex-research/rtdl/blob/a4c93a32b334ef55d2a0559a4407c8306ffeeaee/lib/data.py#L20 +# TODO: fix this hideous output type +def normalize( + X: ArrayDict, + normalization: Normalization, + seed: int | None, + return_normalizer: bool = False, +) -> ArrayDict | tuple[ArrayDict, StandardScaler | MinMaxScaler | QuantileTransformer]: + """ + Normalize the input data. + + Args: + X: The data to normalize. + normalization: The normalization to use. Can be "standard", "minmax", or "quantile". + seed: The seed to use for the random state. Optional, default is None. + return_normalizer: Whether to return the normalizer. Optional, default is False. + + Returns: + The normalized data. If return_normalizer is True, will return a tuple with the + normalized data and the normalizer. + """ + X_train = X["train"] + if normalization == "standard": + normalizer = StandardScaler() + elif normalization == "minmax": + normalizer = MinMaxScaler() + elif normalization == "quantile": + normalizer = QuantileTransformer( + output_distribution="normal", + n_quantiles=max(min(X["train"].shape[0] // 30, 1000), 10), + subsample=int(1e9), + random_state=seed, + ) + else: + raise ValueError(f"Unknown normalization: {normalization}") + normalizer.fit(X_train) + if return_normalizer: + return {k: normalizer.transform(v) for k, v in X.items()}, normalizer + return {k: normalizer.transform(v) for k, v in X.items()} + + +# TODO: is there any relationship between this function and the cat_process_nans function? +# Can they be made a little more similar to each other (in terms of signature)? def num_process_nans(dataset: Dataset, policy: NumNanPolicy | None) -> Dataset: - # ruff: noqa: D103 + """ + Process the NaN values in the dataset. + + Args: + dataset: The dataset to process. + policy: The policy to use to process the NaN values. Can be "drop-rows" or "mean". + Optional, default is None. + + Returns: + The processed dataset. + """ assert dataset.X_num is not None nan_masks = {k: np.isnan(v) for k, v in dataset.X_num.items()} if not any(x.any() for x in nan_masks.values()): @@ -495,36 +721,18 @@ def num_process_nans(dataset: Dataset, policy: NumNanPolicy | None) -> Dataset: return dataset -# Inspired by: https://github.com/yandex-research/rtdl/blob/a4c93a32b334ef55d2a0559a4407c8306ffeeaee/lib/data.py#L20 -def normalize( - X: ArrayDict, - normalization: Normalization, - seed: int | None, - return_normalizer: bool = False, -) -> ArrayDict | tuple[ArrayDict, StandardScaler | MinMaxScaler | QuantileTransformer]: - # ruff: noqa: D103 - X_train = X["train"] - if normalization == "standard": - normalizer = StandardScaler() - elif normalization == "minmax": - normalizer = MinMaxScaler() - elif normalization == "quantile": - normalizer = QuantileTransformer( - output_distribution="normal", - n_quantiles=max(min(X["train"].shape[0] // 30, 1000), 10), - subsample=int(1e9), - random_state=seed, - ) - else: - raise ValueError(f"Unknown normalization: {normalization}") - normalizer.fit(X_train) - if return_normalizer: - return {k: normalizer.transform(v) for k, v in X.items()}, normalizer - return {k: normalizer.transform(v) for k, v in X.items()} +def cat_process_nans(X: ArrayDict, policy: CatNanPolicy | None) -> ArrayDict: + """ + Process the NaN values in the categorical data. + Args: + X: The data to process. + policy: The policy to use to process the NaN values. Can be "most_frequent". + Optional, default is None. -def cat_process_nans(X: ArrayDict, policy: CatNanPolicy | None) -> ArrayDict: - # ruff: noqa: D103 + Returns: + The processed data. + """ assert X is not None nan_masks = {k: v == CAT_MISSING_VALUE for k, v in X.items()} if any(x.any() for x in nan_masks.values()): @@ -543,8 +751,17 @@ def cat_process_nans(X: ArrayDict, policy: CatNanPolicy | None) -> ArrayDict: def cat_drop_rare(X: ArrayDict, min_frequency: float) -> ArrayDict: - # ruff: noqa: D103 - assert 0.0 < min_frequency < 1.0 + """ + Drop the rare categories in the categorical data. + + Args: + X: The data to drop the rare categories from. + min_frequency: The minimum frequency threshold of the categories to keep. Has to be between 0 and 1. + + Returns: + The processed data. + """ + assert 0.0 < min_frequency < 1.0, "min_frequency has to be between 0 and 1" min_count = round(len(X["train"]) * min_frequency) X_new: dict[str, list[Any]] = {x: [] for x in X} for column_idx in range(X["train"].shape[1]): @@ -559,12 +776,28 @@ def cat_drop_rare(X: ArrayDict, min_frequency: float) -> ArrayDict: def cat_encode( X: ArrayDict, - encoding: CatEncoding | None, + encoding: CatEncoding | None, # TODO: add "ordinal" as one of the options, maybe? y_train: np.ndarray | None, seed: int | None, return_encoder: bool = False, -) -> tuple[ArrayDict, bool, Any | None]: # (X, is_converted_to_numerical) - # ruff: noqa: D103 +) -> tuple[ArrayDict, bool, Any | None]: + """ + Encode the categorical data. + + Args: + X: The data to encode. + encoding: The encoding to use. Can be "one-hot" or "counter". Default is None. + If None, will use the "ordinal" encoding. + y_train: The target values. Optional, default is None. Will only be used for the "counter" encoding. + seed: The seed to use for the random state. Optional, default is None. + return_encoder: Whether to return the encoder. Optional, default is False. + + Returns: + A tuple with the following values: + - The encoded data. + - A boolean value indicating if the data was converted to numerical. + - The encoder, if return_encoder is True. None otherwise. + """ if encoding != "counter": y_train = None @@ -621,7 +854,18 @@ def cat_encode( def build_target(y: ArrayDict, policy: YPolicy | None, task_type: TaskType) -> tuple[ArrayDict, dict[str, Any]]: - # ruff: noqa: D103 + """ + Build the target and return the target values metadata. + + Args: + y: The target values. + policy: The policy to use to build the target. Can be "default". Optional, default is None. + If none, it will no-op. + task_type: The type of the task. + + Returns: + A tuple with the target values and the target values metadata. + """ info: dict[str, Any] = {"policy": policy} if policy is None: pass From c6db8fe2d0a217fc6732e27d58818b23e08f70e6 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 17 Sep 2025 13:44:17 -0400 Subject: [PATCH 24/39] Moiving a few more things around, adding docstrings to the model.py file --- src/midst_toolkit/models/clavaddpm/dataset.py | 87 ++++ src/midst_toolkit/models/clavaddpm/model.py | 383 +++++++++++++----- src/midst_toolkit/models/clavaddpm/train.py | 15 +- 3 files changed, 368 insertions(+), 117 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/dataset.py b/src/midst_toolkit/models/clavaddpm/dataset.py index 8a5fbf57..7b857d84 100644 --- a/src/midst_toolkit/models/clavaddpm/dataset.py +++ b/src/midst_toolkit/models/clavaddpm/dataset.py @@ -4,6 +4,7 @@ import json import pickle from collections import Counter +from collections.abc import Generator from copy import deepcopy from dataclasses import astuple, dataclass, replace from enum import Enum @@ -878,3 +879,89 @@ def build_target(y: ArrayDict, policy: YPolicy | None, task_type: TaskType) -> t else: raise ValueError(f"Unknown policy: {policy}") return y, info + + +class FastTensorDataLoader: + """ + Defines a faster dataloader for PyTorch tensors. + + A DataLoader-like object for a set of tensors that can be much faster than + TensorDataset + DataLoader because dataloader grabs individual indices of + the dataset and calls cat (slow). + Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 + """ + + def __init__(self, *tensors: torch.Tensor, batch_size: int = 32, shuffle: bool = False): + """ + Initialize a FastTensorDataLoader. + + Args: + *tensors: tensors to store. Must have the same length @ dim 0. + batch_size: batch size to load. + shuffle: if True, shuffle the data *in-place* whenever an + iterator is created out of this object. + """ + assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) + self.tensors = tensors + + self.dataset_len = self.tensors[0].shape[0] + self.batch_size = batch_size + self.shuffle = shuffle + + # Calculate # batches + n_batches, remainder = divmod(self.dataset_len, self.batch_size) + if remainder > 0: + n_batches += 1 + self.n_batches = n_batches + + def __iter__(self): + """Defines the iterator for the FastTensorDataLoader.""" + if self.shuffle: + r = torch.randperm(self.dataset_len) + self.tensors = [t[r] for t in self.tensors] # type: ignore[assignment] + self.i = 0 + return self + + def __next__(self): + """Get the next batch of data from the dataset.""" + if self.i >= self.dataset_len: + raise StopIteration + batch = tuple(t[self.i : self.i + self.batch_size] for t in self.tensors) + self.i += self.batch_size + return batch + + def __len__(self): + """Get the number of batches in the dataset.""" + return self.n_batches + + +def prepare_fast_dataloader( + dataset: Dataset, + split: Literal["train", "val", "test"], + batch_size: int, + y_type: str = "float", +) -> Generator[tuple[torch.Tensor, ...]]: + """ + Prepare a fast dataloader for the dataset. + + Args: + dataset: The dataset to prepare the dataloader for. + split: The split to prepare the dataloader for. + batch_size: The batch size to use for the dataloader. + y_type: The type of the target values. Can be "float" or "long". Default is "float". + + Returns: + A generator of batches of data from the dataset. + """ + if dataset.X_cat is not None: + if dataset.X_num is not None: + X = torch.from_numpy(np.concatenate([dataset.X_num[split], dataset.X_cat[split]], axis=1)).float() + else: + X = torch.from_numpy(dataset.X_cat[split]).float() + else: + assert dataset.X_num is not None + X = torch.from_numpy(dataset.X_num[split]).float() + y = torch.from_numpy(dataset.y[split]).float() if y_type == "float" else torch.from_numpy(dataset.y[split]).long() + dataloader = FastTensorDataLoader(X, y, batch_size=batch_size, shuffle=(split == "train")) + while True: + yield from dataloader diff --git a/src/midst_toolkit/models/clavaddpm/model.py b/src/midst_toolkit/models/clavaddpm/model.py index f1703c43..e9a6e0b9 100644 --- a/src/midst_toolkit/models/clavaddpm/model.py +++ b/src/midst_toolkit/models/clavaddpm/model.py @@ -1,14 +1,13 @@ import math -from collections.abc import Callable, Generator -from typing import Any, Self +from typing import Any, Literal, Self -import numpy as np import pandas as pd import torch import torch.nn.functional as F + +# ruff: noqa: N812 from torch import Tensor, nn -from midst_toolkit.models.clavaddpm.dataset import Dataset from midst_toolkit.models.clavaddpm.typing import ModuleType @@ -23,6 +22,18 @@ def __init__( num_heads: int = 2, num_layers: int = 1, ): + """ + Initialize the classifier model. + + Args: + d_in: The input dimension size. + d_out: The output dimension size. + dim_t: The dimension size of the timestamp. + hidden_sizes: The list of sizes for the hidden layers. + dropout_prob: The dropout probability. Optional, default is 0.5. + num_heads: The number of heads for the transformer layer. Optional, default is 2. + num_layers: The number of layers for the transformer layer. Optional, default is 1. + """ super(Classifier, self).__init__() self.dim_t = dim_t @@ -54,7 +65,17 @@ def __init__( # Create a Sequential model from the list of layers self.model = nn.Sequential(*layers) - def forward(self, x, timesteps): + def forward(self, x: Tensor, timesteps: Tensor) -> Tensor: + """ + Forward pass of the classifier model. + + Args: + x: The input tensor. + timesteps: The timesteps tensor. + + Returns: + The output tensor. + """ emb = self.time_embed(timestep_embedding(timesteps, self.dim_t)) x = self.proj(x) + emb # x = self.transformer_layer(x, x) @@ -62,6 +83,24 @@ def forward(self, x, timesteps): def get_table_info(df: pd.DataFrame, domain_dict: dict[str, Any], y_col: str) -> dict[str, Any]: + """ + Get the dictionary oftable information. + + Args: + df: The dataframe containing the data. + domain_dict: The domain dictionary of metadata about the data columns. + y_col: The name of the target column. + + Returns: + The table information in the following format: + { + "cat_cols": list[str], + "num_cols": list[str], + "y_col": str, + "n_classes": int, + "task_type": str, + } + """ cat_cols = [] num_cols = [] for col in df.columns: @@ -81,31 +120,20 @@ def get_table_info(df: pd.DataFrame, domain_dict: dict[str, Any], y_col: str) -> return df_info -def prepare_fast_dataloader( - D: Dataset, - # ruff: noqa: N803 - split: str, - batch_size: int, - y_type: str = "float", -) -> Generator[tuple[Tensor, ...]]: - if D.X_cat is not None: - if D.X_num is not None: - X = torch.from_numpy(np.concatenate([D.X_num[split], D.X_cat[split]], axis=1)).float() - else: - X = torch.from_numpy(D.X_cat[split]).float() - else: - assert D.X_num is not None - X = torch.from_numpy(D.X_num[split]).float() - y = torch.from_numpy(D.y[split]).float() if y_type == "float" else torch.from_numpy(D.y[split]).long() - dataloader = FastTensorDataLoader(X, y, batch_size=batch_size, shuffle=(split == "train")) - while True: - yield from dataloader - - def get_model( - model_name: str, + model_name: Literal["mlp", "resnet"], model_params: dict[str, Any], ) -> nn.Module: + """ + Get the model. + + Args: + model_name: The name of the model. Can be "mlp" or "resnet". + model_params: The dictionary of parameters of the model. + + Returns: + The model. + """ print(model_name) if model_name == "mlp": return MLPDiffusion(**model_params) @@ -115,68 +143,17 @@ def get_model( raise ValueError("Unknown model!") -class FastTensorDataLoader: - """ - Defines a faster dataloader for PyTorch tensors. - - A DataLoader-like object for a set of tensors that can be much faster than - TensorDataset + DataLoader because dataloader grabs individual indices of - the dataset and calls cat (slow). - Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 - """ - - def __init__(self, *tensors: Tensor, batch_size: int = 32, shuffle: bool = False): - """ - Initialize a FastTensorDataLoader. - :param *tensors: tensors to store. Must have the same length @ dim 0. - :param batch_size: batch size to load. - :param shuffle: if True, shuffle the data *in-place* whenever an - iterator is created out of this object. - :returns: A FastTensorDataLoader. - """ - assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) - self.tensors = tensors - - self.dataset_len = self.tensors[0].shape[0] - self.batch_size = batch_size - self.shuffle = shuffle - - # Calculate # batches - n_batches, remainder = divmod(self.dataset_len, self.batch_size) - if remainder > 0: - n_batches += 1 - self.n_batches = n_batches - - def __iter__(self): - # ruff: noqa: D105 - if self.shuffle: - r = torch.randperm(self.dataset_len) - self.tensors = [t[r] for t in self.tensors] # type: ignore[assignment] - self.i = 0 - return self - - def __next__(self): - # ruff: noqa: D105 - if self.i >= self.dataset_len: - raise StopIteration - batch = tuple(t[self.i : self.i + self.batch_size] for t in self.tensors) - self.i += self.batch_size - return batch - - def __len__(self): - # ruff: noqa: D105 - return self.n_batches - - def timestep_embedding(timesteps: Tensor, dim: int, max_period: int = 10000) -> Tensor: """ Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. + Args: + timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + + Returns: + An [N x dim] Tensor of positional embeddings. """ half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( @@ -223,12 +200,31 @@ def __init__( activation: ModuleType, dropout: float, ) -> None: + """ + Initialize the MLP block. + + Args: + d_in: The input dimension size. + d_out: The output dimension size. + bias: Whether to use bias. + activation: The activation function. + dropout: The dropout probability. + """ super().__init__() self.linear = nn.Linear(d_in, d_out, bias) self.activation = _make_nn_module(activation) self.dropout = nn.Dropout(dropout) def forward(self, x: Tensor) -> Tensor: + """ + Forward pass of the MLP block. + + Args: + x: The input tensor. + + Returns: + The output tensor. + """ return self.dropout(self.activation(self.linear(x))) def __init__( @@ -237,12 +233,21 @@ def __init__( d_in: int, d_layers: list[int], dropouts: float | list[float], - activation: str | Callable[[], nn.Module], + activation: ModuleType, d_out: int, - ) -> None: + ): """ + Initialize the MLP model. + Note: `make_baseline` is the recommended constructor. + + Args: + d_in: The input dimension size. + d_layers: The list of sizes for the hidden layers. + dropouts: Can be either a single value for the dropout rate or a list of dropout rates. + activation: The activation function. + d_out: The output dimension size. """ super().__init__() if isinstance(dropouts, float): @@ -308,6 +313,15 @@ def make_baseline( ) def forward(self, x: Tensor) -> Tensor: + """ + Forward pass of the MLP model. + + Args: + x: The input tensor. + + Returns: + The output tensor. + """ x = x.float() for block in self.blocks: x = block(x) @@ -360,7 +374,21 @@ def __init__( normalization: ModuleType, activation: ModuleType, skip_connection: bool, - ) -> None: + ): + """ + Initialize the ResNet block. + + Args: + d_main: The input dimension size. + d_hidden: The output dimension size. + bias_first: Whether to use bias for the first linear layer. + bias_second: Whether to use bias for the second linear layer. + dropout_first: The dropout probability for the first dropout layer. + dropout_second: The dropout probability for the second dropout layer. + normalization: The normalization function. + activation: The activation function. + skip_connection: Whether to use skip connection. + """ super().__init__() self.normalization = _make_nn_module(normalization, d_main) self.linear_first = nn.Linear(d_main, d_hidden, bias_first) @@ -371,6 +399,15 @@ def __init__( self.skip_connection = skip_connection def forward(self, x: Tensor) -> Tensor: + """ + Forward pass of the ResNet block. + + Args: + x: The input tensor. + + Returns: + The output tensor. + """ x_input = x x = self.normalization(x) x = self.linear_first(x) @@ -393,13 +430,32 @@ def __init__( bias: bool, normalization: ModuleType, activation: ModuleType, - ) -> None: + ): + """ + Initialize the ResNet head. + + Args: + d_in: The input dimension size. + d_out: The output dimension size. + bias: Whether to use bias. + normalization: The normalization function. + activation: The activation function. + """ super().__init__() self.normalization = _make_nn_module(normalization, d_in) self.activation = _make_nn_module(activation) self.linear = nn.Linear(d_in, d_out, bias) def forward(self, x: Tensor) -> Tensor: + """ + Forward pass of the ResNet head. + + Args: + x: The input tensor. + + Returns: + The output tensor. + """ if self.normalization is not None: x = self.normalization(x) x = self.activation(x) @@ -417,10 +473,23 @@ def __init__( normalization: ModuleType, activation: ModuleType, d_out: int, - ) -> None: + ): """ + Initialize the ResNet model. + Note: `make_baseline` is the recommended constructor. + + Args: + d_in: The input dimension size. + n_blocks: The number of blocks. + d_main: The input dimension size. + d_hidden: The output dimension size. + dropout_first: The dropout probability for the first dropout layer. + dropout_second: The dropout probability for the second dropout layer. + normalization: The normalization function. + activation: The activation function. + d_out: The output dimension size. """ super().__init__() @@ -492,6 +561,15 @@ def make_baseline( ) def forward(self, x: Tensor) -> Tensor: + """ + Forward pass of the ResNet model. + + Args: + x: The input tensor. + + Returns: + The output tensor. + """ x = x.float() x = self.first_layer(x) x = self.blocks(x) @@ -506,10 +584,20 @@ def __init__( self, d_in: int, num_classes: int, - is_y_cond: str, + is_y_cond: Literal["concat", "embedding", "none"], rtdl_params: dict[str, Any], dim_t: int = 128, ): + """ + Initialize the MLP diffusion model. + + Args: + d_in: The input dimension size. + num_classes: The number of classes. + is_y_cond: The condition on the y column. Can be "concat", "embedding", or "none". + rtdl_params: The dictionary of parameters for the MLP. + dim_t: The dimension size of the timestamp. + """ super().__init__() self.dim_t = dim_t self.num_classes = num_classes @@ -531,7 +619,18 @@ def __init__( self.proj = nn.Linear(d_in, dim_t) self.time_embed = nn.Sequential(nn.Linear(dim_t, dim_t), nn.SiLU(), nn.Linear(dim_t, dim_t)) - def forward(self, x, timesteps, y=None): + def forward(self, x: Tensor, timesteps: Tensor, y: Tensor | None = None) -> Tensor: + """ + Forward pass of the MLP diffusion model. + + Args: + x: The input tensor. + timesteps: The timesteps tensor. + y: The y tensor. Optional, default is None. + + Returns: + The output tensor. + """ emb = self.time_embed(timestep_embedding(timesteps, self.dim_t)) if self.is_y_cond == "embedding" and y is not None: y = y.squeeze() if self.num_classes > 0 else y.resize_(y.size(0), 1).float() @@ -547,9 +646,19 @@ def __init__( num_classes: int, rtdl_params: dict[str, Any], dim_t: int = 256, - is_y_cond: str | None = None, + is_y_cond: Literal["concat", "embedding", "none"] | None = None, ): - # ruff: noqa: D107 + """ + Initialize the ResNet diffusion model. + + Args: + d_in: The input dimension size. + num_classes: The number of classes. + rtdl_params: The dictionary of parameters for the ResNet. + dim_t: The dimension size of the timestamp. + is_y_cond: The condition on the y column. Can be "concat", "embedding", or "none". + Optional, default is None. + """ super().__init__() self.dim_t = dim_t self.num_classes = num_classes @@ -567,8 +676,18 @@ def __init__( self.time_embed = nn.Sequential(nn.Linear(dim_t, dim_t), nn.SiLU(), nn.Linear(dim_t, dim_t)) - def forward(self, x, timesteps, y=None): - # ruff: noqa: D102 + def forward(self, x: Tensor, timesteps: Tensor, y: Tensor | None = None) -> Tensor: + """ + Forward pass of the ResNet diffusion model. + + Args: + x: The input tensor. + timesteps: The timesteps tensor. + y: The y tensor. Optional, default is None. + + Returns: + The output tensor. + """ emb = self.time_embed(timestep_embedding(timesteps, self.dim_t)) if y is not None and self.num_classes > 0: emb += self.label_emb(y.squeeze()) @@ -576,10 +695,17 @@ def forward(self, x, timesteps, y=None): def reglu(x: Tensor) -> Tensor: - """The ReGLU activation function from [1]. + """ + The ReGLU activation function from [1]. References: [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020 + + Args: + x: The input tensor. + + Returns: + The output tensor. """ assert x.shape[-1] % 2 == 0 a, b = x.chunk(2, dim=-1) @@ -587,10 +713,17 @@ def reglu(x: Tensor) -> Tensor: def geglu(x: Tensor) -> Tensor: - """The GEGLU activation function from [1]. + """ + The GEGLU activation function from [1]. References: [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020 + + Args: + x: The input tensor. + + Returns: + The output tensor. """ assert x.shape[-1] % 2 == 0 a, b = x.chunk(2, dim=-1) @@ -598,30 +731,42 @@ def geglu(x: Tensor) -> Tensor: class ReGLU(nn.Module): - """The ReGLU activation function from [shazeer2020glu]. + """ + The ReGLU activation function from [shazeer2020glu]. Examples: - .. testcode:: - - module = ReGLU() - x = torch.randn(3, 4) - assert module(x).shape == (3, 2) + module = ReGLU() + x = torch.randn(3, 4) + assert module(x).shape == (3, 2) References: * [shazeer2020glu] Noam Shazeer, "GLU Variants Improve Transformer", 2020 + + Args: + x: The input tensor. + + Returns: + The output tensor. """ def forward(self, x: Tensor) -> Tensor: - # ruff: noqa: D102 + """ + Forward pass of the ReGLU activation function. + + Args: + x: The input tensor. + + Returns: + The output tensor. + """ return reglu(x) class GEGLU(nn.Module): - """The GEGLU activation function from [shazeer2020glu]. + """ + The GEGLU activation function from [shazeer2020glu]. Examples: - .. testcode:: - module = GEGLU() x = torch.randn(3, 4) assert module(x).shape == (3, 2) @@ -631,11 +776,29 @@ class GEGLU(nn.Module): """ def forward(self, x: Tensor) -> Tensor: - # ruff: noqa: D102 + """ + Forward pass of the GEGLU activation function. + + Args: + x: The input tensor. + + Returns: + The output tensor. + """ return geglu(x) -def _make_nn_module(module_type: ModuleType, *args) -> nn.Module: # type: ignore[no-untyped-def] +def _make_nn_module(module_type: ModuleType, *args: Any) -> nn.Module: + """ + Make a neural network module. + + Args: + module_type: The type of the module. + args: The arguments for the module. + + Returns: + The neural network module. + """ return ( (ReGLU() if module_type == "ReGLU" else GEGLU() if module_type == "GEGLU" else getattr(nn, module_type)(*args)) if isinstance(module_type, str) diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index 4fea79b0..f98a4a08 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -12,14 +12,15 @@ from torch import Tensor, optim from midst_toolkit.core import logger -from midst_toolkit.models.clavaddpm.dataset import Dataset, Transformations, get_T_dict, make_dataset_from_df -from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion -from midst_toolkit.models.clavaddpm.model import ( - Classifier, - get_model, - get_table_info, +from midst_toolkit.models.clavaddpm.dataset import ( + Dataset, + Transformations, + get_T_dict, + make_dataset_from_df, prepare_fast_dataloader, ) +from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion +from midst_toolkit.models.clavaddpm.model import Classifier, get_model, get_table_info from midst_toolkit.models.clavaddpm.sampler import ScheduleSampler, create_named_schedule_sampler from midst_toolkit.models.clavaddpm.trainer import ClavaDDPMTrainer from midst_toolkit.models.clavaddpm.typing import Configs, RelationOrder, Tables @@ -244,7 +245,7 @@ def train_model( transformations_dict: dict[str, Any], steps: int, batch_size: int, - model_type: str, + model_type: Literal["mlp", "resnet"], gaussian_loss_type: str, num_timesteps: int, scheduler: str, From 6b7a7f62645939a6a76be721004e67377370c5d3 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 17 Sep 2025 16:29:05 -0400 Subject: [PATCH 25/39] David's CR --- .../models/clavaddpm/clustering.py | 8 +++---- src/midst_toolkit/models/clavaddpm/model.py | 18 +-------------- src/midst_toolkit/models/clavaddpm/train.py | 23 ++++++++----------- 3 files changed, 14 insertions(+), 35 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/clustering.py b/src/midst_toolkit/models/clavaddpm/clustering.py index 0ef96bdc..5b841564 100644 --- a/src/midst_toolkit/models/clavaddpm/clustering.py +++ b/src/midst_toolkit/models/clavaddpm/clustering.py @@ -23,7 +23,7 @@ def clava_clustering( configs: Configs, ) -> tuple[dict[str, Any], GroupLengthsProbDicts]: """ - Clustering function for the mutli-table function of the ClavaDDPM model. + Clustering function for the multi-table function of the ClavaDDPM model. Args: tables: Definition of the tables and their relations. Example: @@ -91,7 +91,7 @@ def _load_clustering_info_from_checkpoint(save_dir: Path) -> dict[str, Any] | No if not os.path.exists(save_dir / "cluster_ckpt.pkl"): return None - print("Clustering checkpoint found, loading...") + logger.info("Clustering checkpoint found, loading...") with open(save_dir / "cluster_ckpt.pkl", "rb") as f: return pickle.load(f) @@ -125,7 +125,7 @@ def _run_clustering( relation_order_reversed = relation_order[::-1] for parent, child in relation_order_reversed: if parent is not None: - print(f"Clustering {parent} -> {child}") + logger.info(f"Clustering {parent} -> {child}") if isinstance(configs["num_clusters"], dict): num_clusters = configs["num_clusters"][child] else: @@ -328,7 +328,7 @@ def _pair_clustering_keep_id( # Compute the average agree rate across all groups average_agree_rate = np.mean(agree_rates) - print("Average agree rate: ", average_agree_rate) + logger.info(f"Average agree rate: {average_agree_rate}") group_assignment = np.repeat(group_cluster_labels, child_group_lengths, axis=0).reshape((-1, 1)) diff --git a/src/midst_toolkit/models/clavaddpm/model.py b/src/midst_toolkit/models/clavaddpm/model.py index f7ec9a42..6ad74307 100644 --- a/src/midst_toolkit/models/clavaddpm/model.py +++ b/src/midst_toolkit/models/clavaddpm/model.py @@ -4,7 +4,7 @@ import pickle from abc import ABC, abstractmethod from collections import Counter -from collections.abc import Callable, Generator, Iterator +from collections.abc import Callable, Generator from copy import deepcopy from dataclasses import astuple, dataclass, replace from pathlib import Path @@ -493,22 +493,6 @@ def get_model( raise ValueError("Unknown model!") -def update_ema( - target_params: Iterator[nn.Parameter], - source_params: Iterator[nn.Parameter], - rate: float = 0.999, -) -> None: - """ - Update target parameters to be closer to those of source parameters using - an exponential moving average. - :param target_params: the target parameter sequence. - :param source_params: the source parameter sequence. - :param rate: the EMA rate (closer to 1 means slower). - """ - for targ, src in zip(target_params, source_params): - targ.detach().mul_(rate).add_(src.detach(), alpha=1 - rate) - - class ScheduleSampler(ABC): """ A distribution over timesteps in the diffusion process, intended to reduce diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index d116d310..47185a13 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -1,6 +1,5 @@ """Defines the training functions for the ClavaDDPM model.""" -import logging import pickle from pathlib import Path from typing import Any @@ -28,10 +27,6 @@ from midst_toolkit.models.clavaddpm.typing import Configs, RelationOrder, Tables -logging.basicConfig(level=logging.INFO) -LOGGER = logging.getLogger(__name__) - - def clava_training( tables: Tables, relation_order: RelationOrder, @@ -110,7 +105,7 @@ def clava_training( target_file = target_folder / f"{parent}_{child}_ckpt.pkl" create_message = f"Creating {target_folder}. " if not target_folder.exists() else "" - LOGGER.info(f"{create_message}Saving {parent} -> {child} model to {target_file}") + logger.info(f"{create_message}Saving {parent} -> {child} model to {target_file}") target_folder.mkdir(parents=True, exist_ok=True) with open(target_file, "wb") as f: @@ -232,7 +227,7 @@ def child_training( ) child_result["classifier"] = child_classifier else: - LOGGER.warning("Skipping classifier training since classifier_config['iterations'] <= 0") + logger.warn("Skipping classifier training since classifier_config['iterations'] <= 0") child_result["df_info"] = child_info child_result["model_params"] = child_model_params @@ -414,7 +409,7 @@ def train_classifier( # TODO: understand what's going on here if dataset.X_num is None: - LOGGER.warning("dataset.X_num is None. num_numerical_features will be set to 0") + logger.warn("dataset.X_num is None. num_numerical_features will be set to 0") num_numerical_features = 0 else: num_numerical_features = dataset.X_num["train"].shape[1] @@ -519,13 +514,13 @@ def save_table_info( df_with_cluster = tables[child]["df"] df_without_id = get_df_without_id(df_with_cluster) df_info = result["df_info"] - X_num_real = df_without_id[df_info["num_cols"]].to_numpy().astype(float) - uniq_vals_list = [] - for col in range(X_num_real.shape[1]): - uniq_vals = np.unique(X_num_real[:, col]) - uniq_vals_list.append(uniq_vals) + x_num_real = df_without_id[df_info["num_cols"]].to_numpy().astype(float) + unique_values_list = [] + for column in range(x_num_real.shape[1]): + unique_values = np.unique(x_num_real[:, column]) + unique_values_list.append(unique_values) table_info[(parent, child)] = { - "uniq_vals_list": uniq_vals_list, + "uniq_vals_list": unique_values_list, "size": len(df_with_cluster), "columns": tables[child]["df"].columns, "parents": tables[child]["parents"], From 09e5dc82347866b18634eb30421cb05293f3b199 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 17 Sep 2025 17:03:16 -0400 Subject: [PATCH 26/39] Rewording docstring, replacing logger --- src/midst_toolkit/models/clavaddpm/clustering.py | 13 +++++++------ src/midst_toolkit/models/clavaddpm/train.py | 10 ++++++---- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/clustering.py b/src/midst_toolkit/models/clavaddpm/clustering.py index 5b841564..4ad5b494 100644 --- a/src/midst_toolkit/models/clavaddpm/clustering.py +++ b/src/midst_toolkit/models/clavaddpm/clustering.py @@ -3,6 +3,7 @@ import os import pickle from collections import defaultdict +from logging import INFO, WARNING from pathlib import Path from typing import Any, Literal @@ -12,7 +13,7 @@ from sklearn.mixture import BayesianGaussianMixture, GaussianMixture from sklearn.preprocessing import LabelEncoder, MinMaxScaler, OneHotEncoder, QuantileTransformer -from midst_toolkit.core import logger +from midst_toolkit.common.logger import log from midst_toolkit.models.clavaddpm.typing import Configs, GroupLengthsProbDicts, RelationOrder, Tables @@ -91,7 +92,7 @@ def _load_clustering_info_from_checkpoint(save_dir: Path) -> dict[str, Any] | No if not os.path.exists(save_dir / "cluster_ckpt.pkl"): return None - logger.info("Clustering checkpoint found, loading...") + log(INFO, "Clustering checkpoint found, loading...") with open(save_dir / "cluster_ckpt.pkl", "rb") as f: return pickle.load(f) @@ -125,7 +126,7 @@ def _run_clustering( relation_order_reversed = relation_order[::-1] for parent, child in relation_order_reversed: if parent is not None: - logger.info(f"Clustering {parent} -> {child}") + log(INFO, f"Clustering {parent} -> {child}") if isinstance(configs["num_clusters"], dict): num_clusters = configs["num_clusters"][child] else: @@ -242,7 +243,7 @@ def _pair_clustering_keep_id( for i in range(joint_cat_matrix.shape[1]): # A threshold of 1000 unique values is used to prevent the one-hot encoding of large categorical columns if len(np.unique(joint_cat_matrix[:, i])) > 1000: - logger.warn(f"Categorical column {i} has more than 1000 unique values, skipping...") + log(WARNING, f"Categorical column {i} has more than 1000 unique values, skipping...") continue label_encoder = LabelEncoder() cat_converted.append(label_encoder.fit_transform(joint_cat_matrix[:, i]).astype(float)) @@ -328,7 +329,7 @@ def _pair_clustering_keep_id( # Compute the average agree rate across all groups average_agree_rate = np.mean(agree_rates) - logger.info(f"Average agree rate: {average_agree_rate}") + log(INFO, f"Average agree rate: {average_agree_rate}") group_assignment = np.repeat(group_cluster_labels, child_group_lengths, axis=0).reshape((-1, 1)) @@ -390,7 +391,7 @@ def _pair_clustering_keep_id( "size": len(set(parent_data_clusters_np.flatten())), } - logger.info(f"Number of cluster centers: {new_col_entry['size']}") + log(INFO, f"Number of cluster centers: {new_col_entry['size']}") parent_domain_dict[relation_cluster_name] = new_col_entry.copy() child_domain_dict[relation_cluster_name] = new_col_entry.copy() diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index 47185a13..40816f7f 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -1,6 +1,7 @@ """Defines the training functions for the ClavaDDPM model.""" import pickle +from logging import INFO, WARNING from pathlib import Path from typing import Any @@ -9,6 +10,7 @@ import torch from torch import optim +from midst_toolkit.common.logger import log from midst_toolkit.core import logger from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion from midst_toolkit.models.clavaddpm.model import ( @@ -105,7 +107,7 @@ def clava_training( target_file = target_folder / f"{parent}_{child}_ckpt.pkl" create_message = f"Creating {target_folder}. " if not target_folder.exists() else "" - logger.info(f"{create_message}Saving {parent} -> {child} model to {target_file}") + log(INFO, f"{create_message}Saving {parent} -> {child} model to {target_file}") target_folder.mkdir(parents=True, exist_ok=True) with open(target_file, "wb") as f: @@ -227,7 +229,7 @@ def child_training( ) child_result["classifier"] = child_classifier else: - logger.warn("Skipping classifier training since classifier_config['iterations'] <= 0") + log(WARNING, "Skipping classifier training since classifier_config['iterations'] <= 0") child_result["df_info"] = child_info child_result["model_params"] = child_model_params @@ -379,7 +381,7 @@ def train_classifier( cluster_col: Name of the cluster column. Default is `"cluster"`. dim_t: Dimension of the timestamp. Default is 128. learning_rate: Learning rate to use for the optimizer in the classifier. Default is 0.0001. - classifier_evaluation_interval: The amount of classifier_steps to wait + classifier_evaluation_interval: The number of classifier training steps to wait until the next evaluation of the classifier. Default is 5. Returns: @@ -409,7 +411,7 @@ def train_classifier( # TODO: understand what's going on here if dataset.X_num is None: - logger.warn("dataset.X_num is None. num_numerical_features will be set to 0") + log(WARNING, "dataset.X_num is None. num_numerical_features will be set to 0") num_numerical_features = 0 else: num_numerical_features = dataset.X_num["train"].shape[1] From e39db1fbf7a07f44bf2d8e68a12a9e1cac9befa6 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 17 Sep 2025 17:40:45 -0400 Subject: [PATCH 27/39] Moving data_loaders.py to the clavaddpm function and adding missing docstrings and type hints --- .../clavaddpm}/data_loaders.py | 204 ++++++++++++++++-- src/midst_toolkit/models/clavaddpm/dataset.py | 87 -------- src/midst_toolkit/models/clavaddpm/train.py | 2 +- .../models/clavaddpm/test_model.py | 2 +- 4 files changed, 183 insertions(+), 112 deletions(-) rename src/midst_toolkit/{core => models/clavaddpm}/data_loaders.py (59%) diff --git a/src/midst_toolkit/core/data_loaders.py b/src/midst_toolkit/models/clavaddpm/data_loaders.py similarity index 59% rename from src/midst_toolkit/core/data_loaders.py rename to src/midst_toolkit/models/clavaddpm/data_loaders.py index 81513f7b..0dd6317a 100644 --- a/src/midst_toolkit/core/data_loaders.py +++ b/src/midst_toolkit/models/clavaddpm/data_loaders.py @@ -1,16 +1,36 @@ import json import os -from typing import Any +from collections.abc import Generator +from logging import INFO +from typing import Any, Literal import numpy as np import pandas as pd +import torch +from midst_toolkit.common.logger import log +from midst_toolkit.models.clavaddpm.dataset import Dataset -def load_multi_table(data_dir, verbose=True): + +def load_multi_table( + data_dir: str, verbose: bool = True +) -> tuple[dict[str, Any], list[tuple[str, str]], dict[str, Any]]: + """ + Load the multi-table dataset from the data directory. + + Args: + data_dir: The directory to load the dataset from. + verbose: Whether to print verbose output. Optional, default is True. + + Returns: + A tuple with 3 values: + - The tables dictionary. + - The relation order between the tables. + - The dataset metadata dictionary. + """ dataset_meta = json.load(open(os.path.join(data_dir, "dataset_meta.json"), "r")) relation_order = dataset_meta["relation_order"] - # relation_order_reversed = relation_order[::-1] tables = {} @@ -31,7 +51,7 @@ def load_multi_table(data_dir, verbose=True): id_cols = [col for col in tables[table]["df"].columns if "_id" in col] df_no_id = tables[table]["df"].drop(columns=id_cols) info = get_info_from_domain(df_no_id, tables[table]["domain"]) - data, info = pipeline_process_data( + _, info = pipeline_process_data( name=table, data_df=df_no_id, info=info, @@ -45,7 +65,21 @@ def load_multi_table(data_dir, verbose=True): def get_info_from_domain(data_df: pd.DataFrame, domain_dict: dict[str, Any]) -> dict[str, Any]: - # ruff: noqa: D103 + """ + Get the information dictionaryfrom the domain dictionary. + + Args: + data_df: The dataframe containing the data. + domain_dict: The domain dictionary containing metadata about the data columns. + + Returns: + The information dictionary containing the following keys: + - num_col_idx: The indices of the numerical columns. + - cat_col_idx: The indices of the categorical columns. + - target_col_idx: The indices of the target columns. + - task_type: The type of the task. + - column_names: The names of the columns. + """ info: dict[str, Any] = {} info["num_col_idx"] = [] info["cat_col_idx"] = [] @@ -72,7 +106,30 @@ def pipeline_process_data( save: bool = False, verbose: bool = True, ) -> tuple[dict[str, Any], dict[str, Any]]: - # ruff: noqa: D103 + """ + Process the data through the pipeline. + + Args: + name: The name of the table. + data_df: The dataframe containing the data. + info: The information dictionary, retrieved from the get_info_from_domain function. + ratio: The ratio of the data to be used for training. Optional, default is 0.9. + save: Whether to save the data. Optional, default is False. + verbose: Whether to print verbose output. Optional, default is True. + + Returns: + A tuple with 2 values: + - The data dictionary containing the following keys: + - df: The dataframe containing the data. + - numpy: A dictionary with the numeric data, containing the keys: + - X_num_train: The numeric data for the training set. + - X_cat_train: The categorical data for the training set. + - y_train: The target data for the training set. + - X_num_test: The numeric data for the test set. + - X_cat_test: The categorical data for the test set. + - y_test: The target data for the test set. + - The information dictionary with updated values. + """ num_data = data_df.shape[0] column_names = info["column_names"] if info["column_names"] else data_df.columns.tolist() @@ -235,21 +292,7 @@ def pipeline_process_data( str_shape += ", Numerical data shape: {}".format(X_num_train.shape) str_shape += ", Categorical data shape: {}".format(X_cat_train.shape) - print(str_shape) - - # print(name) - # print('Total', info['train_num'] + info['test_num'] if ratio < 1 else info['train_num']) - # print('Train', info['train_num']) - # if ratio < 1: - # print('Test', info['test_num']) - # if info['task_type'] == 'regression': - # num = len(info['num_col_idx'] + info['target_col_idx']) - # cat = len(info['cat_col_idx']) - # else: - # cat = len(info['cat_col_idx'] + info['target_col_idx']) - # num = len(info['num_col_idx']) - # print('Num', num) - # print('Cat', cat) + log(INFO, str_shape) data = { "df": {"train": train_df}, @@ -276,7 +319,22 @@ def get_column_name_mapping( target_col_idx: list[int], column_names: list[str] | None = None, ) -> tuple[dict[int, int], dict[int, int], dict[int, str]]: - # ruff: noqa: D103 + """ + Get the column name mapping. + + Args: + data_df: The dataframe containing the data. + num_col_idx: The indices of the numerical columns. + cat_col_idx: The indices of the categorical columns. + target_col_idx: The indices of the target columns. + column_names: The names of the columns. + + Returns: + A tuple with 3 values: + - The mapping of the categorical and numerical columns to the indices. + - The mapping of the column names to the indices. + - The mapping of all the indices to the column names. + """ if not column_names: column_names = data_df.columns.tolist() @@ -315,7 +373,21 @@ def train_val_test_split( num_train: int = 0, num_test: int = 0, ) -> tuple[pd.DataFrame, pd.DataFrame, int]: - # ruff: noqa: D103 + """ + Split the data into training and test sets. + + Args: + data_df: The dataframe containing the data. + cat_columns: The names of the categorical columns. + num_train: The number of rows in the training set. Optional, default is 0. + num_test: The number of rows in the test set. Optional, default is 0. + + Returns: + A tuple with 3 values: + - The training dataframe. + - The test dataframe. + - The seed used for the random number generator. + """ total_num = data_df.shape[0] idx = np.arange(total_num) @@ -342,3 +414,89 @@ def train_val_test_split( seed += 1 return train_df, test_df, seed + + +class FastTensorDataLoader: + """ + Defines a faster dataloader for PyTorch tensors. + + A DataLoader-like object for a set of tensors that can be much faster than + TensorDataset + DataLoader because dataloader grabs individual indices of + the dataset and calls cat (slow). + Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 + """ + + def __init__(self, *tensors: torch.Tensor, batch_size: int = 32, shuffle: bool = False): + """ + Initialize a FastTensorDataLoader. + + Args: + *tensors: tensors to store. Must have the same length @ dim 0. + batch_size: batch size to load. + shuffle: if True, shuffle the data *in-place* whenever an + iterator is created out of this object. + """ + assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) + self.tensors = tensors + + self.dataset_len = self.tensors[0].shape[0] + self.batch_size = batch_size + self.shuffle = shuffle + + # Calculate # batches + n_batches, remainder = divmod(self.dataset_len, self.batch_size) + if remainder > 0: + n_batches += 1 + self.n_batches = n_batches + + def __iter__(self): + """Defines the iterator for the FastTensorDataLoader.""" + if self.shuffle: + r = torch.randperm(self.dataset_len) + self.tensors = [t[r] for t in self.tensors] # type: ignore[assignment] + self.i = 0 + return self + + def __next__(self): + """Get the next batch of data from the dataset.""" + if self.i >= self.dataset_len: + raise StopIteration + batch = tuple(t[self.i : self.i + self.batch_size] for t in self.tensors) + self.i += self.batch_size + return batch + + def __len__(self): + """Get the number of batches in the dataset.""" + return self.n_batches + + +def prepare_fast_dataloader( + dataset: Dataset, + split: Literal["train", "val", "test"], + batch_size: int, + y_type: str = "float", +) -> Generator[tuple[torch.Tensor, ...]]: + """ + Prepare a fast dataloader for the dataset. + + Args: + dataset: The dataset to prepare the dataloader for. + split: The split to prepare the dataloader for. + batch_size: The batch size to use for the dataloader. + y_type: The type of the target values. Can be "float" or "long". Default is "float". + + Returns: + A generator of batches of data from the dataset. + """ + if dataset.X_cat is not None: + if dataset.X_num is not None: + X = torch.from_numpy(np.concatenate([dataset.X_num[split], dataset.X_cat[split]], axis=1)).float() + else: + X = torch.from_numpy(dataset.X_cat[split]).float() + else: + assert dataset.X_num is not None + X = torch.from_numpy(dataset.X_num[split]).float() + y = torch.from_numpy(dataset.y[split]).float() if y_type == "float" else torch.from_numpy(dataset.y[split]).long() + dataloader = FastTensorDataLoader(X, y, batch_size=batch_size, shuffle=(split == "train")) + while True: + yield from dataloader diff --git a/src/midst_toolkit/models/clavaddpm/dataset.py b/src/midst_toolkit/models/clavaddpm/dataset.py index 7b857d84..8a5fbf57 100644 --- a/src/midst_toolkit/models/clavaddpm/dataset.py +++ b/src/midst_toolkit/models/clavaddpm/dataset.py @@ -4,7 +4,6 @@ import json import pickle from collections import Counter -from collections.abc import Generator from copy import deepcopy from dataclasses import astuple, dataclass, replace from enum import Enum @@ -879,89 +878,3 @@ def build_target(y: ArrayDict, policy: YPolicy | None, task_type: TaskType) -> t else: raise ValueError(f"Unknown policy: {policy}") return y, info - - -class FastTensorDataLoader: - """ - Defines a faster dataloader for PyTorch tensors. - - A DataLoader-like object for a set of tensors that can be much faster than - TensorDataset + DataLoader because dataloader grabs individual indices of - the dataset and calls cat (slow). - Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 - """ - - def __init__(self, *tensors: torch.Tensor, batch_size: int = 32, shuffle: bool = False): - """ - Initialize a FastTensorDataLoader. - - Args: - *tensors: tensors to store. Must have the same length @ dim 0. - batch_size: batch size to load. - shuffle: if True, shuffle the data *in-place* whenever an - iterator is created out of this object. - """ - assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) - self.tensors = tensors - - self.dataset_len = self.tensors[0].shape[0] - self.batch_size = batch_size - self.shuffle = shuffle - - # Calculate # batches - n_batches, remainder = divmod(self.dataset_len, self.batch_size) - if remainder > 0: - n_batches += 1 - self.n_batches = n_batches - - def __iter__(self): - """Defines the iterator for the FastTensorDataLoader.""" - if self.shuffle: - r = torch.randperm(self.dataset_len) - self.tensors = [t[r] for t in self.tensors] # type: ignore[assignment] - self.i = 0 - return self - - def __next__(self): - """Get the next batch of data from the dataset.""" - if self.i >= self.dataset_len: - raise StopIteration - batch = tuple(t[self.i : self.i + self.batch_size] for t in self.tensors) - self.i += self.batch_size - return batch - - def __len__(self): - """Get the number of batches in the dataset.""" - return self.n_batches - - -def prepare_fast_dataloader( - dataset: Dataset, - split: Literal["train", "val", "test"], - batch_size: int, - y_type: str = "float", -) -> Generator[tuple[torch.Tensor, ...]]: - """ - Prepare a fast dataloader for the dataset. - - Args: - dataset: The dataset to prepare the dataloader for. - split: The split to prepare the dataloader for. - batch_size: The batch size to use for the dataloader. - y_type: The type of the target values. Can be "float" or "long". Default is "float". - - Returns: - A generator of batches of data from the dataset. - """ - if dataset.X_cat is not None: - if dataset.X_num is not None: - X = torch.from_numpy(np.concatenate([dataset.X_num[split], dataset.X_cat[split]], axis=1)).float() - else: - X = torch.from_numpy(dataset.X_cat[split]).float() - else: - assert dataset.X_num is not None - X = torch.from_numpy(dataset.X_num[split]).float() - y = torch.from_numpy(dataset.y[split]).float() if y_type == "float" else torch.from_numpy(dataset.y[split]).long() - dataloader = FastTensorDataLoader(X, y, batch_size=batch_size, shuffle=(split == "train")) - while True: - yield from dataloader diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index c5d96ead..23d3d044 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -13,12 +13,12 @@ from midst_toolkit.common.logger import log from midst_toolkit.core import logger +from midst_toolkit.models.clavaddpm.data_loaders import prepare_fast_dataloader from midst_toolkit.models.clavaddpm.dataset import ( Dataset, Transformations, get_T_dict, make_dataset_from_df, - prepare_fast_dataloader, ) from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion from midst_toolkit.models.clavaddpm.model import Classifier, get_model, get_table_info diff --git a/tests/integration/models/clavaddpm/test_model.py b/tests/integration/models/clavaddpm/test_model.py index 4895e7b8..42bc5e76 100644 --- a/tests/integration/models/clavaddpm/test_model.py +++ b/tests/integration/models/clavaddpm/test_model.py @@ -10,8 +10,8 @@ from torch.nn import functional from midst_toolkit.common.random import set_all_random_seeds, unset_all_random_seeds -from midst_toolkit.core.data_loaders import load_multi_table from midst_toolkit.models.clavaddpm.clustering import clava_clustering +from midst_toolkit.models.clavaddpm.data_loaders import load_multi_table from midst_toolkit.models.clavaddpm.model import Classifier from midst_toolkit.models.clavaddpm.train import clava_training From bc462dd07523b0a6623388e4ba49738ebe1839aa Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 19 Sep 2025 11:27:22 -0400 Subject: [PATCH 28/39] David's last comments --- src/midst_toolkit/models/clavaddpm/sampler.py | 2 +- src/midst_toolkit/models/clavaddpm/train.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/midst_toolkit/models/clavaddpm/sampler.py b/src/midst_toolkit/models/clavaddpm/sampler.py index 2020fcc4..3edb30bc 100644 --- a/src/midst_toolkit/models/clavaddpm/sampler.py +++ b/src/midst_toolkit/models/clavaddpm/sampler.py @@ -198,7 +198,7 @@ def create_named_schedule_sampler( Returns: The UniformSampler if ``name`` is "uniform", LossSecondMomentResampler if ``name`` - is "loss-second-moment". + is "loss-second-moment". """ if name == "uniform": return UniformSampler(diffusion) diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index 9c1843ab..a570c75e 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -612,6 +612,7 @@ def _numerical_forward_backward_log( loss.backward(loss * len(sub_batch) / len(batch)) +# TODO: Think about moving this to a metrics module def _compute_top_k( logits: Tensor, labels: Tensor, From ff2d434342494c6953a44ee4fafae135c64185ee Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 19 Sep 2025 11:46:06 -0400 Subject: [PATCH 29/39] Removing data_loaders module from the docs --- docs/api.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/docs/api.md b/docs/api.md index 5ebf872e..30fbca29 100644 --- a/docs/api.md +++ b/docs/api.md @@ -7,13 +7,6 @@ show_root_heading: true show_root_full_path: true -## Data Loaders Module - -::: midst_toolkit.core.data_loaders - options: - show_root_heading: true - show_root_full_path: true - ## Logger Module ::: midst_toolkit.core.logger From 03232f2c21c0a03573fc64cb31d892f5d370ed23 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 19 Sep 2025 14:05:20 -0400 Subject: [PATCH 30/39] WIP needs to fix bug --- src/midst_toolkit/common/logger.py | 47 ++ src/midst_toolkit/core/__init__.py | 1 - src/midst_toolkit/core/logger.py | 504 -------------------- src/midst_toolkit/models/clavaddpm/train.py | 33 +- 4 files changed, 69 insertions(+), 516 deletions(-) delete mode 100644 src/midst_toolkit/core/__init__.py delete mode 100644 src/midst_toolkit/core/logger.py diff --git a/src/midst_toolkit/common/logger.py b/src/midst_toolkit/common/logger.py index e8cf269a..f9f7f8b9 100644 --- a/src/midst_toolkit/common/logger.py +++ b/src/midst_toolkit/common/logger.py @@ -174,3 +174,50 @@ def redirect_output(output_buffer: StringIO) -> None: sys.stdout = output_buffer sys.stderr = output_buffer console_handler.stream = sys.stdout + + +class KeyValueLogger: + def __init__(self, log_level: int = logging.DEBUG): + self.key_to_value: dict[str, float] = {} + self.key_to_count: dict[str, int] = {} + self.log_level = log_level + + def save_entry(self, key: str, value: Any) -> None: + print(key) + self.key_to_value[key] = value + + def save_entry_mean(self, key: str, value: Any) -> None: + old_value = self.key_to_value[key] + count = self.key_to_count[key] + self.key_to_value[key] = old_value * count / (count + 1) + value / (count + 1) + self.key_to_count[key] = count + 1 + + def dump(self) -> None: + # Create strings for printing + key_to_string = {} + for key, value in sorted(self.key_to_value.items()): + value_string = "%-8.3g" % value if hasattr(value, "__float__") else str(value) + key_to_string[self._truncate(key)] = self._truncate(value_string) + + if len(key_to_string) == 0: + log(self.log_level, "WARNING: tried to write empty key-value dict") + return + + # Find max widths + key_width = max(map(len, key_to_string.keys())) + value_width = max(map(len, key_to_string.values())) + + # Write out the data + dashes = "-" * (key_width + value_width + 7) + log(self.log_level, dashes) + for key, value in sorted(key_to_string.items(), key=lambda kv: kv[0].lower()): + line = "| %s%s | %s%s |" % (key, " " * (key_width - len(key)), value, " " * (value_width - len(value))) + log(self.log_level, line) + log(self.log_level, dashes) + + self.key_to_value.clear() + self.key_to_count.clear() + + def _truncate(self, s: str) -> str: + max_length = 30 + return s[: max_length - 3] + "..." if len(s) > max_length else s diff --git a/src/midst_toolkit/core/__init__.py b/src/midst_toolkit/core/__init__.py deleted file mode 100644 index 0349622b..00000000 --- a/src/midst_toolkit/core/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""PLACEHOLDER.""" diff --git a/src/midst_toolkit/core/logger.py b/src/midst_toolkit/core/logger.py deleted file mode 100644 index b76ce8b7..00000000 --- a/src/midst_toolkit/core/logger.py +++ /dev/null @@ -1,504 +0,0 @@ -""" -Logger copied from OpenAI baselines to avoid extra RL-based dependencies. - -https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py -""" - -# TODO is this file necessary at all? - -import datetime -import json -import os -import os.path as osp -import sys -import tempfile -import time -import warnings -from collections import defaultdict -from collections.abc import Generator, Iterable -from contextlib import contextmanager -from typing import IO, Any - - -DEBUG = 10 -INFO = 20 -WARN = 30 -ERROR = 40 - -DISABLED = 50 - - -class KVWriter(object): - def writekvs(self, kvs: dict[str, Any]) -> None: - raise NotImplementedError - - def close(self) -> None: - raise NotImplementedError - - -class SeqWriter(object): - def writeseq(self, seq: Iterable[str]) -> None: - raise NotImplementedError - - def close(self) -> None: - raise NotImplementedError - - -class HumanOutputFormat(KVWriter, SeqWriter): - def __init__(self, filename_or_file: str | IO[str]): - if isinstance(filename_or_file, str): - self.file = open(filename_or_file, "wt") - # ruff: noqa: SIM115 - self.own_file = True - else: - assert hasattr(filename_or_file, "read"), "expected file or str, got %s" % filename_or_file - self.file = filename_or_file # type: ignore[assignment] - self.own_file = False - - def writekvs(self, kvs): - # Create strings for printing - key2str = {} - for key, val in sorted(kvs.items()): - valstr = "%-8.3g" % val if hasattr(val, "__float__") else str(val) - key2str[self._truncate(key)] = self._truncate(valstr) - - # Find max widths - if len(key2str) == 0: - print("WARNING: tried to write empty key-value dict") - return - keywidth = max(map(len, key2str.keys())) - valwidth = max(map(len, key2str.values())) - - # Write out the data - dashes = "-" * (keywidth + valwidth + 7) - lines = [dashes] - for key, val in sorted(key2str.items(), key=lambda kv: kv[0].lower()): - lines.append("| %s%s | %s%s |" % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))) - lines.append(dashes) - self.file.write("\n".join(lines) + "\n") - - # Flush the output to the file - self.file.flush() - - def _truncate(self, s: str) -> str: - maxlen = 30 - return s[: maxlen - 3] + "..." if len(s) > maxlen else s - - def writeseq(self, seq): - seq = list(seq) - for i, elem in enumerate(seq): - self.file.write(elem) - if i < len(seq) - 1: # add space unless this is the last one - self.file.write(" ") - self.file.write("\n") - self.file.flush() - - def close(self) -> None: - if self.own_file: - self.file.close() - - -class JSONOutputFormat(KVWriter): - def __init__(self, filename: str): - self.file = open(filename, "wt") - # ruff: noqa: SIM115 - - def writekvs(self, kvs): - for k, v in sorted(kvs.items()): - if hasattr(v, "dtype"): - kvs[k] = float(v) - self.file.write(json.dumps(kvs) + "\n") - self.file.flush() - - def close(self) -> None: - self.file.close() - - -class CSVOutputFormat(KVWriter): - def __init__(self, filename: str): - self.file = open(filename, "w+t") - # ruff: noqa: SIM115 - self.keys: list[str] = [] - self.sep = "," - - def writekvs(self, kvs): - # Add our current row to the history - extra_keys = list(kvs.keys() - self.keys) - extra_keys.sort() - if extra_keys: - self.keys.extend(extra_keys) - self.file.seek(0) - lines = self.file.readlines() - self.file.seek(0) - for i, k in enumerate(self.keys): - if i > 0: - self.file.write(",") - self.file.write(k) - self.file.write("\n") - for line in lines[1:]: - self.file.write(line[:-1]) - self.file.write(self.sep * len(extra_keys)) - self.file.write("\n") - for i, k in enumerate(self.keys): - if i > 0: - self.file.write(",") - v = kvs.get(k) - if v is not None: - self.file.write(str(v)) - self.file.write("\n") - self.file.flush() - - def close(self) -> None: - self.file.close() - - -class TensorBoardOutputFormat(KVWriter): - """Dumps key/value pairs into TensorBoard's numeric format.""" - - def __init__(self, dir: str): - os.makedirs(dir, exist_ok=True) - self.dir = dir - self.step = 1 - prefix = "events" - path = osp.join(osp.abspath(dir), prefix) - import tensorflow as tf - from tensorflow.core.util import event_pb2 - from tensorflow.python import pywrap_tensorflow - from tensorflow.python.util import compat - - self.tf = tf - self.event_pb2 = event_pb2 - self.pywrap_tensorflow = pywrap_tensorflow - self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) - - def writekvs(self, kvs: dict[str, Any]) -> None: - def summary_val(k: str, v: Any) -> Any: - kwargs = {"tag": k, "simple_value": float(v)} - return self.tf.Summary.Value(**kwargs) - - summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) - event = self.event_pb2.Event(wall_time=time.time(), summary=summary) - event.step = self.step # is there any reason why you'd want to specify the step? - self.writer.WriteEvent(event) - self.writer.Flush() - self.step += 1 - - def close(self) -> None: - if self.writer: - self.writer.Close() - self.writer = None - - -def make_output_format(format: str, ev_dir: str, log_suffix: str = "") -> KVWriter | SeqWriter: - os.makedirs(ev_dir, exist_ok=True) - if format == "stdout": - return HumanOutputFormat(sys.stdout) - if format == "log": - return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) - if format == "json": - return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) - if format == "csv": - return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) - if format == "tensorboard": - return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) - raise ValueError("Unknown format specified: %s" % (format,)) - - -# ================================================================ -# API -# ================================================================ - - -def logkv(key: str, val: Any) -> None: - """ - Log a value of some diagnostic. - - Call this once for each diagnostic quantity, each iteration - If called many times, last value will be used. - """ - get_current().logkv(key, val) - - -def logkv_mean(key: str, val: Any) -> None: - """The same as logkv(), but if called many times, values averaged.""" - get_current().logkv_mean(key, val) - - -def logkvs(d: dict[str, Any]) -> None: - """Log a dictionary of key-value pairs.""" - for k, v in d.items(): - logkv(k, v) - - -def dumpkvs() -> dict[str, Any]: - """Write all of the diagnostics from the current iteration.""" - return get_current().dumpkvs() - - -def getkvs() -> dict[str, Any]: - return get_current().name2val - - -def log(*args: Iterable[Any], level: int = INFO) -> None: - """ - Logs the args in the desired level. - - Write the sequence of args, with no separators, to the console and output - files (if you've configured an output file). - """ - get_current().log(*args, level=level) - - -def debug(*args: Iterable[Any]) -> None: - log(*args, level=DEBUG) - - -def info(*args: Iterable[Any]) -> None: - log(*args, level=INFO) - - -def warn(*args: Iterable[Any]) -> None: - log(*args, level=WARN) - - -def error(*args: Iterable[Any]) -> None: - log(*args, level=ERROR) - - -def set_level(level: int) -> None: - """Set logging threshold on current logger.""" - get_current().set_level(level) - - -def set_comm(comm: Any | None) -> None: - get_current().set_comm(comm) - - -def get_dir() -> str: - """ - Get directory that log files are being written to. - - will be None if there is no output directory (i.e., if you didn't call start) - """ - return get_current().get_dir() - - -record_tabular = logkv -dump_tabular = dumpkvs - - -@contextmanager -def profile_kv(scopename: str) -> Generator[None, None, None]: - logkey = "wait_" + scopename - tstart = time.time() - try: - yield - finally: - get_current().name2val[logkey] += time.time() - tstart - - -def profile(n): - """ - Usage. - - @profile("my_func") - def my_func(): code - """ - - def decorator_with_name(func): - def func_wrapper(*args, **kwargs): - with profile_kv(n): - return func(*args, **kwargs) - - return func_wrapper - - return decorator_with_name - - -# ================================================================ -# Backend -# ================================================================ - - -class Logger(object): - DEFAULT = None # A logger with no output files. (See right below class definition) - # So that you can still log to the terminal without setting up any output files - CURRENT = None # Current logger being used by the free functions above - - def __init__(self, dir: str, output_formats: list[KVWriter | SeqWriter], comm: Any | None = None): - # ruff: noqa: D107 - self.name2val: defaultdict[str, float] = defaultdict(float) # values this iteration - self.name2cnt: defaultdict[str, int] = defaultdict(int) - self.level = INFO - self.dir = dir - self.output_formats = output_formats - self.comm = comm - - # Logging API, forwarded - # ---------------------------------------- - def logkv(self, key: str, val: Any) -> None: - # ruff: noqa: D102 - self.name2val[key] = val - - def logkv_mean(self, key: str, val: Any) -> None: - oldval, cnt = self.name2val[key], self.name2cnt[key] - self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) - self.name2cnt[key] = cnt + 1 - - def dumpkvs(self) -> dict[str, Any]: - # ruff: noqa: D102 - if self.comm is None: - d = self.name2val - else: - d = mpi_weighted_mean( # type: ignore[assignment] - self.comm, - {name: (val, self.name2cnt.get(name, 1)) for (name, val) in self.name2val.items()}, - ) - if self.comm.rank != 0: - d["dummy"] = 1 # so we don't get a warning about empty dict - out = d.copy() # Return the dict for unit testing purposes - for fmt in self.output_formats: - if isinstance(fmt, KVWriter): - fmt.writekvs(d) - self.name2val.clear() - self.name2cnt.clear() - return out - - def log(self, *args: Iterable[Any], level: int = INFO) -> None: - # ruff: noqa: D102 - if self.level <= level: - self._do_log(args) - - # Configuration - # ---------------------------------------- - def set_level(self, level: int) -> None: - # ruff: noqa: D102 - self.level = level - - def set_comm(self, comm: Any | None) -> None: - # ruff: noqa: D102 - self.comm = comm - - def get_dir(self) -> str: - # ruff: noqa: D102 - return self.dir - - def close(self) -> None: - # ruff: noqa: D102 - for fmt in self.output_formats: - fmt.close() - - # Misc - # ---------------------------------------- - def _do_log(self, args: Iterable[Any]) -> None: - for fmt in self.output_formats: - if isinstance(fmt, SeqWriter): - fmt.writeseq(map(str, args)) - - -def get_rank_without_mpi_import() -> int: - # check environment variables here instead of importing mpi4py - # to avoid calling MPI_Init() when this module is imported - for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: - if varname in os.environ: - return int(os.environ[varname]) - return 0 - - -def mpi_weighted_mean(comm: Any, local_name2valcount: dict[str, tuple[float, float]]) -> dict[str, float]: - """ - Copied from below. - - https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 - Perform a weighted average over dicts that are each on a different node - Input: local_name2valcount: dict mapping key -> (value, count) - Returns: key -> mean - """ - all_name2valcount = comm.gather(local_name2valcount) - if comm.rank == 0: - name2sum: defaultdict[str, float] = defaultdict(float) - name2count: defaultdict[str, float] = defaultdict(float) - for n2vc in all_name2valcount: - for name, (val, count) in n2vc.items(): - try: - val = float(val) - except ValueError: - if comm.rank == 0: - warnings.warn("WARNING: tried to compute mean on non-float {}={}".format(name, val)) - # ruff: noqa: B028 - else: - name2sum[name] += val * count - name2count[name] += count - return {name: name2sum[name] / name2count[name] for name in name2sum} - return {} - - -def configure( - dir: str | None = None, - format_strs: list[str] | None = None, - comm: Any | None = None, - log_suffix: str = "", -) -> None: - """If comm is provided, average all numerical stats across that comm.""" - if dir is None: - dir = os.getenv("OPENAI_LOGDIR") - if dir is None: - dir = osp.join( - tempfile.gettempdir(), - datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), - ) - assert isinstance(dir, str) - dir = os.path.expanduser(dir) - os.makedirs(os.path.expanduser(dir), exist_ok=True) - - rank = get_rank_without_mpi_import() - if rank > 0: - log_suffix = log_suffix + "-rank%03i" % rank - - if format_strs is None: - if rank == 0: - format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") - else: - format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") - format_strs_filter = filter(None, format_strs) - output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs_filter] - - Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) # type: ignore[assignment] - if output_formats: - log("Logging to %s" % dir) - - -def _configure_default_logger() -> None: - # ruff: noqa: D103 - configure() - Logger.DEFAULT = Logger.CURRENT - - -def reset() -> None: - # ruff: noqa: D103 - if Logger.CURRENT is not Logger.DEFAULT: - Logger.CURRENT.close() - Logger.CURRENT = Logger.DEFAULT - log("Reset logger") - - -@contextmanager -def scoped_configure(dir=None, format_strs=None, comm=None): - # ruff: noqa: D103 - prevlogger = Logger.CURRENT - configure(dir=dir, format_strs=format_strs, comm=comm) - try: - yield - finally: - assert Logger.CURRENT is not None - Logger.CURRENT.close() - Logger.CURRENT = prevlogger - - -def get_current() -> Logger: - # ruff: noqa: D103 - if Logger.CURRENT is None: - _configure_default_logger() - - assert isinstance(Logger.CURRENT, Logger) - return Logger.CURRENT diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index 8c789082..c9a5db49 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -11,8 +11,7 @@ import torch from torch import Tensor, optim -from midst_toolkit.common.logger import log -from midst_toolkit.core import logger +from midst_toolkit.common.logger import log, KeyValueLogger from midst_toolkit.models.clavaddpm.data_loaders import prepare_fast_dataloader from midst_toolkit.models.clavaddpm.dataset import ( Dataset, @@ -438,14 +437,12 @@ def train_classifier( empty_diffusion.to(device) schedule_sampler = create_named_schedule_sampler("uniform", empty_diffusion) + key_value_logger = KeyValueLogger() classifier.train() for step in range(classifier_steps): - logger.logkv("step", step) - logger.logkv( - "samples", - (step + 1) * batch_size, - ) + key_value_logger.save_entry("step", step) + key_value_logger.save_entry("samples", (step + 1) * batch_size) _numerical_forward_backward_log( classifier, classifier_optimizer, @@ -455,6 +452,7 @@ def train_classifier( empty_diffusion, prefix="train", device=device, + key_value_logger=key_value_logger, ) classifier_optimizer.step() @@ -470,9 +468,12 @@ def train_classifier( empty_diffusion, prefix="val", device=device, + key_value_logger=key_value_logger, ) classifier.train() + key_value_logger.dump() + # test classifier classifier.eval() @@ -564,6 +565,7 @@ def _numerical_forward_backward_log( prefix: str = "train", remove_first_col: bool = False, device: str = "cuda", + key_value_logger: KeyValueLogger | None = None, ) -> None: """ Forward and backward pass for the numerical features of the ClavaDDPM model. @@ -578,6 +580,7 @@ def _numerical_forward_backward_log( prefix: The prefix for the loss. Defaults to "train". remove_first_col: Whether to remove the first column of the batch. Defaults to False. device: The device to use. Defaults to "cuda". + key_value_logger: The key-value logger to log ther losses. If None, the losses are not logged. """ batch, labels = next(data_loader) labels = labels.long().to(device) @@ -600,7 +603,7 @@ def _numerical_forward_backward_log( losses[f"{prefix}_acc@1"] = _compute_top_k(logits, sub_labels, k=1, reduction="none") if logits.shape[1] >= 5: losses[f"{prefix}_acc@5"] = _compute_top_k(logits, sub_labels, k=5, reduction="none") - _log_loss_dict(diffusion, sub_t, losses) + _log_loss_dict(diffusion, sub_t, losses, key_value_logger) del losses loss = loss.mean() if loss.requires_grad: @@ -637,7 +640,11 @@ def _compute_top_k( raise ValueError(f"reduction should be one of ['mean', 'none']: {reduction}") -def _log_loss_dict(diffusion: GaussianMultinomialDiffusion, ts: Tensor, losses: dict[str, Tensor]) -> None: +def _log_loss_dict( + diffusion: GaussianMultinomialDiffusion, + ts: Tensor, losses: dict[str, Tensor], + key_value_logger: KeyValueLogger | None = None, +) -> None: """ Output the log loss dictionary in the logger. @@ -645,13 +652,17 @@ def _log_loss_dict(diffusion: GaussianMultinomialDiffusion, ts: Tensor, losses: diffusion: The diffusion object. ts: The timesteps. losses: The losses. + key_value_logger: The key-value logger. If None, the losses are not logged. """ + if key_value_logger is None: + return + for key, values in losses.items(): - logger.logkv_mean(key, values.mean().item()) + key_value_logger.save_entry_mean(key, values.mean().item()) # Log the quantiles (four quartiles, in particular). for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): quartile = int(4 * sub_t / diffusion.num_timesteps) - logger.logkv_mean(f"{key}_q{quartile}", sub_loss) + key_value_logger.save_entry_mean(f"{key}_q{quartile}", sub_loss) def _split_microbatches( From af1597911668f8b43971b28aa6618a391beb5f90 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 19 Sep 2025 16:46:27 -0400 Subject: [PATCH 31/39] Finished implementation, needs tests --- src/midst_toolkit/common/logger.py | 50 ++++++++++++++++++--- src/midst_toolkit/models/clavaddpm/train.py | 15 ++++--- 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/src/midst_toolkit/common/logger.py b/src/midst_toolkit/common/logger.py index f9f7f8b9..3f360bce 100644 --- a/src/midst_toolkit/common/logger.py +++ b/src/midst_toolkit/common/logger.py @@ -2,6 +2,7 @@ import logging import sys +from collections import defaultdict from io import StringIO from logging import LogRecord from pathlib import Path @@ -177,22 +178,48 @@ def redirect_output(output_buffer: StringIO) -> None: class KeyValueLogger: + """Logger for key-value pairs.""" + def __init__(self, log_level: int = logging.DEBUG): - self.key_to_value: dict[str, float] = {} - self.key_to_count: dict[str, int] = {} + """ + Initialize the key-value logger. + + Args: + log_level: The log level to use when dumping the key-value pairs. Defaults to logging.DEBUG. + """ + self.key_to_value: defaultdict[str, float] = defaultdict(float) + self.key_to_count: defaultdict[str, int] = defaultdict(int) self.log_level = log_level def save_entry(self, key: str, value: Any) -> None: - print(key) + """ + Save an entry to the key-value logger. + + Args: + key: The key to save. + value: The value to save. + """ self.key_to_value[key] = value def save_entry_mean(self, key: str, value: Any) -> None: + """ + Save an entry to the key-value logger with mean calculation. + + Args: + key: The key to save. + value: The value to add to the mean calculation. + """ old_value = self.key_to_value[key] count = self.key_to_count[key] self.key_to_value[key] = old_value * count / (count + 1) + value / (count + 1) self.key_to_count[key] = count + 1 def dump(self) -> None: + """ + Dump the key-value pairs at the log level specified in the constructor. + + Will clear the key-value pairs after dumping. + """ # Create strings for printing key_to_string = {} for key, value in sorted(self.key_to_value.items()): @@ -203,21 +230,32 @@ def dump(self) -> None: log(self.log_level, "WARNING: tried to write empty key-value dict") return - # Find max widths + # Find max widths key_width = max(map(len, key_to_string.keys())) value_width = max(map(len, key_to_string.values())) # Write out the data dashes = "-" * (key_width + value_width + 7) log(self.log_level, dashes) - for key, value in sorted(key_to_string.items(), key=lambda kv: kv[0].lower()): - line = "| %s%s | %s%s |" % (key, " " * (key_width - len(key)), value, " " * (value_width - len(value))) + sorted_key_to_string = sorted(key_to_string.items(), key=lambda kv: kv[0].lower()) + for k, v in sorted_key_to_string: + line = "| %s%s | %s%s |" % (k, " " * (key_width - len(k)), v, " " * (value_width - len(v))) log(self.log_level, line) log(self.log_level, dashes) + # Clear the key-value pairs self.key_to_value.clear() self.key_to_count.clear() def _truncate(self, s: str) -> str: + """ + Truncate a string to a maximum length of 30 characters. + + Args: + s: The string to truncate. + + Returns: + The string truncated to 30 characters. + """ max_length = 30 return s[: max_length - 3] + "..." if len(s) > max_length else s diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index c9a5db49..a6cf2ec7 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -11,7 +11,7 @@ import torch from torch import Tensor, optim -from midst_toolkit.common.logger import log, KeyValueLogger +from midst_toolkit.common.logger import KeyValueLogger, log from midst_toolkit.models.clavaddpm.data_loaders import prepare_fast_dataloader from midst_toolkit.models.clavaddpm.dataset import ( Dataset, @@ -359,6 +359,7 @@ def train_classifier( dim_t: int = 128, learning_rate: float = 0.0001, classifier_evaluation_interval: int = 5, + logger_interval: int = 10, ) -> Classifier: """ Training function for the classifier model. @@ -380,6 +381,8 @@ def train_classifier( learning_rate: Learning rate to use for the optimizer in the classifier. Default is 0.0001. classifier_evaluation_interval: The number of classifier training steps to wait until the next evaluation of the classifier. Default is 5. + logger_interval: The number of classifier training steps to wait until the next logging + of its metrics. Default is 10. Returns: The trained classifier model. @@ -441,8 +444,8 @@ def train_classifier( classifier.train() for step in range(classifier_steps): - key_value_logger.save_entry("step", step) - key_value_logger.save_entry("samples", (step + 1) * batch_size) + key_value_logger.save_entry("step", float(step)) + key_value_logger.save_entry("samples", float((step + 1) * batch_size)) _numerical_forward_backward_log( classifier, classifier_optimizer, @@ -472,7 +475,8 @@ def train_classifier( ) classifier.train() - key_value_logger.dump() + if not step % logger_interval: + key_value_logger.dump() # test classifier classifier.eval() @@ -642,7 +646,8 @@ def _compute_top_k( def _log_loss_dict( diffusion: GaussianMultinomialDiffusion, - ts: Tensor, losses: dict[str, Tensor], + ts: Tensor, + losses: dict[str, Tensor], key_value_logger: KeyValueLogger | None = None, ) -> None: """ From 226c97ce0d603991d072eebd757cc4e307871665 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 19 Sep 2025 16:53:38 -0400 Subject: [PATCH 32/39] Changing from Any to float --- src/midst_toolkit/common/logger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/midst_toolkit/common/logger.py b/src/midst_toolkit/common/logger.py index 3f360bce..8c200b3b 100644 --- a/src/midst_toolkit/common/logger.py +++ b/src/midst_toolkit/common/logger.py @@ -191,7 +191,7 @@ def __init__(self, log_level: int = logging.DEBUG): self.key_to_count: defaultdict[str, int] = defaultdict(int) self.log_level = log_level - def save_entry(self, key: str, value: Any) -> None: + def save_entry(self, key: str, value: float) -> None: """ Save an entry to the key-value logger. @@ -201,7 +201,7 @@ def save_entry(self, key: str, value: Any) -> None: """ self.key_to_value[key] = value - def save_entry_mean(self, key: str, value: Any) -> None: + def save_entry_mean(self, key: str, value: float) -> None: """ Save an entry to the key-value logger with mean calculation. From 941803eab1c913f57c8d220f419c88e50f2b44b1 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Mon, 22 Sep 2025 11:02:05 -0400 Subject: [PATCH 33/39] Adding tests for the key value logger --- src/midst_toolkit/common/logger.py | 2 +- tests/integration/core/__init__.py | 1 - tests/unit/common/__init__.py | 0 tests/unit/common/test_logger.py | 77 ++++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+), 2 deletions(-) delete mode 100644 tests/integration/core/__init__.py create mode 100644 tests/unit/common/__init__.py create mode 100644 tests/unit/common/test_logger.py diff --git a/src/midst_toolkit/common/logger.py b/src/midst_toolkit/common/logger.py index 8c200b3b..de1ed80e 100644 --- a/src/midst_toolkit/common/logger.py +++ b/src/midst_toolkit/common/logger.py @@ -178,7 +178,7 @@ def redirect_output(output_buffer: StringIO) -> None: class KeyValueLogger: - """Logger for key-value pairs.""" + """Logger for key-value pairs of numerical metrics.""" def __init__(self, log_level: int = logging.DEBUG): """ diff --git a/tests/integration/core/__init__.py b/tests/integration/core/__init__.py deleted file mode 100644 index 4ec8f6b4..00000000 --- a/tests/integration/core/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Unit tests for midst-toolkit.""" diff --git a/tests/unit/common/__init__.py b/tests/unit/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/common/test_logger.py b/tests/unit/common/test_logger.py new file mode 100644 index 00000000..1c0377e1 --- /dev/null +++ b/tests/unit/common/test_logger.py @@ -0,0 +1,77 @@ +import logging +from unittest.mock import Mock, call, patch + +from midst_toolkit.common.logger import KeyValueLogger + + +def test_key_value_logger_init() -> None: + key_value_logger = KeyValueLogger() + assert key_value_logger.key_to_value == {} + assert key_value_logger.key_to_count == {} + assert key_value_logger.log_level == logging.DEBUG + + +def test_key_value_logger_save_entry() -> None: + key_value_logger = KeyValueLogger() + key_value_logger.save_entry("test_key", 1.0) + assert key_value_logger.key_to_value["test_key"] == 1.0 + + +def test_key_value_logger_save_entry_mean() -> None: + key_value_logger = KeyValueLogger() + key_value_logger.save_entry_mean("test_key", 1.0) + assert key_value_logger.key_to_value["test_key"] == 1.0 + assert key_value_logger.key_to_count["test_key"] == 1 + key_value_logger.save_entry_mean("test_key", 2.0) + assert key_value_logger.key_to_value["test_key"] == 1.5 + assert key_value_logger.key_to_count["test_key"] == 2 + + +def test_key_value_logger_truncate() -> None: + key_value_logger = KeyValueLogger() + result = key_value_logger._truncate("test string " * 3) # total of 36 characters + assert result == "test string test string tes..." + + +@patch("midst_toolkit.common.logger.log") +def test_key_value_logger_dump(mock_log: Mock) -> None: + key_value_logger = KeyValueLogger() + key_value_logger.save_entry("test_key", 1.0) + key_value_logger.save_entry("test_key_2", 0.79) + key_value_logger.dump() + + assert mock_log.call_count == 4 + mock_log.assert_has_calls( + [ + call(logging.DEBUG, "-------------------------"), + call(logging.DEBUG, "| test_key | 1 |"), + call(logging.DEBUG, "| test_key_2 | 0.79 |"), + call(logging.DEBUG, "-------------------------"), + ] + ) + mock_log.reset_mock() + assert len(key_value_logger.key_to_value) == 0 + assert len(key_value_logger.key_to_count) == 0 + + key_value_logger.save_entry("test_key", 164537) + key_value_logger.save_entry("really_long_key_with_more_than_30_characters", 0.98765357623989) + key_value_logger.dump() + + assert mock_log.call_count == 4 + mock_log.assert_has_calls( + [ + call(logging.DEBUG, "---------------------------------------------"), + call(logging.DEBUG, "| really_long_key_with_more_t... | 0.988 |"), + call(logging.DEBUG, "| test_key | 1.65e+05 |"), + call(logging.DEBUG, "---------------------------------------------"), + ] + ) + + +@patch("midst_toolkit.common.logger.log") +def test_key_value_logger_dump_empty(mock_log: Mock) -> None: + key_value_logger = KeyValueLogger() + key_value_logger.dump() + + assert mock_log.call_count == 1 + mock_log.assert_has_calls([call(logging.DEBUG, "WARNING: tried to write empty key-value dict")]) From b15059ea053be76b66f51793197e535293ecc802 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 23 Sep 2025 12:30:03 -0400 Subject: [PATCH 34/39] Addressing David's comments --- src/midst_toolkit/common/logger.py | 28 +++++++++++++++------ src/midst_toolkit/models/clavaddpm/train.py | 13 +++++----- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/midst_toolkit/common/logger.py b/src/midst_toolkit/common/logger.py index de1ed80e..342213ed 100644 --- a/src/midst_toolkit/common/logger.py +++ b/src/midst_toolkit/common/logger.py @@ -185,7 +185,8 @@ def __init__(self, log_level: int = logging.DEBUG): Initialize the key-value logger. Args: - log_level: The log level to use when dumping the key-value pairs. Defaults to logging.DEBUG. + log_level: The log level to use when dumping the key-value pairs. Should match + one of the logging levels in the logging python module. Defaults to logging.DEBUG. """ self.key_to_value: defaultdict[str, float] = defaultdict(float) self.key_to_count: defaultdict[str, int] = defaultdict(int) @@ -223,23 +224,23 @@ def dump(self) -> None: # Create strings for printing key_to_string = {} for key, value in sorted(self.key_to_value.items()): - value_string = "%-8.3g" % value if hasattr(value, "__float__") else str(value) + value_string = "%-8.3g" % value key_to_string[self._truncate(key)] = self._truncate(value_string) if len(key_to_string) == 0: log(self.log_level, "WARNING: tried to write empty key-value dict") return - # Find max widths - key_width = max(map(len, key_to_string.keys())) - value_width = max(map(len, key_to_string.values())) + # Find maximum number of characters for the keys and values + key_max_width = max(map(len, key_to_string.keys())) + value_max_width = max(map(len, key_to_string.values())) # Write out the data - dashes = "-" * (key_width + value_width + 7) + dashes = "-" * (key_max_width + value_max_width + 7) log(self.log_level, dashes) sorted_key_to_string = sorted(key_to_string.items(), key=lambda kv: kv[0].lower()) for k, v in sorted_key_to_string: - line = "| %s%s | %s%s |" % (k, " " * (key_width - len(k)), v, " " * (value_width - len(v))) + line = f"| {k}{self._add_spacing(k, key_max_width)} | {v}{self._add_spacing(v, value_max_width)} |" log(self.log_level, line) log(self.log_level, dashes) @@ -259,3 +260,16 @@ def _truncate(self, s: str) -> str: """ max_length = 30 return s[: max_length - 3] + "..." if len(s) > max_length else s + + def _add_spacing(self, element: str, max_width: int) -> str: + """ + Add spacing to an element to make it the same length as the maximum width. + + Args: + element: The element to add spacing to. + max_width: The maximum width to add spacing to. + + Returns: + The element with spacing added to make it the same length as the maximum width. + """ + return " " * (max_width - len(element)) diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index a6cf2ec7..ca360352 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -475,7 +475,8 @@ def train_classifier( ) classifier.train() - if not step % logger_interval: + if step % logger_interval == 0: + # Dump the metrics every logger_interval number of steps key_value_logger.dump() # test classifier @@ -584,7 +585,7 @@ def _numerical_forward_backward_log( prefix: The prefix for the loss. Defaults to "train". remove_first_col: Whether to remove the first column of the batch. Defaults to False. device: The device to use. Defaults to "cuda". - key_value_logger: The key-value logger to log ther losses. If None, the losses are not logged. + key_value_logger: The key-value logger to log the losses. If None, the losses are not logged. """ batch, labels = next(data_loader) labels = labels.long().to(device) @@ -646,7 +647,7 @@ def _compute_top_k( def _log_loss_dict( diffusion: GaussianMultinomialDiffusion, - ts: Tensor, + timesteps: Tensor, losses: dict[str, Tensor], key_value_logger: KeyValueLogger | None = None, ) -> None: @@ -655,9 +656,9 @@ def _log_loss_dict( Args: diffusion: The diffusion object. - ts: The timesteps. + timesteps: The timesteps tensor. losses: The losses. - key_value_logger: The key-value logger. If None, the losses are not logged. + key_value_logger: The key-value logger to log the losses. If None, the losses are not logged. """ if key_value_logger is None: return @@ -665,7 +666,7 @@ def _log_loss_dict( for key, values in losses.items(): key_value_logger.save_entry_mean(key, values.mean().item()) # Log the quantiles (four quartiles, in particular). - for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): + for sub_t, sub_loss in zip(timesteps.cpu().numpy(), values.detach().cpu().numpy()): quartile = int(4 * sub_t / diffusion.num_timesteps) key_value_logger.save_entry_mean(f"{key}_q{quartile}", sub_loss) From aeb8eef1d434ac85fec0ba5ecf97db75eb99aae6 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 23 Sep 2025 17:10:15 -0400 Subject: [PATCH 35/39] WIP Partially addressed comments --- .../models/clavaddpm/data_loaders.py | 143 +++++++++++------- src/midst_toolkit/models/clavaddpm/dataset.py | 67 +++++--- .../models/clavaddpm/test_model.py | 10 +- 3 files changed, 136 insertions(+), 84 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/data_loaders.py b/src/midst_toolkit/models/clavaddpm/data_loaders.py index 0dd6317a..4ac510ee 100644 --- a/src/midst_toolkit/models/clavaddpm/data_loaders.py +++ b/src/midst_toolkit/models/clavaddpm/data_loaders.py @@ -2,6 +2,7 @@ import os from collections.abc import Generator from logging import INFO +from pathlib import Path from typing import Any, Literal import numpy as np @@ -13,7 +14,7 @@ def load_multi_table( - data_dir: str, verbose: bool = True + data_dir: Path, verbose: bool = True ) -> tuple[dict[str, Any], list[tuple[str, str]], dict[str, Any]]: """ Load the multi-table dataset from the data directory. @@ -28,21 +29,25 @@ def load_multi_table( - The relation order between the tables. - The dataset metadata dictionary. """ - dataset_meta = json.load(open(os.path.join(data_dir, "dataset_meta.json"), "r")) + with open(data_dir / "dataset_meta.json", "r") as f: + dataset_meta = json.load(f) relation_order = dataset_meta["relation_order"] tables = {} for table, meta in dataset_meta["tables"].items(): - if os.path.exists(os.path.join(data_dir, "train.csv")): - train_df = pd.read_csv(os.path.join(data_dir, "train.csv")) + if (data_dir / "train.csv").exists(): + train_df = pd.read_csv(data_dir / "train.csv") else: - train_df = pd.read_csv(os.path.join(data_dir, f"{table}.csv")) + train_df = pd.read_csv(data_dir / f"{table}.csv") + + with open(data_dir / f"{table}_domain.json", "r") as f: + domain = json.load(f) + tables[table] = { "df": train_df, - "domain": json.load(open(os.path.join(data_dir, f"{table}_domain.json"))), - # ruff: noqa: SIM115 + "domain": domain, "children": meta["children"], "parents": meta["parents"], } @@ -66,7 +71,7 @@ def load_multi_table( def get_info_from_domain(data_df: pd.DataFrame, domain_dict: dict[str, Any]) -> dict[str, Any]: """ - Get the information dictionaryfrom the domain dictionary. + Get the information dictionary from the domain dictionary. Args: data_df: The dataframe containing the data. @@ -78,7 +83,7 @@ def get_info_from_domain(data_df: pd.DataFrame, domain_dict: dict[str, Any]) -> - cat_col_idx: The indices of the categorical columns. - target_col_idx: The indices of the target columns. - task_type: The type of the task. - - column_names: The names of the columns. + - column_names: The names of all the columns. """ info: dict[str, Any] = {} info["num_col_idx"] = [] @@ -107,29 +112,41 @@ def pipeline_process_data( verbose: bool = True, ) -> tuple[dict[str, Any], dict[str, Any]]: """ - Process the data through the pipeline. + Processes the data to be sent through the pipeline. + + Will split the data into training and test sets (saving the data when specified), + replace invalid and missing values, split the data sets categorical, numerical + and target columns, and populate the information dictionary with additional + metadata. Args: - name: The name of the table. + name: The name of the table. Used to name the files when saving the data. data_df: The dataframe containing the data. info: The information dictionary, retrieved from the get_info_from_domain function. - ratio: The ratio of the data to be used for training. Optional, default is 0.9. + ratio: The ratio of the data to be used for training. Should be between 0 and 1. + If it's == 1, it will only return the training set. Optional, default is 0.9. save: Whether to save the data. Optional, default is False. verbose: Whether to print verbose output. Optional, default is True. Returns: A tuple with 2 values: - The data dictionary containing the following keys: - - df: The dataframe containing the data. - - numpy: A dictionary with the numeric data, containing the keys: - - X_num_train: The numeric data for the training set. - - X_cat_train: The categorical data for the training set. - - y_train: The target data for the training set. - - X_num_test: The numeric data for the test set. - - X_cat_test: The categorical data for the test set. - - y_test: The target data for the test set. + - "df": The dataframe containing the data. + - "train": The dataframe containing the training set. + - "test": The dataframe containing the test set. It will be absent if ratio == 1. + - "numpy": A dictionary with the numeric data, containing the keys: + - "X_num_train": The numeric data for the training set. + - "X_cat_train": The categorical data for the training set. + - "y_train": The target data for the training set. + - "X_num_test": The numeric data for the test set. It will be absent if ratio == 1. + - "X_cat_test": The categorical data for the test set. It will be absent if ratio == 1. + - "y_test": The target data for the test set. It will be absent if ratio == 1. - The information dictionary with updated values. """ + assert 0 < ratio <= 1, "Ratio must be between 0 and 1." + if ratio == 1: + log(INFO, "Ratio is 1, so the data will not be split into training and test sets.") + num_data = data_df.shape[0] column_names = info["column_names"] if info["column_names"] else data_df.columns.tolist() @@ -139,7 +156,7 @@ def pipeline_process_data( target_col_idx = info["target_col_idx"] idx_mapping, inverse_idx_mapping, idx_name_mapping = get_column_name_mapping( - data_df, num_col_idx, cat_col_idx, target_col_idx, column_names + data_df, num_col_idx, cat_col_idx, column_names ) num_columns = [column_names[i] for i in num_col_idx] @@ -151,7 +168,7 @@ def pipeline_process_data( num_test = num_data - num_train if ratio < 1: - train_df, test_df, seed = train_val_test_split(data_df, cat_columns, num_train, num_test) + train_df, test_df, seed = train_test_split(data_df, cat_columns, num_train, num_test) else: train_df = data_df.copy() @@ -284,14 +301,12 @@ def pipeline_process_data( if verbose: if ratio < 1: - str_shape = "Train dataframe shape: {}, Test dataframe shape: {}, Total dataframe shape: {}".format( - train_df.shape, test_df.shape, data_df.shape - ) + str_shape = f"Train dataframe shape: {train_df.shape}, Test dataframe shape: {test_df.shape}, Total dataframe shape: {data_df.shape}" else: - str_shape = "Table name: {}, Total dataframe shape: {}".format(name, data_df.shape) + str_shape = f"Table name: {name}, Total dataframe shape: {data_df.shape}" - str_shape += ", Numerical data shape: {}".format(X_num_train.shape) - str_shape += ", Categorical data shape: {}".format(X_cat_train.shape) + str_shape += f", Numerical data shape: {X_num_train.shape}" + str_shape += f", Categorical data shape: {X_cat_train.shape}" log(INFO, str_shape) data = { @@ -316,24 +331,32 @@ def get_column_name_mapping( data_df: pd.DataFrame, num_col_idx: list[int], cat_col_idx: list[int], - target_col_idx: list[int], column_names: list[str] | None = None, ) -> tuple[dict[int, int], dict[int, int], dict[int, str]]: """ - Get the column name mapping. + Get the column name mappings. + + Will produce 3 mappings: + - The mapping of the categorical and numerical columns from their original indices + in the dataframe to their indices in the num_col_idx and cat_col_idx lists. + - The inverse mapping of the above, i.e. the mapping from their indices in the + num_col_idx and cat_col_idx lists to their original indices in the dataframe. + - The mapping of the indices in the original dataframe to the column names for all columns. Args: data_df: The dataframe containing the data. num_col_idx: The indices of the numerical columns. cat_col_idx: The indices of the categorical columns. - target_col_idx: The indices of the target columns. - column_names: The names of the columns. + column_names: The names of the columns. Optional, default is None. If None, + it will use the columns of the dataframe. Returns: A tuple with 3 values: - - The mapping of the categorical and numerical columns to the indices. - - The mapping of the column names to the indices. - - The mapping of all the indices to the column names. + - The mapping of the categorical and numerical columns from their original indices + in the dataframe to their indices in the num_col_idx and cat_col_idx lists. + - The inverse mapping of the above, i.e. the mapping from their indices in the + num_col_idx and cat_col_idx lists to their original indices in the dataframe. + - The mapping of the indices in the original dataframe to the column names for all columns. """ if not column_names: column_names = data_df.columns.tolist() @@ -346,28 +369,29 @@ def get_column_name_mapping( for idx in range(len(column_names)): if idx in num_col_idx: - idx_mapping[int(idx)] = curr_num_idx + idx_mapping[idx] = curr_num_idx curr_num_idx += 1 elif idx in cat_col_idx: - idx_mapping[int(idx)] = curr_cat_idx + idx_mapping[idx] = curr_cat_idx curr_cat_idx += 1 else: - idx_mapping[int(idx)] = curr_target_idx + idx_mapping[idx] = curr_target_idx curr_target_idx += 1 inverse_idx_mapping = {} for k, v in idx_mapping.items(): - inverse_idx_mapping[int(v)] = k + inverse_idx_mapping[v] = k idx_name_mapping = {} for i in range(len(column_names)): - idx_name_mapping[int(i)] = column_names[i] + idx_name_mapping[i] = column_names[i] return idx_mapping, inverse_idx_mapping, idx_name_mapping -def train_val_test_split( +# TODO: refactor this function so it doesn't run the risk of running indefinitely. +def train_test_split( data_df: pd.DataFrame, cat_columns: list[str], num_train: int = 0, @@ -376,6 +400,9 @@ def train_val_test_split( """ Split the data into training and test sets. + Will make the split in a way that both sets have all the values for the categorical + columns represented. + Args: data_df: The dataframe containing the data. cat_columns: The names of the categorical columns. @@ -386,7 +413,7 @@ def train_val_test_split( A tuple with 3 values: - The training dataframe. - The test dataframe. - - The seed used for the random number generator. + - The seed used by the random number generator to generate the split. """ total_num = data_df.shape[0] idx = np.arange(total_num) @@ -417,26 +444,24 @@ def train_val_test_split( class FastTensorDataLoader: - """ - Defines a faster dataloader for PyTorch tensors. - - A DataLoader-like object for a set of tensors that can be much faster than - TensorDataset + DataLoader because dataloader grabs individual indices of - the dataset and calls cat (slow). - Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 - """ - def __init__(self, *tensors: torch.Tensor, batch_size: int = 32, shuffle: bool = False): """ Initialize a FastTensorDataLoader. + A DataLoader-like object for a set of tensors that can be much faster than + TensorDataset + DataLoader because dataloader grabs individual indices of + the dataset and calls cat (slow). + Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 + Args: - *tensors: tensors to store. Must have the same length @ dim 0. - batch_size: batch size to load. + *tensors: tensors to store. All tensors must have the same length as the tensor at dimension 0. + batch_size: batch size to load. Optional, default is 32. shuffle: if True, shuffle the data *in-place* whenever an - iterator is created out of this object. + iterator is created out of this object. Optional, default is False. """ - assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) + assert all(t.shape[0] == tensors[0].shape[0] for t in tensors), ( + "All tensors must have the same length as the tensor at dimension 0." + ) self.tensors = tensors self.dataset_len = self.tensors[0].shape[0] @@ -450,7 +475,7 @@ def __init__(self, *tensors: torch.Tensor, batch_size: int = 32, shuffle: bool = self.n_batches = n_batches def __iter__(self): - """Defines the iterator for the FastTensorDataLoader.""" + """Define the iterator for the FastTensorDataLoader.""" if self.shuffle: r = torch.randperm(self.dataset_len) self.tensors = [t[r] for t in self.tensors] # type: ignore[assignment] @@ -458,7 +483,11 @@ def __iter__(self): return self def __next__(self): - """Get the next batch of data from the dataset.""" + """Get the next batch of data from the dataset. + + Returns: + A tuple of tensors, one for each tensor in the FastTensorDataLoader. + """ if self.i >= self.dataset_len: raise StopIteration batch = tuple(t[self.i : self.i + self.batch_size] for t in self.tensors) diff --git a/src/midst_toolkit/models/clavaddpm/dataset.py b/src/midst_toolkit/models/clavaddpm/dataset.py index 8a5fbf57..d38abdcd 100644 --- a/src/midst_toolkit/models/clavaddpm/dataset.py +++ b/src/midst_toolkit/models/clavaddpm/dataset.py @@ -31,6 +31,7 @@ from midst_toolkit.models.clavaddpm.typing import ArrayDict +# TODO: Dunders are special case in python, rename these values to something else. CAT_MISSING_VALUE = "__nan__" CAT_RARE_VALUE = "__rare__" @@ -105,34 +106,47 @@ class Dataset: num_transform: StandardScaler | None = None @classmethod - def from_dir(cls, dir_: Path | str) -> Self: + def from_dir(cls, directory: Path) -> Self: """ Load a dataset from a directory. Args: - dir_: The directory to load the dataset from. Can be a Path object or a path string. + directory: The directory to load the dataset from. Can be a Path object or a path string. Returns: The loaded dataset. """ - dir_ = Path(dir_) - splits = [k for k in ["train", "val", "test"] if dir_.joinpath(f"y_{k}.npy").exists()] - - def load(item: str) -> ArrayDict: - return {x: cast(np.ndarray, np.load(dir_ / f"{item}_{x}.npy", allow_pickle=True)) for x in splits} - - if Path(dir_ / "info.json").exists(): - info = json.loads(Path(dir_ / "info.json").read_text()) + if Path(directory / "info.json").exists(): + info = json.loads(Path(directory / "info.json").read_text()) return cls( - load("X_num") if dir_.joinpath("X_num_train.npy").exists() else None, - load("X_cat") if dir_.joinpath("X_cat_train.npy").exists() else None, - load("y"), + cls._load_datasets(directory, "X_num") if directory.joinpath("X_num_train.npy").exists() else None, + cls._load_datasets(directory, "X_cat") if directory.joinpath("X_cat_train.npy").exists() else None, + cls._load_datasets(directory, "y"), {}, TaskType(info["task_type"]), info.get("n_classes"), ) + @classmethod + def _load_datasets(cls, directory: Path, dataset_name: str) -> ArrayDict: + """ + Load all the dataset splits from a directory. + + Will check which of the splits exist in the directory for the + given dataset_name and load all of them. + + Args: + directory: The directory to load the dataset from. + dataset_name: The dataset_name to load. + + Returns: + The loaded datasets with all the splits. + """ + splits = [k for k in ["train", "val", "test"] if directory.joinpath(f"y_{k}.npy").exists()] + # TODO: figure out if there is a way of getting rid of the cast + return {x: cast(np.ndarray, np.load(directory / f"{dataset_name}_{x}.npy", allow_pickle=True)) for x in splits} + @property def is_binclass(self) -> bool: """ @@ -168,6 +182,8 @@ def n_num_features(self) -> int: """ Get the number of numerical features in the dataset. + That number should be in the second dimension of the tensors of X_num. + Returns: The number of numerical features in the dataset. """ @@ -178,6 +194,8 @@ def n_cat_features(self) -> int: """ Get the number of categorical features in the dataset. + That number should be in the second dimension of the tensors of X_cat tensor. + Returns: The number of categorical features in the dataset. """ @@ -194,24 +212,28 @@ def n_features(self) -> int: return self.n_num_features + self.n_cat_features # TODO: make partition into an Enum - def size(self, partition: Literal["train", "val", "test"] | None) -> int: + def size(self, split: Literal["train", "val", "test"] | None) -> int: """ - Get the size of the dataset. + Get the size of a dataset split. If no split is provided, the size of + the entire dataset is returned. Args: - partition: The partition of the dataset to get the size of. + split: The split of the dataset to get the size of. If None, the size of the entire dataset is returned. Returns: The size of the dataset. """ - return sum(map(len, self.y.values())) if partition is None else len(self.y[partition]) + return sum(map(len, self.y.values())) if split is None else len(self.y[split]) @property def nn_output_dim(self) -> int: """ Get the output dimension of the neural network. + This only works for multiclass classification and regression tasks. Binary classification + tasks have output dimension of 2. + Returns: The output dimension of the neural network. """ @@ -220,17 +242,17 @@ def nn_output_dim(self) -> int: return self.n_classes return 1 - def get_category_sizes(self, partition: Literal["train", "val", "test"]) -> list[int]: + def get_category_sizes(self, split: Literal["train", "val", "test"]) -> list[int]: """ - Get the size of the categories in the dataset. + Get the size of the categories in the specified split of the dataset. Args: - partition: The partition of the dataset to get the size of the categories of. + split: The split of the dataset to get the size of the categories of. Returns: - The size of the categories in the partition of the dataset. + The size of the categories in the specified split of the dataset. """ - return [] if self.X_cat is None else get_category_sizes(self.X_cat[partition]) + return [] if self.X_cat is None else get_category_sizes(self.X_cat[split]) # TODO: prediciton_type should be of type PredictionType def calculate_metrics( @@ -312,6 +334,7 @@ def calculate_metrics( result = {"rmse": rmse, "r2": r2} else: labels, probs = _get_labels_and_probs(y_pred, task_type, prediction_type) + # TODO: figure out if there is a way of getting rid of the cast result = cast(dict[str, Any], classification_report(y_true, labels, output_dict=True)) if task_type == TaskType.BINCLASS: result["roc_auc"] = roc_auc_score(y_true, probs) diff --git a/tests/integration/models/clavaddpm/test_model.py b/tests/integration/models/clavaddpm/test_model.py index 42bc5e76..701c9a1f 100644 --- a/tests/integration/models/clavaddpm/test_model.py +++ b/tests/integration/models/clavaddpm/test_model.py @@ -46,7 +46,7 @@ @pytest.mark.integration_test() def test_load_single_table(): - tables, relation_order, dataset_meta = load_multi_table("tests/integration/assets/single_table/") + tables, relation_order, dataset_meta = load_multi_table(Path("tests/integration/assets/single_table/")) assert list(tables.keys()) == ["trans"] @@ -121,7 +121,7 @@ def test_load_single_table(): @pytest.mark.integration_test() def test_load_multi_table(): - tables, relation_order, dataset_meta = load_multi_table("tests/integration/assets/multi_table/") + tables, relation_order, dataset_meta = load_multi_table(Path("tests/integration/assets/multi_table/")) assert list(tables.keys()) == ["account", "trans"] @@ -251,7 +251,7 @@ def test_train_single_table(tmp_path: Path): set_all_random_seeds(seed=133742, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) # Act - tables, relation_order, _ = load_multi_table("tests/integration/assets/single_table/") + tables, relation_order, _ = load_multi_table(Path("tests/integration/assets/single_table/")) tables, models = clava_training( tables, relation_order, tmp_path, DIFFUSION_CONFIG, CLASSIFIER_CONFIG, device="cpu" ) @@ -309,7 +309,7 @@ def test_train_multi_table(tmp_path: Path): set_all_random_seeds(seed=133742, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) # Act - tables, relation_order, _ = load_multi_table("tests/integration/assets/multi_table/") + tables, relation_order, _ = load_multi_table(Path("tests/integration/assets/multi_table/")) tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, CLUSTERING_CONFIG) models = clava_training(tables, relation_order, tmp_path, DIFFUSION_CONFIG, CLASSIFIER_CONFIG, device="cpu") @@ -395,7 +395,7 @@ def test_clustering_reload(tmp_path: Path): set_all_random_seeds(seed=133742, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) # Act - tables, relation_order, dataset_meta = load_multi_table("tests/integration/assets/multi_table/") + tables, relation_order, dataset_meta = load_multi_table(Path("tests/integration/assets/multi_table/")) tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, CLUSTERING_CONFIG) # Assert From 4545a9a75834ff29d10a75efa4f8e8255c064579 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Thu, 25 Sep 2025 13:19:02 -0400 Subject: [PATCH 36/39] Fixing the rest of David's comments --- .../models/clavaddpm/data_loaders.py | 7 +- src/midst_toolkit/models/clavaddpm/dataset.py | 65 +++++++++++++++---- src/midst_toolkit/models/clavaddpm/model.py | 8 +-- 3 files changed, 61 insertions(+), 19 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/data_loaders.py b/src/midst_toolkit/models/clavaddpm/data_loaders.py index 4ac510ee..748e8d66 100644 --- a/src/midst_toolkit/models/clavaddpm/data_loaders.py +++ b/src/midst_toolkit/models/clavaddpm/data_loaders.py @@ -495,7 +495,12 @@ def __next__(self): return batch def __len__(self): - """Get the number of batches in the dataset.""" + """ + Get the number of batches in the dataset. + + Returns: + (int) The number of batches in the dataset. + """ return self.n_batches diff --git a/src/midst_toolkit/models/clavaddpm/dataset.py b/src/midst_toolkit/models/clavaddpm/dataset.py index d38abdcd..e759f9f7 100644 --- a/src/midst_toolkit/models/clavaddpm/dataset.py +++ b/src/midst_toolkit/models/clavaddpm/dataset.py @@ -194,7 +194,7 @@ def n_cat_features(self) -> int: """ Get the number of categorical features in the dataset. - That number should be in the second dimension of the tensors of X_cat tensor. + That number should be in the second dimension of the tensors of X_cat. Returns: The number of categorical features in the dataset. @@ -288,7 +288,8 @@ def calculate_metrics( # TODO consider moving all the functions below into the Dataset class def get_category_sizes(X: torch.Tensor | np.ndarray) -> list[int]: """ - Get the size of the categories in the data. + Get the size of the categories in the data by counting the number of + unique values in each column. Args: X: The data to get the size of the categories of. @@ -320,7 +321,40 @@ def calculate_metrics( y_info: A dictionary with metadata about the labels. Returns: - The metrics of the predictions. + The metrics of the predictions as a dictionary with the following keys: + If the task type is TaskType.REGRESSION: + { + "rmse": The root mean squared error. + "r2": The R^2 score. + } + + If the task type is TaskType.MULTICLASS, it will have a key for each label + with the following metrics (result of sklearn.metrics.classification_report): + { + "label-1": { + "precision": The precision of the label. + "recall": The recall of the label. + "f1-score": The F1 score of the label. + "support": The number of occurrences of this label in y_true. + }, + "label-2": {...} + ... + } + + If the task type is TaskType.BINCLASS, it will have a key for each label + with the following metrics ((result of sklearn.metrics.classification_report), + and an additional ROC AUC metric: + { + "label-1": { + "precision": The precision of the label. + "recall": The recall of the label. + "f1-score": The F1 score of the label. + "support": The number of occurrences of this label in y_true. + }, + "label-2": {...} + ... + "roc_auc": The ROC AUC score. + } """ task_type = TaskType(task_type) if prediction_type is not None: @@ -333,7 +367,7 @@ def calculate_metrics( r2 = r2_score(y_true, y_pred) result = {"rmse": rmse, "r2": r2} else: - labels, probs = _get_labels_and_probs(y_pred, task_type, prediction_type) + labels, probs = _get_predicted_labels_and_probs(y_pred, task_type, prediction_type) # TODO: figure out if there is a way of getting rid of the cast result = cast(dict[str, Any], classification_report(y_true, labels, output_dict=True)) if task_type == TaskType.BINCLASS: @@ -360,11 +394,13 @@ def calculate_rmse(y_true: np.ndarray, y_pred: np.ndarray, std: float | None) -> return rmse -def _get_labels_and_probs( +def _get_predicted_labels_and_probs( y_pred: np.ndarray, task_type: TaskType, prediction_type: PredictionType | None ) -> tuple[np.ndarray, np.ndarray | None]: """ Get the labels and probabilities from the predictions. + If prediction_type is None, will return the predicted labels as is + and the probabilities as None. Args: y_pred: The predicted labels as a numpy array. @@ -395,10 +431,9 @@ def _get_labels_and_probs( def make_dataset_from_df( # ruff: noqa: PLR0915, PLR0912 df: pd.DataFrame, - T: Transformations, - # ruff: noqa: N803 + transformations: Transformations, is_y_cond: Literal["concat", "embedding", "none"], - df_info: pd.DataFrame, + df_info: dict[str, Any], ratios: list[float] | None = None, std: float = 0, ) -> tuple[Dataset, dict[int, LabelEncoder], list[int]]: @@ -414,9 +449,9 @@ def make_dataset_from_df( Args: df: The pandas DataFrame to generate the dataset from. - T: The transformations to apply to the dataset. + transformations: The transformations to apply to the dataset. is_y_cond: The condition on the y column. - concat: y is concatenated to X, the model learn a joint distribution of (y, X) + concat: y is concatenated to X, the model learns a joint distribution of (y, X) embedding: y is not concatenated to X. During computations, y is embedded and added to the latent vector of X none: y column is completely ignored @@ -435,7 +470,8 @@ def make_dataset_from_df( In this case, y is completely independent of X. df_info: A dictionary with metadata about the DataFrame. - ratios: The ratios of the dataset to split into train, val, and test. Optional, default is [0.7, 0.2, 0.1]. + ratios: The ratios of the dataset to split into train, val, and test. The sum of + the ratios must amount to 1 (with a tolerance of 0.01). Optional, default is [0.7, 0.2, 0.1]. std: The standard deviation of the labels. Optional, default is 0. Returns: @@ -444,6 +480,8 @@ def make_dataset_from_df( if ratios is None: ratios = [0.7, 0.2, 0.1] + assert np.isclose(sum(ratios), 1, atol=0.01), "The sum of the ratios must amount to 1 (with a tolerance of 0.01)." + train_val_df, test_df = train_test_split(df, test_size=ratios[2], random_state=42) train_df, val_df = train_test_split(train_val_df, test_size=ratios[1] / (ratios[0] + ratios[1]), random_state=42) @@ -541,8 +579,7 @@ def make_dataset_from_df( X_num = X_cat X_cat = None - D = Dataset( - # ruff: noqa: N806 + dataset = Dataset( X_num, None, y, @@ -551,7 +588,7 @@ def make_dataset_from_df( n_classes=df_info["n_classes"], ) - return transform_dataset(D, T, None), label_encoders, column_orders + return transform_dataset(dataset, transformations, None), label_encoders, column_orders def transform_dataset( diff --git a/src/midst_toolkit/models/clavaddpm/model.py b/src/midst_toolkit/models/clavaddpm/model.py index e9a6e0b9..bb7dfd61 100644 --- a/src/midst_toolkit/models/clavaddpm/model.py +++ b/src/midst_toolkit/models/clavaddpm/model.py @@ -28,7 +28,7 @@ def __init__( Args: d_in: The input dimension size. d_out: The output dimension size. - dim_t: The dimension size of the timestamp. + dim_t: The dimension size of the timestep. hidden_sizes: The list of sizes for the hidden layers. dropout_prob: The dropout probability. Optional, default is 0.5. num_heads: The number of heads for the transformer layer. Optional, default is 2. @@ -84,7 +84,7 @@ def forward(self, x: Tensor, timesteps: Tensor) -> Tensor: def get_table_info(df: pd.DataFrame, domain_dict: dict[str, Any], y_col: str) -> dict[str, Any]: """ - Get the dictionary oftable information. + Get the dictionary of table information. Args: df: The dataframe containing the data. @@ -596,7 +596,7 @@ def __init__( num_classes: The number of classes. is_y_cond: The condition on the y column. Can be "concat", "embedding", or "none". rtdl_params: The dictionary of parameters for the MLP. - dim_t: The dimension size of the timestamp. + dim_t: The dimension size of the timestep. """ super().__init__() self.dim_t = dim_t @@ -655,7 +655,7 @@ def __init__( d_in: The input dimension size. num_classes: The number of classes. rtdl_params: The dictionary of parameters for the ResNet. - dim_t: The dimension size of the timestamp. + dim_t: The dimension size of the timestep. is_y_cond: The condition on the y column. Can be "concat", "embedding", or "none". Optional, default is None. """ From 75823f8706ec813cdebecf159282694029a4d3ba Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Thu, 25 Sep 2025 14:37:25 -0400 Subject: [PATCH 37/39] Fixing more code comments --- src/midst_toolkit/models/clavaddpm/data_loaders.py | 9 +++++---- src/midst_toolkit/models/clavaddpm/dataset.py | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/data_loaders.py b/src/midst_toolkit/models/clavaddpm/data_loaders.py index 7de18567..72c12ef2 100644 --- a/src/midst_toolkit/models/clavaddpm/data_loaders.py +++ b/src/midst_toolkit/models/clavaddpm/data_loaders.py @@ -461,7 +461,7 @@ def train_test_split( class FastTensorDataLoader: - def __init__(self, *tensors: torch.Tensor, batch_size: int = 32, shuffle: bool = False): + def __init__(self, tensors: tuple[torch.Tensor, ...], batch_size: int = 32, shuffle: bool = False): """ Initialize a FastTensorDataLoader. @@ -471,13 +471,14 @@ def __init__(self, *tensors: torch.Tensor, batch_size: int = 32, shuffle: bool = Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 Args: - *tensors: tensors to store. All tensors must have the same length as the tensor at dimension 0. + tensors: a tuple of tensors to store. The first dimension for each tensor is the + number of samples, and all tensors must have the same number of samples. batch_size: batch size to load. Optional, default is 32. shuffle: if True, shuffle the data *in-place* whenever an iterator is created out of this object. Optional, default is False. """ assert all(t.shape[0] == tensors[0].shape[0] for t in tensors), ( - "All tensors must have the same length as the tensor at dimension 0." + "All tensors must have the same amount of samples." ) self.tensors = tensors @@ -553,6 +554,6 @@ def prepare_fast_dataloader( assert dataset.X_num is not None X = torch.from_numpy(dataset.X_num[split]).float() y = torch.from_numpy(dataset.y[split]).float() if y_type == "float" else torch.from_numpy(dataset.y[split]).long() - dataloader = FastTensorDataLoader(X, y, batch_size=batch_size, shuffle=(split == "train")) + dataloader = FastTensorDataLoader((X, y), batch_size=batch_size, shuffle=(split == "train")) while True: yield from dataloader diff --git a/src/midst_toolkit/models/clavaddpm/dataset.py b/src/midst_toolkit/models/clavaddpm/dataset.py index 4fa83cb4..b969b974 100644 --- a/src/midst_toolkit/models/clavaddpm/dataset.py +++ b/src/midst_toolkit/models/clavaddpm/dataset.py @@ -227,15 +227,15 @@ def size(self, split: Literal["train", "val", "test"] | None) -> int: return sum(map(len, self.y.values())) if split is None else len(self.y[split]) @property - def nn_output_dim(self) -> int: + def output_dimension(self) -> int: """ - Get the output dimension of the neural network. + Get the output dimension of the model. This only works for multiclass classification and regression tasks. Binary classification tasks have output dimension of 2. Returns: - The output dimension of the neural network. + The output dimension of the model. """ if self.is_multiclass: assert self.n_classes is not None From eedad3b1d08f96dff460a0cfbe429936f67bc620 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 26 Sep 2025 11:25:40 -0400 Subject: [PATCH 38/39] Addressing one more comment. --- src/midst_toolkit/models/clavaddpm/dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/dataset.py b/src/midst_toolkit/models/clavaddpm/dataset.py index b969b974..9157418e 100644 --- a/src/midst_toolkit/models/clavaddpm/dataset.py +++ b/src/midst_toolkit/models/clavaddpm/dataset.py @@ -231,8 +231,9 @@ def output_dimension(self) -> int: """ Get the output dimension of the model. - This only works for multiclass classification and regression tasks. Binary classification - tasks have output dimension of 2. + For self.task_type == TaskType.MULTICLASS, the output dimension is the number of classes. + For self.task_type == TaskType.REGRESSION, the output dimension is 1. + For self.task_type == TaskType.BINCLASS, the output dimension is also 1 because it is label encoded. Returns: The output dimension of the model. From abcc4796cb23eac4c547050b60cb2db11148cc2d Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 26 Sep 2025 12:48:29 -0400 Subject: [PATCH 39/39] Removing logger from docs --- docs/api.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/docs/api.md b/docs/api.md index 30fbca29..9df08cf6 100644 --- a/docs/api.md +++ b/docs/api.md @@ -7,13 +7,6 @@ show_root_heading: true show_root_full_path: true -## Logger Module - -::: midst_toolkit.core.logger - options: - show_root_heading: true - show_root_full_path: true - ## Diffusion Utils Module ::: midst_toolkit.models.clavaddpm.diffusion_utils