<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

In [None]:
#| include: false
from nbdev.showdoc import *

In [1]:
#| echo: false
#| output: asis
show_doc(train)

---

### train

>      train (model, df, groups, optimizer, n_batches=20, criterion=MMD_loss(),
>             use_cuda=False, sample_size=(100,), sample_with_replacement=False,
>             local_loss=True, global_loss=False, hold_one_out=False,
>             hold_out='random', apply_losses_in_time=True, top_k=5,
>             hinge_value=0.01, use_density_loss=True, lambda_density=1.0,
>             autoencoder=None, use_emb=True, use_gae=False,
>             use_gaussian:bool=True, add_noise:bool=False,
>             noise_scale:float=0.1, logger=None, use_penalty=False,
>             lambda_energy=1.0, reverse:bool=False)

MIOFlow training loop

Notes:
    - The argument `model` must have a method `forward` that accepts two arguments
        in its function signature:
            ```python
            model.forward(x, t)
            ```
        where, `x` is the input tensor and `t` is a `torch.Tensor` of time points (float).
    - The training loop is divided in two parts; local (predict t+1 from t), and global (predict the entire trajectory).

Arguments:
    model (nn.Module): the initialized pytorch ODE model.

    df (pd.DataFrame): the DataFrame from which to extract batch data.

    groups (list): the list of the numerical groups in the data, e.g. 
        `[1.0, 2.0, 3.0, 4.0, 5.0]`, if the data has five groups.

    optimizer (torch.optim): an optimizer initilized with the model's parameters.

    n_batches (int): Default to '20', the number of batches from which to randomly sample each consecutive pair
        of groups.

    criterion (Callable | nn.Loss): a loss function.

    use_cuda (bool): Defaults to `False`. Whether or not to send the model and data to cuda. 

    sample_size (tuple): Defaults to `(100, )`

    sample_with_replacement (bool): Defaults to `False`. Whether or not to sample data points with replacement.

    local_loss (bool): Defaults to `True`. Whether or not to use a local loss in the model.
        See notes for more detail.

    global_loss (bool): Defaults to `False`. Whether or not to use a global loss in the model.

    hold_one_out (bool): Defaults to `False`. Whether or not to randomly hold one time pair
        e.g. t_1 to t_2 out when computing the global loss.

    hold_out (str | int): Defaults to `"random"`. Which time point to hold out when calculating the
        global loss.

    apply_losses_in_time (bool): Defaults to `True`. Applies the losses and does back propegation
        as soon as a loss is calculated. See notes for more detail.

    top_k (int): Default to '5'. The k for the k-NN used in the density loss.

    hinge_value (float): Defaults to `0.01`. The hinge value for density loss.

    use_density_loss (bool): Defaults to `True`. Whether or not to add density regularization.

    lambda_density (float): Defaults to `1.0`. The weight for density loss.

    autoencoder (NoneType|nn.Module): Default to 'None'. The full geodesic Autoencoder.

    use_emb (bool): Defaults to `True`. Whether or not to use the embedding model.

    use_gae (bool): Defaults to `False`. Whether or not to use the full Geodesic AutoEncoder.

    use_gaussian (bool): Defaults to `True`. Whether to use random or gaussian noise.

    add_noise (bool): Defaults to `False`. Whether or not to add noise.

    noise_scale (float): Defaults to `0.30`. How much to scale the noise by.

    logger (NoneType|Logger): Default to 'None'. The logger to record information.

    use_penalty (bool): Defaults to `False`. Whether or not to use $L_e$ during training (norm of the derivative).

    lambda_energy (float): Default to '1.0'. The weight of the energy penalty.

    reverse (bool): Whether to train time backwards.

In [2]:
#| echo: false
#| output: asis
show_doc(train_ae)

---

### train_ae

>      train_ae (model, df, groups, optimizer, n_epochs=60, criterion=MSELoss(),
>                dist=None, recon=True, use_cuda=False, sample_size=(100,),
>                sample_with_replacement=False, noise_min_scale=0.09,
>                noise_max_scale=0.15, hold_one_out:bool=False,
>                hold_out='random')

Geodesic Autoencoder training loop.

