Following [this](https://pytorch-lightning.readthedocs.io/en/stable/extensions/loops.html)

## [❗] Loops

Loops let advanced users **swap out the default gradient descent optimization loop** at the core of Lightning with a **different optimization paradigm**.


The Lightning Trainer is built on top of the standard gradient descent optimization loop which works for 90%+ of machine learning use cases:

```python
for i, batch in enumerate(dataloader):
    x, y = batch
    y_hat = model(x)
    loss = loss_function(y_hat, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
```

However, some new research use cases such as 
* meta-learning,
* active learning,
* recommendation systems, etc., 

require a different loop structure.

For example here is a simple loop that guides the weight updates with a loss from a special validation split:

```python
for i, batch in enumerate(train_dataloader):
    x, y = batch
    y_hat = model(x)
    loss = loss_function(y_hat, y)
    optimizer.zero_grad()
    loss.backward()

    # Added this whole mechanic:
    val_loss = 0
    for i, val_batch in enumerate(val_dataloader):
        x, y = val_batch
        y_hat = model(x)
        val_loss += loss_function(y_hat, y)

    scale_gradients(model, 1 / val_loss)  # NOTE.
    optimizer.step()  # Finally, step.
```

With Lightning Loops, you can customize to non-standard gradient descent optimizations to get the same loop above:

```python
trainer = Trainer()
trainer.fit_loop.epoch_loop = MyGradientDescentLoop()  # NOTE!
```

### ℹ️ Understanding the Default Trainer Loop

The Lightning `Trainer` automates the standard optimization loop which every PyTorch user is familiar with:

```python
for i, batch in enumerate(dataloader):
    x, y = batch
    y_hat = model(x)
    loss = loss_function(y_hat, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
```

The core research logic is simply shifted to the `LightningModule`:

```python
for i, batch in enumerate(dataloader):
    # x, y = batch                      moved to training_step
    # y_hat = model(x)                  moved to training_step
    # loss = loss_function(y_hat, y)    moved to training_step
    loss = lightning_module.training_step(batch, i)

    # Lightning handles automatically:
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
```

Under the hood, the above loop is implemented using **the `Loop` API** like so:

```python
class DefaultLoop(Loop):
    def advance(self, batch, i):  # NOTE.
        loss = lightning_module.training_step(batch, i)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    def run(self, dataloader):  # NOTE.
        for i, batch in enumerate(dataloader):
            self.advance(batch, i)
```

Defining a loop within a class interface instead of hard-coding a raw Python for/while loop has several benefits:
* You can have full control over the data flow through loops.
* You can add new loops and nest as many of them as you want.
* If needed, the state of a loop can be [saved and resumed](https://pytorch-lightning.readthedocs.io/en/stable/extensions/loops_advanced.html#persisting-loop-state).
* New hooks can be injected at any point.

See gif:

<img src="./assets/epoch-loop-steps.gif" />

### Overriding the default Loops

The fastest way to get started with loops, is to override functionality of an existing loop.

Lightning has 4 (??) main loops which relies on:
* [`FitLoop`](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.loops.FitLoop.html#pytorch_lightning.loops.FitLoop) for fitting (training and validating),
* [`EvaluationLoop`](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.loops.dataloader.EvaluationLoop.html#pytorch_lightning.loops.dataloader.EvaluationLoop) for validating or testing,
* [`PredictionLoop`](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.loops.dataloader.PredictionLoop.html#pytorch_lightning.loops.dataloader.PredictionLoop) for predicting.

For simple changes that don’t require a custom loop, **you can modify each of these loops**.

Each loop has **a series of methods that can be modified**. 

For example with the `FitLoop`:

```python
from pytorch_lightning.loops import FitLoop


class MyLoop(FitLoop):
    def advance(self):
        """Advance from one iteration to the next."""

    def on_advance_end(self):
        """Do something at the end of an iteration."""

    def on_run_end(self):
        """Do something when the loop ends."""
```

> A full list with all built-in loops and subloops can be found here:
> https://pytorch-lightning.readthedocs.io/en/stable/extensions/loops.html#loop-structure-extensions

To add your own modifications to a loop, simply subclass an existing loop class and override what you need.

Here is a simple example how to add a new hook:

```python
from pytorch_lightning.loops import FitLoop


class CustomFitLoop(FitLoop):
    def advance(self):
        """Put your custom logic here."""
```

Now simply attach the correct loop in the trainer directly:
```python
trainer = Trainer(...)
trainer.fit_loop = CustomFitLoop()  # NOTE.

# fit() now uses the new FitLoop!
trainer.fit(...)

# the equivalent for validate()
val_loop = CustomValLoop()
trainer = Trainer()
trainer.validate_loop = val_loop
trainer.validate(...)
```

Finally, see gif:

<img src="./assets/replace-fit-loop.gif" />

### Creating a *New Loop From Scratch*

You can also go wild and implement a full loop from scratch by sub-classing the Loop base class.

You will need to override *a minimum of two things*:

```python
from pytorch_lightning.loop import Loop


class MyFancyLoop(Loop):
    @property
    def done(self):
        """Provide a condition to stop the loop."""

    def advance(self):
        """
        Access your dataloader/s in whatever way you want.
        Do your fancy optimization things.
        Call the LightningModule methods at your leisure.
        """
```

Finally, attach it into the `Trainer`:
```python
trainer = Trainer(...)
trainer.fit_loop = MyFancyLoop()

# fit() now uses your fancy loop!
trainer.fit(...)
```

> ⚠️ But beware: Loop customization gives you more power and full control over the Trainer and with great power comes great responsibility.
> We recommend that you familiarize yourself with *overriding the default loops* first before you start building a new loop from the ground up.

### Loop API

The `Loop` class is the base of all loops in the same way as the `LightningModule` is the base of all models.
It defines a public interface that each loop implementation must follow, the key ones are:

See `Loop` API here:

https://pytorch-lightning.readthedocs.io/en/stable/extensions/loops.html#loop-api

### Subloops

**[❓] The explanation in section makes little sense**

> When you want to customize **nested loops within loops**, use the `replace()` method:

```python
# This takes care of properly instantiating the new Loop and setting all references
trainer.fit_loop.replace(epoch_loop=MyEpochLoop)
# Trainer runs the fit loop with your new epoch loop!
trainer.fit(model)
```

> Alternatively, for more fine-grained control, use the `connect()` method:

```python
# Optional: stitch back the trainer arguments
epoch_loop = MyEpochLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps)
# Optional: connect children loops as they might have existing state
epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop)
# Instantiate and connect the loop.
trainer.fit_loop.connect(epoch_loop=epoch_loop)
trainer.fit(model)
```

More about the built-in loops and how they are composed is explained in the next section.

![img](./assets/connect-epoch-loop.gif)

### Built-in Loops

The training loop in Lightning is called **fit loop** and is actually a combination of several loops.

Here is what the structure would look like in plain Python:

```python
# FitLoop
for epoch in range(max_epochs):

    # TrainingEpochLoop
    for batch_idx, batch in enumerate(train_dataloader):

        # TrainingBatchLoop
        for split_batch in tbptt_split(batch):

            # OptimizerLoop
            for optimizer_idx, opt in enumerate(optimizers):

                loss = lightning_module.training_step(batch, batch_idx, optimizer_idx)
                ...

        # ValidationEpochLoop
        for batch_idx, batch in enumerate(val_dataloader):
            lightning_module.validation_step(batch, batch_idx, optimizer_idx)
            ...
```

* Each of these `for`-loops represents a class implementing the Loop interface.

**See also of built-in loops:**

https://pytorch-lightning.readthedocs.io/en/stable/extensions/loops.html#built-in-loops

* `FitLoop`
* `TrainingEpochLoop`
* `TrainingBatchLoop`
* `OptimizerLoop`
* `ManualOptimization`
* `EvaluationLoop`
* `PredictionLoop`

### Available Loops in Lightning Flash

> 🤔 Investigate further: [Lightning Flash](https://github.com/Lightning-AI/lightning-flash)
>
> Recipes for more complex ML scenarios implemented in Lightning.

See Active Learning example in the tutorial:

https://pytorch-lightning.readthedocs.io/en/stable/extensions/loops.html#available-loops-in-lightning-flash


### Advanced Examples

https://pytorch-lightning.readthedocs.io/en/stable/extensions/loops.html#advanced-examples

* [`K-fold Cross Validation`](https://github.com/Lightning-AI/lightning/blob/master/examples/pl_loops/kfold.py)
* [`Yielding Training Step`](https://github.com/Lightning-AI/lightning/blob/master/examples/pl_loops/yielding_training_step.py)