diff --git a/src/midst_toolkit/models/clavaddpm/model.py b/src/midst_toolkit/models/clavaddpm/model.py index 6ad74307..c88fbdbc 100644 --- a/src/midst_toolkit/models/clavaddpm/model.py +++ b/src/midst_toolkit/models/clavaddpm/model.py @@ -2,7 +2,6 @@ import json import math import pickle -from abc import ABC, abstractmethod from collections import Counter from collections.abc import Callable, Generator from copy import deepcopy @@ -38,10 +37,6 @@ from torch import Tensor, nn from midst_toolkit.common.enumerations import PredictionType, TaskType -from midst_toolkit.core import logger -from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import ( - GaussianMultinomialDiffusion, -) Normalization = Literal["standard", "quantile", "minmax"] @@ -493,228 +488,6 @@ def get_model( raise ValueError("Unknown model!") -class ScheduleSampler(ABC): - """ - A distribution over timesteps in the diffusion process, intended to reduce - variance of the objective. - - By default, samplers perform unbiased importance sampling, in which the - objective's mean is unchanged. - However, subclasses may override sample() to change how the resampled - terms are reweighted, allowing for actual changes in the objective. - """ - - @abstractmethod - def weights(self) -> Tensor: - """ - Get a numpy array of weights, one per diffusion step. - - The weights needn't be normalized, but must be positive. - """ - - def sample(self, batch_size: int, device: str) -> tuple[Tensor, Tensor]: - """ - Importance-sample timesteps for a batch. - - :param batch_size: the number of timesteps. - :param device: the torch device to save to. - :return: a tuple (timesteps, weights): - - timesteps: a tensor of timestep indices. - - weights: a tensor of weights to scale the resulting losses. - """ - w = self.weights().cpu().numpy() - p = w / np.sum(w) - indices_np = np.random.choice(len(p), size=(batch_size,), p=p) - indices = torch.from_numpy(indices_np).long().to(device) - weights_np = 1 / (len(p) * p[indices_np]) - weights = torch.from_numpy(weights_np).float().to(device) - return indices, weights - - -class UniformSampler(ScheduleSampler): - def __init__(self, diffusion: GaussianMultinomialDiffusion): - self.diffusion = diffusion - self._weights = torch.from_numpy(np.ones([diffusion.num_timesteps])) - - def weights(self) -> Tensor: - return self._weights - - -class LossAwareSampler(ScheduleSampler): - def update_with_local_losses(self, local_ts: Tensor, local_losses: Tensor) -> None: - """ - Update the reweighting using losses from a model. - - Call this method from each rank with a batch of timesteps and the - corresponding losses for each of those timesteps. - This method will perform synchronization to make sure all of the ranks - maintain the exact same reweighting. - - :param local_ts: an integer Tensor of timesteps. - :param local_losses: a 1D Tensor of losses. - """ - batch_sizes = [ - torch.tensor([0], dtype=torch.int32, device=local_ts.device) - for _ in range(torch.distributed.get_world_size()) - ] - torch.distributed.all_gather( - batch_sizes, - torch.tensor([len(local_ts)], dtype=torch.int32, device=local_ts.device), - ) - - # Pad all_gather batches to be the maximum batch size. - max_bs = max([int(x.item()) for x in batch_sizes]) - - timestep_batches = [torch.zeros(max_bs).to(local_ts) for bs in batch_sizes] - loss_batches = [torch.zeros(max_bs).to(local_losses) for bs in batch_sizes] - torch.distributed.all_gather(timestep_batches, local_ts) - torch.distributed.all_gather(loss_batches, local_losses) - timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]] - losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] - self.update_with_all_losses(timesteps, losses) - - @abstractmethod - def update_with_all_losses(self, ts: list[int], losses: list[float]) -> None: - """ - Update the reweighting using losses from a model. - - Sub-classes should override this method to update the reweighting - using losses from the model. - - This method directly updates the reweighting without synchronizing - between workers. It is called by update_with_local_losses from all - ranks with identical arguments. Thus, it should have deterministic - behavior to maintain state across workers. - - :param ts: a list of int timesteps. - :param losses: a list of float losses, one per timestep. - """ - - -class LossSecondMomentResampler(LossAwareSampler): - def __init__( - self, - diffusion: GaussianMultinomialDiffusion, - history_per_term: int = 10, - uniform_prob: float = 0.001, - ): - self.diffusion = diffusion - self.history_per_term = history_per_term - self.uniform_prob = uniform_prob - self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64) - self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.uint) - - def weights(self): - if not self._warmed_up(): - return np.ones([self.diffusion.num_timesteps], dtype=np.float64) - weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) - weights /= np.sum(weights) - weights *= 1 - self.uniform_prob - weights += self.uniform_prob / len(weights) - return weights - - def update_with_all_losses(self, ts: list[int], losses: list[float]) -> None: - for t, loss in zip(ts, losses): - if self._loss_counts[t] == self.history_per_term: - # Shift out the oldest loss term. - self._loss_history[t, :-1] = self._loss_history[t, 1:] - self._loss_history[t, -1] = loss - else: - self._loss_history[t, self._loss_counts[t]] = loss - self._loss_counts[t] += 1 - - def _warmed_up(self) -> bool: - return (self._loss_counts == self.history_per_term).all() - - -def create_named_schedule_sampler(name: str, diffusion: GaussianMultinomialDiffusion) -> ScheduleSampler: - """ - Create a ScheduleSampler from a library of pre-defined samplers. - - :param name: the name of the sampler. - :param diffusion: the diffusion object to sample for. - """ - if name == "uniform": - return UniformSampler(diffusion) - if name == "loss-second-moment": - return LossSecondMomentResampler(diffusion) - raise NotImplementedError(f"unknown schedule sampler: {name}") - - -def split_microbatches( - microbatch: int, - batch: Tensor, - labels: Tensor, - t: Tensor, -) -> Generator[tuple[Tensor, Tensor, Tensor]]: - bs = len(batch) - if microbatch == -1 or microbatch >= bs: - yield batch, labels, t - else: - for i in range(0, bs, microbatch): - yield batch[i : i + microbatch], labels[i : i + microbatch], t[i : i + microbatch] - - -def compute_top_k(logits: Tensor, labels: Tensor, k: int, reduction: str = "mean") -> Tensor: - _, top_ks = torch.topk(logits, k, dim=-1) - if reduction == "mean": - return (top_ks == labels[:, None]).float().sum(dim=-1).mean() - if reduction == "none": - return (top_ks == labels[:, None]).float().sum(dim=-1) - - raise ValueError(f"reduction should be one of ['mean', 'none']: {reduction}") - - -def log_loss_dict(diffusion: GaussianMultinomialDiffusion, ts: Tensor, losses: dict[str, Tensor]) -> None: - for key, values in losses.items(): - logger.logkv_mean(key, values.mean().item()) - # Log the quantiles (four quartiles, in particular). - for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): - quartile = int(4 * sub_t / diffusion.num_timesteps) - logger.logkv_mean(f"{key}_q{quartile}", sub_loss) - - -def numerical_forward_backward_log( - classifier: nn.Module, - optimizer: torch.optim.Optimizer, - data_loader: Generator[tuple[Tensor, ...]], - dataset: Dataset, - schedule_sampler: ScheduleSampler, - diffusion: GaussianMultinomialDiffusion, - prefix: str = "train", - remove_first_col: bool = False, - device: str = "cuda", -) -> None: - batch, labels = next(data_loader) - labels = labels.long().to(device) - - if remove_first_col: - # Remove the first column of the batch, which is the label. - batch = batch[:, 1:] - - num_batch = batch[:, : dataset.n_num_features].to(device) - - t, _ = schedule_sampler.sample(num_batch.shape[0], device) - batch = diffusion.gaussian_q_sample(num_batch, t).to(device) - - for i, (sub_batch, sub_labels, sub_t) in enumerate(split_microbatches(-1, batch, labels, t)): - logits = classifier(sub_batch, timesteps=sub_t) - loss = F.cross_entropy(logits, sub_labels, reduction="none") - - losses = {} - losses[f"{prefix}_loss"] = loss.detach() - losses[f"{prefix}_acc@1"] = compute_top_k(logits, sub_labels, k=1, reduction="none") - if logits.shape[1] >= 5: - losses[f"{prefix}_acc@5"] = compute_top_k(logits, sub_labels, k=5, reduction="none") - log_loss_dict(diffusion, sub_t, losses) - del losses - loss = loss.mean() - if loss.requires_grad: - if i == 0: - optimizer.zero_grad() - loss.backward(loss * len(sub_batch) / len(batch)) - - def transform_dataset( dataset: Dataset, transformations: Transformations, diff --git a/src/midst_toolkit/models/clavaddpm/sampler.py b/src/midst_toolkit/models/clavaddpm/sampler.py new file mode 100644 index 00000000..3edb30bc --- /dev/null +++ b/src/midst_toolkit/models/clavaddpm/sampler.py @@ -0,0 +1,207 @@ +"""Samplers for the ClavaDDPM model.""" + +from abc import ABC, abstractmethod +from typing import Literal + +import numpy as np +import torch +from torch import Tensor + +from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. However, subclasses may override sample() to + change how the resampled terms are reweighted, allowing for actual changes + in the objective. + """ + + @abstractmethod + def weights(self) -> Tensor: + """ + Get a numpy array of weights, one per diffusion step. + + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size: int, device: str) -> tuple[Tensor, Tensor]: + # TODO: what's happening with batch_size? Is is also the number of timesteps? + # We need to clarify this. + """ + Importance-sample timesteps for a batch. + + Args: + batch_size: The number of timesteps. + device: The torch device to save to. + + Returns: + A tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights().cpu().numpy() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = torch.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = torch.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion: GaussianMultinomialDiffusion): + """ + Initialize the UniformSampler. + + Args: + diffusion: The diffusion object. + """ + self.diffusion = diffusion + self._weights = torch.from_numpy(np.ones([diffusion.num_timesteps])) + + def weights(self) -> Tensor: + """Return the weights.""" + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts: Tensor, local_losses: Tensor) -> None: + """ + Update the reweighting using losses from a model. + + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + + Args: + local_ts: An integer Tensor of timesteps. + local_losses: A 1D Tensor of losses. + """ + batch_sizes = [ + torch.tensor([0], dtype=torch.int32, device=local_ts.device) + for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather( + batch_sizes, + torch.tensor([len(local_ts)], dtype=torch.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + max_bs = max([int(x.item()) for x in batch_sizes]) + + timestep_batches = [torch.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [torch.zeros(max_bs).to(local_losses) for bs in batch_sizes] + torch.distributed.all_gather(timestep_batches, local_ts) + torch.distributed.all_gather(loss_batches, local_losses) + timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts: list[int], losses: list[float]) -> None: + """ + Update the reweighting using losses from a model. + + Sub-classes should override this method to update the reweighting + using losses from the model. + + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + + Args: + ts: A list of int timesteps. + losses: A list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__( + self, + diffusion: GaussianMultinomialDiffusion, + history_per_term: int = 10, + uniform_prob: float = 0.001, + ): + """ + Initialize the LossSecondMomentResampler. + + Args: + diffusion: The diffusion object. + history_per_term: The number of losses to keep for each timestep. + uniform_prob: The probability of sampling a uniform timestep. + """ + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.uint) + + def weights(self): + """ + Return the weights. + + Warms up the sampler if it's not warmed up. + """ + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts: list[int], losses: list[float]) -> None: + """ + Update the reweighting using losses from the model. + + Args: + ts: The timesteps. + losses: The losses. + """ + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self) -> bool: + """ + Check if the sampler is warmed up by checking if the loss counts are equal + to the history per term. + + Returns: + True if the sampler is warmed up, False otherwise. + """ + return (self._loss_counts == self.history_per_term).all() + + +def create_named_schedule_sampler( + name: Literal["uniform", "loss-second-moment"], + diffusion: GaussianMultinomialDiffusion, +) -> ScheduleSampler: + """ + Create a ScheduleSampler from a library of pre-defined samplers. + + Args: + name: The name of the sampler. Should be one of ["uniform", "loss-second-moment"]. + diffusion: The diffusion object to sample for. + + Returns: + The UniformSampler if ``name`` is "uniform", LossSecondMomentResampler if ``name`` + is "loss-second-moment". + """ + if name == "uniform": + return UniformSampler(diffusion) + if name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + raise NotImplementedError(f"unknown schedule sampler: {name}") diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index 40816f7f..a570c75e 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -1,30 +1,31 @@ """Defines the training functions for the ClavaDDPM model.""" import pickle +from collections.abc import Generator from logging import INFO, WARNING from pathlib import Path -from typing import Any +from typing import Any, Literal import numpy as np import pandas as pd import torch -from torch import optim +from torch import Tensor, optim from midst_toolkit.common.logger import log from midst_toolkit.core import logger from midst_toolkit.models.clavaddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion from midst_toolkit.models.clavaddpm.model import ( Classifier, + Dataset, Transformations, - create_named_schedule_sampler, get_model, get_model_params, get_T_dict, get_table_info, make_dataset_from_df, - numerical_forward_backward_log, prepare_fast_dataloader, ) +from midst_toolkit.models.clavaddpm.sampler import ScheduleSampler, create_named_schedule_sampler from midst_toolkit.models.clavaddpm.trainer import ClavaDDPMTrainer from midst_toolkit.models.clavaddpm.typing import Configs, RelationOrder, Tables @@ -448,7 +449,7 @@ def train_classifier( "samples", (step + 1) * batch_size, ) - numerical_forward_backward_log( + _numerical_forward_backward_log( classifier, classifier_optimizer, train_loader, @@ -463,7 +464,7 @@ def train_classifier( if not step % classifier_evaluation_interval: with torch.no_grad(): classifier.eval() - numerical_forward_backward_log( + _numerical_forward_backward_log( classifier, classifier_optimizer, val_loader, @@ -554,3 +555,129 @@ def get_df_without_id(df: pd.DataFrame) -> pd.DataFrame: """ id_cols = [col for col in df.columns if "_id" in col] return df.drop(columns=id_cols) + + +def _numerical_forward_backward_log( + classifier: Classifier, + optimizer: torch.optim.Optimizer, + data_loader: Generator[tuple[Tensor, ...]], + dataset: Dataset, + schedule_sampler: ScheduleSampler, + diffusion: GaussianMultinomialDiffusion, + prefix: str = "train", + remove_first_col: bool = False, + device: str = "cuda", +) -> None: + """ + Forward and backward pass for the numerical features of the ClavaDDPM model. + + Args: + classifier: The classifier model. + optimizer: The optimizer. + data_loader: The data loader. + dataset: The dataset. + schedule_sampler: The schedule sampler. + diffusion: The diffusion object. + prefix: The prefix for the loss. Defaults to "train". + remove_first_col: Whether to remove the first column of the batch. Defaults to False. + device: The device to use. Defaults to "cuda". + """ + batch, labels = next(data_loader) + labels = labels.long().to(device) + + if remove_first_col: + # Remove the first column of the batch, which is the label. + batch = batch[:, 1:] + + num_batch = batch[:, : dataset.n_num_features].to(device) + + t, _ = schedule_sampler.sample(num_batch.shape[0], device) + batch = diffusion.gaussian_q_sample(num_batch, t).to(device) + + for i, (sub_batch, sub_labels, sub_t) in enumerate(_split_microbatches(-1, batch, labels, t)): + logits = classifier(sub_batch, timesteps=sub_t) + loss = torch.nn.functional.cross_entropy(logits, sub_labels, reduction="none") + + losses = {} + losses[f"{prefix}_loss"] = loss.detach() + losses[f"{prefix}_acc@1"] = _compute_top_k(logits, sub_labels, k=1, reduction="none") + if logits.shape[1] >= 5: + losses[f"{prefix}_acc@5"] = _compute_top_k(logits, sub_labels, k=5, reduction="none") + _log_loss_dict(diffusion, sub_t, losses) + del losses + loss = loss.mean() + if loss.requires_grad: + if i == 0: + optimizer.zero_grad() + loss.backward(loss * len(sub_batch) / len(batch)) + + +# TODO: Think about moving this to a metrics module +def _compute_top_k( + logits: Tensor, + labels: Tensor, + k: int, + reduction: Literal["mean", "none"] = "mean", +) -> Tensor: + """ + Compute the top-k accuracy. + + Args: + logits: The logits of the classifier. + labels: The labels of the data. + k: The number of top-k. + reduction: The reduction method. Should be one of ["mean", "none"]. Defaults to "mean". + + Returns: + The top-k accuracy. + """ + _, top_ks = torch.topk(logits, k, dim=-1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean() + if reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + raise ValueError(f"reduction should be one of ['mean', 'none']: {reduction}") + + +def _log_loss_dict(diffusion: GaussianMultinomialDiffusion, ts: Tensor, losses: dict[str, Tensor]) -> None: + """ + Output the log loss dictionary in the logger. + + Args: + diffusion: The diffusion object. + ts: The timesteps. + losses: The losses. + """ + for key, values in losses.items(): + logger.logkv_mean(key, values.mean().item()) + # Log the quantiles (four quartiles, in particular). + for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): + quartile = int(4 * sub_t / diffusion.num_timesteps) + logger.logkv_mean(f"{key}_q{quartile}", sub_loss) + + +def _split_microbatches( + microbatch: int, + batch: Tensor, + labels: Tensor, + t: Tensor, +) -> Generator[tuple[Tensor, Tensor, Tensor]]: + """ + Split the batch into microbatches. + + Args: + microbatch: The size of the microbatch. If -1, the batch is not split. + batch: The batch of data as a tensor. + labels: The labels of the data as a tensor. + t: The timesteps tensor. + + Returns: + A generator of for the minibatch which outputs tuples of the batch, labels, and timesteps. + """ + bs = len(batch) + if microbatch == -1 or microbatch >= bs: + yield batch, labels, t + else: + for i in range(0, bs, microbatch): + yield batch[i : i + microbatch], labels[i : i + microbatch], t[i : i + microbatch]