In [None]:
!pip install pytorch-lightning
!pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
!pip install git+git://github.com/lgvaz/mantisshrimp.git

# Wheat

In [None]:
from mantisshrimp.imports import *
from mantisshrimp import *
import pandas as pd
import albumentations as A

## Parser

The first step is to understand the data. In this task we were given a `.csv` file with annotations, let's take a look at that.

<div class="alert alert-info">
    
**Note:**  

Replace `source` with your own path for the dataset directory.
    
</div>

In [None]:
source = Path('../input/global-wheat-detection/')
df = pd.read_csv(source / "train.csv")
df.head()

At first glance, we can make the following assumptions:  
* Multiple rows with the same object_id, width, height  
* A different bbox for each row  
* source doesn't seem relevant right now  

Once we know what our data provides we can create our custom `Parser`.  

When creating a `Parser` we inherit from smaller building blocks that provides the functionallity we want:  
* `DefaultImageInfoParser`: Will parse standard fields for image information, e.g. `filepath`, `height`, `width`  
* `FasterRCNNParser`: Since we only need to predict bboxes we will use a `FasterRCNN` model, this will parse all the requirements for using such a model.  

We can also specify exactly what fields we would like to parse, in fact, the parsers we are currently using are just helper classes that groups a collection of individual parsers.  
We are going to see how to use individual parsers in a future tutorial.

<div class="alert alert-info">
    
**Note:**

If you are using an IDE there is a little bit of magic than can happen. Once you created defined your class you can right click on it and select the option _"implement abstract methods"_, this will automatically populate your class with all the methods you need to override. 

If you are using a notebook, or your IDE does not support that, check the documentation to know what methods you should override.

</div>


<div class="alert alert-warning">
    
**Important:**  
    
Be sure to return the correct type on all overriden methods!
    
</div>

In [None]:
class WheatParser(DefaultImageInfoParser, FasterRCNNParser):
    def __init__(self, df, source):
        self.df = df
        self.source = source
        self.imageid_map = IDMap()

    def __iter__(self):
        yield from self.df.itertuples()

    def __len__(self):
        return len(self.df)

    def imageid(self, o) -> int:
        return self.imageid_map[o.image_id]

    def filepath(self, o) -> Union[str, Path]:
        return self.source / f"{o.image_id}.jpg"

    def height(self, o) -> int:
        return o.height

    def width(self, o) -> int:
        return o.width

    def label(self, o) -> int:
        return 1

    def bbox(self, o) -> BBox:
        return BBox.from_xywh(*np.fromstring(o.bbox[1:-1], sep=","))

Defining the `__init__` is completely up to you, normally we have to pass our data (the `df` in our case) and the folder where our images are contained (`source` in our case).

We then override `__iter__`, telling our parser how to iterate over our data. In our case we call `df.itertuples` to iterate over all `df` rows.

`__len__` is not obligatory but will help visualizing the progress when parsing.

And finally we override all the other methods, they all receive a single argument `o`, which is the object returned by `__iter__`.

Now we just need to decide how to split our data and `Parser.parse`!

In [None]:
data_splitter = RandomSplitter([.8, .2])
parser = WheatParser(df, source / "train")
train_rs, valid_rs = parser.parse(data_splitter)

Let's take a look at one record.

In [None]:
show_record(train_rs[0], label=False)

## Transforms and Datasets

Mantisshrimp is agnostic to the transform library you want to use. We provide default support for [albumentations](https://github.com/albumentations-team/albumentations) but if you want to use another library you just need to inherit and override all abstract methods of `Transform`.

For simplicity, let's use a single transform on the train data and no transforms on the validation data.

In [None]:
train_tfm = AlbuTransform([A.Flip()])

For creating a `Dataset` we just need need to pass the parsed records from the previous step and optionally a transform.

In [None]:
train_ds = Dataset(train_rs, train_tfm)
valid_ds = Dataset(valid_rs)

## Model

Now [pytorch-lightning](https://github.com/PytorchLightning/pytorch-lightning) enters the picture.  

Everything from now is almost pure lightning, the only big difference is that instead of inheriting from `LightningModule` we inherit from the specialized `MantisFasterRCNN`, this will automatically create the model architecture and download the pre-trained model weights.

If you are not familiar with lightning, be sure to check their excelent [documentation](https://pytorch-lightning.readthedocs.io/en/stable/).

In [None]:
class WheatModel(MantisFasterRCNN):
    def configure_optimizers(self):
        opt = SGD(self.parameters(), 1e-3, momentum=0.9)
        return opt

We create the model passing how many classes we have.  

In our case we have two: `wheat` and `background`.

In [None]:
model = WheatModel(2)

## DataLoader

Another difference from lightning is that all mantis models have a `dataloader` method that returns a customized `DataLoader` for each model.

In [None]:
train_dl = model.dataloader(train_ds, shuffle=True, batch_size=4, num_workers=2)
valid_dl = model.dataloader(valid_ds, batch_size=4, num_workers=2)

## Train

That's it! Trainer for train! 🚀 🚀 🚀

In [None]:
trainer = Trainer(max_epochs=1, gpus=1)
trainer.fit(model, train_dl, valid_dl)