# Implementing New Synthesis Methods

This notebook assumes you are familiar with synthesis methods and have interacted with them in this package, and would like to either understand what happens under the hood a little better or implement your own using the existing API.

`Synthesis` is an abstract class that should be inherited by any synthesis method (e.g., `Metamer`, `MADCompetition`). It provides many helper functions which, depending on your method, can be used directly or modified slightly. For the two extremes on this, see the source code for `Metamer` and `MADCompetition`: `Metamer` uses almost everything exactly as written, whereas because `MADCompetition` works with two models, requires extensive modification. Even when you're modifying the methods, however, you should try to:

1. Maintain the names and, as much as possible, the call signatures. We want it to be easy to use the different synthesis methods so the way users interact with them should be as similar as possible.
  - The most common reason you'd modify the call signature would be for adding arguments. If an argument is a modification / tweak of an existing one, place it next to that existing argument. If it's completely novel (and important), place it near the beginning. 
  - For example, `MADCompetition` requires two models, instead of one, during intialization, and a new required argument, `synthesis_target` for `synthesize()`. The standard initialization call signature is `(target_image, model, loss_function, **model_kwargs)`, so `MADCompetition`'s is `(target_image, model_1, model_2, loss_function, model_1_kwargs, model_2_kwargs)`. The new argument for `synthesize()` goes at the beginning.
2. Reuse existing methods. The basic idea of many synthesis methods is pretty similar: update the input image based on the gradient (or a function of the gradient) of the model. Therefore, the code for much of what you'll want to do already exists and you will just need to e.g., call it with a different argument, specify what model to use, modify the gradient before updating the image.
3. Make sure all the existing public-facing methods either work or raise a `NotImplementedError`. We want people either to be able to use the methods they're used to from other synthesis methods for better understanding synthesis (for example, plotting the synthesis status or creating an animation of progress), or to know why they cannot. For example, because `MADCompetition` has two models, we want to plot both losses in `plot_loss`. We can make use of the existing `Synthesis.plot_loss()` method, just modifying where it grabs the data from, and call it twice. To the user, there's no difference in how it creates the plot. However, there's no need to do this for the private methods (e.g., `_set_seed()`).
4. Add any natural generalizations. `MADCompetition` stimuli come in sets of 4, and so it makes sense to provide a function that generalizes `plot_synthesized_image()` to show all 4 of them: `plot_synthesized_image_all()`.

## Structure

Now, let's walk through the structure of a `Synthesis` class.

The two most important functions are `__init__()`, which initializes the class, and `synthesize()` which synthesizes an image. 

`Synthesis.__init__()` provides a lot of code that you can use (as well as a basis for the docstring), and should be called unless you have a *really* good reason not to. It will automatically support the use of models and metrics, modifying the loss function, and initialize a lot of the class's attributes. You may want to call it and then do additional stuff, e.g., set up a second model or initialize new attributes.

`Synthesis.synthesize()` cannot be called, but provides a skeleton of what `synthesize()` should look like (as well as a basis for the docstring). It shows how the various hidden helper methods are used to set up the synthesis call and core loop. You'll probably want to copy this into your new synthesis method's `synthesize()` and then modify it. You'll certainly need to change the initialization of the matched image, which varies from method to method (for instance, `Metamer` uses random noise or a new image, whereas `MADCompetition` uses the reference image plus some noise). You may otherwise be able to ues the method as it's written, just modifying the helper functions.

`Synthesis` also contains a variety of plotting and animating functions. You will probably need to think about what to plot, but should hopefully be able to adapt the existing display code to your needs:
- `Synthesis.plot_representation_error()` calls `po.tools.display.plot_representation` on `Synthesis.representation_error()`, which takes the difference between `Synthesis.saved_representation` and `Synthesis.base_representation`.
- `Synthesis.plot_loss()` plots `Synthesis.loss` as a function of iterations.
- `Synthesis.plot_synthesized_image()` calls `po.imshow` on `Synthesis.synthesized_signal`
- `Synthesis.plot_synthesis_statuss()` combines the three above plots into one figure
- `Synthesis.animate()` animates the above figure over iterations.

## Important Attributes

In order to mesh with `Synthesis`, you'll need to adopt its naming conventions for its attributes:
- At initialization, you should take something like the following arguments, which will get stored as attributes:
  - `base_signal`: the signal you're basing your synthesis off of.
  - `model`: the model (`torch.nn.Module`) or metric (callable) that you're basing your synthesis off of
  - `loss_function`: the callable to use for computing distance, must return a scalar. Can be `None`, in which case we use the l2-norm of the difference in representation space.
- The model's representation of `base_signal` should be `base_representation`.
- During iterative synthesis: 
  - The synthesis-in-progress is `synthesized_signal` and the model's representation of it is `synthesized_representation`.
  - Loss is `loss`, norm of the gradient is `gradient`, learning rate is `learning_rate`
  - If user wants to store progress, then `store_progress` is either a boolean or an integer specifying how often to update the following attributes, which store the corresponding other attributes:
    - `saved_signal` contains `synthesized_signal`
    - `saved_representation` contains `synthesized_representation`
    - `saved_signal_gradient` contains `synthesized_signal.grad`
    - `saved_representation_gradient` contains `synthesized_representation.grad`
  - If you want to make use of coarse-to-fine optimization, `_init_ctf_and_randomizer` will take care of initializing the following attributes, `_optimizer_step` and `_closure` use them:
    - `scales` is a copy of `model.scales` and will be edited over the course of optimization to specify which scale we're working on at the moment
    - `scales_loss`: scale-specific loss at each iteration (`loss` contains the loss computed with the whole model)
    - `scales_timing`: dictionary containing the iterations where we started and stopped synthesizing each scale
    - `scales_finished`: list of scales that we've finished optimizing
  - For saving during synthesis (in case of failure or something), `save_progress` acts like `store_progress` and `save_path` specifies the path to the `.pt` file for saving. 
  - The other arguments to `synthesize()`, as documented there, are also set as attributes and made use of by `_optimizer_step` and `_closure`, but are not necessary for the other functionality.
  
## Required methods

The only methods you need to implement are `__init__()`, `save()`, and `load()`:
- For save, you just need to tell `super().save()` which attributes you wish to save. It's recommended you include the `save_model_reduced` argument as well (see `Metamer` tutorial notebook for an explanation of that).
- For load, you need to tell `super().load()` what the name of the attribute that contains the model is (e.g., `model`).