# Training Detr

This tutorial explains how to use [LitDetr] module to train [DETR50 architecture] from scratch, using [COCO2017 detection dataset] as input.

<div class="alert alert-info">
    
**Goals**
    
1. Learn the different ways to instantiate the [LitDetr] class
2. Train [DETR50 architecture]
3. Load trained weights and make inference with pre-trained weights

</div>

[DETR50 architecture]: https://arxiv.org/abs/2005.12872
[COCO2017 detection dataset]: https://cocodataset.org/#detection-2017
[LitDetr]: ../alonet/detr_training.rst#alonet.detr.train.LitDetr

## 1. LitDetr argument levels

[Aloception] is developed under the [Pytorch Lightning] framework, and provides different modules that facilitate the use of datasets and training models. 

<div class="alert alert-info">

**See also**
    
All information is availabled at:
    
 * [End-to-End Object Detection with Transformers (DETR)]   
 * [Pytorch Lightning Module]   

</div>

There are multiple ways to instantiate the module, starting with the most common one: using the default parameters

[Pytorch Lightning]: https://www.pytorchlightning.ai/
[Pytorch Lightning Module]: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
[End-to-End Object Detection with Transformers (DETR)]: https://arxiv.org/abs/2005.12872
[Aloception]: ../index.rst

In [None]:
from alonet.detr import LitDetr

litdetr = LitDetr()

Like all modules in [Aloception] based on [Pytorch Lightning], [LitDetr] has a static method that concatenates its default parameters to other modules.

[Pytorch Lightning]: https://www.pytorchlightning.ai/
[Aloception]: ../index.rst
[LitDetr]: ../alonet/detr_training.rst#alonet.detr.train.LitDetr

In [None]:
from argparse import ArgumentParser

parser = ArgumentParser()
parser = litdetr.add_argparse_args(parser)
parser.parse_args([])

However, if we want to change a specific parameter, it should be changed in class definition

In [None]:
from argparse import Namespace

def params2Namespace(litdetr):
    return Namespace(
        accumulate_grad_batches=litdetr.accumulate_grad_batches, 
        gradient_clip_val=litdetr.gradient_clip_val, 
        model_name=litdetr.model_name, 
        weights=litdetr.weights
    )

litdetr = LitDetr(gradient_clip_val=0.6)
params2Namespace(litdetr)

These parameters could be easily modified in console if we provide them to the module

In [None]:
args = parser.parse_args([]) # Remove [] to run in script
litdetr = LitDetr(args)
params2Namespace(litdetr)

Also, we could use both examples to fix some parameters and use the rest as the values entered via the command line

In [None]:
from alonet.detr import LitDetr, DetrR50Finetune

my_model = DetrR50Finetune(num_classes = 2)
litdetr = LitDetr(args, model_name="finetune", model=my_model)
params2Namespace(litdetr)

<div class="alert alert-warning">

**Attention**
    
All the parameters described explicitly will replace the ones in the **args** variable.
</div>

<div class="alert alert-info">

**See also**

Since [LitDetr] is a pytorch lig based module, all functionalities could be implemented by inheriting [LitDetr] as a parent class. See the information in [Pytorch Lightning Module].
</div>

[LitDetr]: ../alonet/detr_training.rst#alonet.detr.train.LitDetr
[Pytorch Lightning Module]: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html

## 2. Train process

<div class="alert alert-info">

**See also**

The training process is based on the [Pytorch Lightning Trainer Module]. For more information, please consult their online documentation.
</div>

In order to make an example, let's take the [COCO detection 2017 dataset] as a training base. The common pipeline is described below:

[COCO detection 2017 dataset]: https://cocodataset.org/#detection-2017
[Pytorch Lightning Trainer Module]: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html

In [None]:
from argparse import ArgumentParser

import alonet
from alonet.detr import CocoDetection2Detr, LitDetr

import torch

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# Parameters definition
# Build parser (concatenates arguments to modify the entire project)
parser = ArgumentParser(conflict_handler="resolve")
parser = CocoDetection2Detr.add_argparse_args(parser)
parser = LitDetr.add_argparse_args(parser)
parser = alonet.common.add_argparse_args(parser)  # Add common arguments in train process
args = parser.parse_args([])

# Dataset use to train
coco_loader = CocoDetection2Detr(args)
lit_detr = LitDetr(args)

# Train process
# args.save = True # Uncomment this line to store trained weights
lit_detr.run_train(
    data_loader=coco_loader, 
    args=args, 
    project="detr", 
    expe_name="coco_detr", 
)

<div class="alert alert-warning">
    
**Attention**

This code has a high computational cost and demands several hours of training, given its initialization from scratch. It is recommended to skip to the next section to see the results of the trained network.
</div>

## 3. Make inferences

Once the training is finished, we can load the trained weights knowing the project and run id (`~/.aloception/project_run_id/run_id` path). For this, a function of the common module of aloception could be used:

```python
from argparse import Namespace
from alonet.common import load_training

args = Namespace(project_run_id = "project_run_id", run_id = "run_id")
lit_detr = load_training(LitDetr, args = args)
```

Moreover, [LitDetr] allows download and load pre-trained weights for use. This is achieved by using the `weights` attribute:

[LitDetr]: ../alonet/detr_training.rst#alonet.detr.train.LitDetr

In [None]:
lit_detr = LitDetr(weights = "detr-r50")

Finally, we have a pre-trained model ready to make some detections.

In [None]:
%matplotlib inline
from alonet.detr import CocoDetection2Detr, LitDetr

import torch

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# Dataset use to train
coco_loader = CocoDetection2Detr()
lit_detr = LitDetr(weights = "detr-r50")
lit_detr.model = lit_detr.model.eval().to(device)

# Check a random result
frame = next(iter(coco_loader.val_dataloader()))
frame = frame[0].batch_list(frame).to(device)
pred_boxes = lit_detr.inference(lit_detr(frame))[0]  # Inference from forward result
gt_boxes = frame[0].boxes2d

frame.get_view(
    [
        gt_boxes.get_view(frame[0], title="Ground truth boxes"),
        pred_boxes.get_view(frame[0], title="Predicted boxes"),
    ], size = (1920,1080)
).render()

<div class="alert alert-info">
    
**What is next?**

Learn how to train a custom architecture in **[Finetuning DETR]** tutorial. Also, know about a complex model based on *deformable attention module* in **[Training Deformable]** tutorial.
</div>

[Finetuning DETR]: finetuning_detr.rst
[Training Deformable]: training_deformable_detr.rst