<a href="https://colab.research.google.com/github/PyTorchLightning/lightning-flash/blob/master/flash_notebooks/custom_task_tutorial" target="_parent">
    <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Tutorial: Creating a Custom Task

In this tutorial we will go over the process of creating a custom task, along with a custom data module.

In [None]:
%%capture
! pip install git+https://github.com/PyTorchLightning/pytorch-flash.git

In [None]:
import flash

import torch
from torch.utils.data import TensorDataset, DataLoader
from torch import nn
from sklearn import datasets
from sklearn.model_selection import train_test_split

## The Task

Here we create a basic linear regression task by subclassing `flash.Task`. For the majority of tasks, you will likely only need to override the `__init__` and `forward` methods of task.

In [None]:
class LinearRegression(flash.Task):
    def __init__(self, num_inputs, learning_rate=0.001, metrics=None):
        # what kind of model do we want?
        model = nn.Linear(num_inputs, 1)

        # what loss function do we want?
        loss_fn = torch.nn.functional.mse_loss
        
        # what optimizer to do we want?
        optimizer = torch.optim.SGD
        
        super().__init__(
            model=model,
            loss_fn=loss_fn,
            optimizer=optimizer,
            metrics=metrics,
            learning_rate=learning_rate,
        )
        
    def forward(self, x):
        # we don't actually need to override this method for this example
        return self.model(x)

### Where is the training step?

Most models can be trained simply by passing the output of `forward` to the supplied `loss_fn`, and then passing the resulting loss to the supplied `optimizer`. If you need a more custom configuration, you can override `step` (which is called for training, validation, and testing) or override `training_step`, `validation_step`, and `test_step` individually. These methods behave identically to PyTorch Lightning's [methods](https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#methods).


## The Data

For a task you will likely need a specific way of loading data. For this example, lets say we want a `flash.DataModule` to be used explicitly for the prediction of diabetes disease progression. We can create this `DataModule` below, wrapping the scikit-learn [Diabetes dataset](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset).

In [None]:
class DiabetesPipeline(flash.data.process.Postprocess):
    def per_sample_transform(self, samples):
        return [f"disease progression: {float(s):.2f}" for s in samples]

class DiabetesData(flash.DataModule):
    
    postprocess_cls = DiabetesPipeline
    
    def __init__(self, batch_size=64, num_workers=0):
        x, y = datasets.load_diabetes(return_X_y=True)
        x = torch.from_numpy(x).float()
        y = torch.from_numpy(y).float().unsqueeze(1)
        x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=0)

        train_ds = TensorDataset(x_train, y_train)
        test_ds = TensorDataset(x_test, y_test)
        
        super().__init__(
            train_ds=train_ds,
            test_ds=test_ds,
            batch_size=batch_size,
            num_workers=num_workers
        )
        self.num_inputs = x.shape[1]  

You'll notice we added a `DataPipeline`, which will be used when we call `.predict()` on our model. In this case we want to nicely format our ouput from the model with the string `"disease progression"`, but you could do any sort of post processing you want!

## Fit

Like any Flash Task, we can fit our model using the `flash.Trainer` by supplying the task itself, and the associated data:

In [None]:
data = DiabetesData()
model = LinearRegression(num_inputs=data.num_inputs)

trainer = flash.Trainer(max_epochs=10, progress_bar_refresh_rate=20)
trainer.fit(model, data)

With a trained model we can now perform inference. Here we will use a few examples from the test set of our data:

In [2]:
predict_data = torch.tensor([
    [ 0.0199,  0.0507,  0.1048,  0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037,  0.0403],
    [-0.0128, -0.0446,  0.0606,  0.0529,  0.0480,  0.0294, -0.0176,  0.0343, 0.0702,  0.0072],
    [ 0.0381,  0.0507,  0.0089,  0.0425, -0.0428, -0.0210, -0.0397, -0.0026, -0.0181,  0.0072],
    [-0.0128, -0.0446, -0.0235, -0.0401, -0.0167,  0.0046, -0.0176, -0.0026, -0.0385, -0.0384],
    [-0.0237, -0.0446,  0.0455,  0.0907, -0.0181, -0.0354,  0.0707, -0.0395, -0.0345, -0.0094]])

model.predict(predict_data)

NameError: name 'torch' is not defined

Because of our custom data pipeline's `after_uncollate` method, we will get a nicely formatted output like the following:
```
[['disease progression: 14.84'],
 ['disease progression: 14.86'],
 ['disease progression: 14.78'],
 ['disease progression: 14.73'],
 ['disease progression: 14.71']]
```

<code style="color:#792ee5;">
    <h1> <strong> Congratulations - Time to Join the Community! </strong>  </h1>
</code>

Congratulations on completing this notebook tutorial! If you enjoyed it and would like to join the Lightning movement, you can do so in the following ways!

### Help us build Flash by adding support for new data-types and new tasks.
Flash aims at becoming the first task hub, so anyone can get started to great amazing application using deep learning. 
If you are interested, please open a PR with your contributions !!! 


### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub
The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.

* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)

### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!
The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel

### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/lightning-bolts)
Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.

* Please, star [Bolt](https://github.com/PyTorchLightning/lightning-bolts)

### Contributions !
The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/lightning-bolts) GitHub Issues page and filter for "good first issue". 

* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
* [Bolt good first issue](https://github.com/PyTorchLightning/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
* You can also contribute your own notebooks with useful examples !

### Great thanks from the entire Pytorch Lightning Team for your interest !

<img src="https://raw.githubusercontent.com/PyTorchLightning/lightning-flash/18c591747e40a0ad862d4f82943d209b8cc25358/docs/source/_static/images/logo.svg" width="800" height="200" />