Notes:
    - We can train only the encoder the fit the geodesic distance (recon=False), or the full geodesic Autoencoder (recon=True),
        i.e. matching the distance and reconstruction of the inputs.

Arguments:

    model (nn.Module): the initialized pytorch Geodesic Autoencoder model.

    df (pd.DataFrame): the DataFrame from which to extract batch data.

    groups (list): the list of the numerical groups in the data, e.g. 
        `[1.0, 2.0, 3.0, 4.0, 5.0]`, if the data has five groups.

    optimizer (torch.optim): an optimizer initilized with the model's parameters.

    n_epochs (int): Default to '60'. The number of training epochs.

    criterion (torch.nn). Default to 'nn.MSELoss()'. The criterion to minimize. 

    dist (NoneType|Class). Default to 'None'. The distance Class with a 'fit(X)' method for a dataset 'X'. Computes the pairwise distances in 'X'.

    recon (bool): Default to 'True'. Whether or not the apply the reconstruction loss. 

    use_cuda (bool): Defaults to `False`. Whether or not to send the model and data to cuda. 

    sample_size (tuple): Defaults to `(100, )`.

    sample_with_replacement (bool): Defaults to `False`. Whether or not to sample data points with replacement.

    noise_min_scale (float): Default to '0.0'. The minimum noise scale. 

    noise_max_scale (float): Default to '1.0'. The maximum noise scale. The true scale is sampled between these two bounds for each epoch. 

    hold_one_out (bool): Default to False, whether or not to ignore a timepoint during training.

    hold_out (str|int): Default to 'random', the timepoint to hold out, either a specific element of 'groups' or a random one.

In [3]:
#| echo: false
#| output: asis
show_doc(training_regimen)

---

### training_regimen

>      training_regimen (n_local_epochs, n_epochs, n_post_local_epochs, exp_dir,
>                        model, df, groups, optimizer, n_batches=20,
>                        criterion=MMD_loss(), use_cuda=False,
>                        hold_one_out=False, hold_out='random',
>                        hinge_value=0.01, use_density_loss=True, top_k=5,
>                        lambda_density=1.0, autoencoder=None, use_emb=True,
>                        use_gae=False, sample_size=(100,),
>                        sample_with_replacement=False, logger=None,
>                        add_noise=False, noise_scale=0.1, use_gaussian=True,
>                        use_penalty=False, lambda_energy=1.0, steps=None,
>                        plot_every=None, n_points=100, n_trajectories=100,
>                        n_bins=100, local_losses=None, batch_losses=None,
>                        globe_losses=None, reverse_schema=True, reverse_n=4)

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| n_local_epochs |  |  |  |
| n_epochs |  |  |  |
| n_post_local_epochs |  |  |  |
| exp_dir |  |  |  |
| model |  |  |  |
| df |  |  |  |
| groups |  |  |  |
| optimizer |  |  |  |
| n_batches | int | 20 | BEGIN: train params |
| criterion | MMD_loss | MMD_loss() |  |
| use_cuda | bool | False |  |
| hold_one_out | bool | False |  |
| hold_out | str | random |  |
| hinge_value | float | 0.01 |  |
| use_density_loss | bool | True |  |
| top_k | int | 5 |  |
| lambda_density | float | 1.0 |  |
| autoencoder | NoneType | None |  |
| use_emb | bool | True |  |
| use_gae | bool | False |  |
| sample_size | tuple | (100,) |  |
| sample_with_replacement | bool | False |  |
| logger | NoneType | None |  |
| add_noise | bool | False |  |
| noise_scale | float | 0.1 |  |
| use_gaussian | bool | True |  |
| use_penalty | bool | False |  |
| lambda_energy | float | 1.0 |  |
| steps | NoneType | None |  |
| plot_every | NoneType | None |  |
| n_points | int | 100 |  |
| n_trajectories | int | 100 |  |
| n_bins | int | 100 |  |
| local_losses | NoneType | None |  |
| batch_losses | NoneType | None |  |
| globe_losses | NoneType | None |  |
| reverse_schema | bool | True |  |
| reverse_n | int | 4 |  |