# Example: FLAVA training and inference on FHM dataset

This notebook provides a comprehensive guide on using the MATK (Multimodal AI Toolkit) library to evaluate the performance of the FLAVA model on the Facebook Hateful Memes dataset.

We kindly request that interested researchers duly acknowledge and adhere to Facebook AI's Hateful Memes dataset licence agreements. This entails the requisite download of the original dataset provided by Facebook AI.

## Step 1. Review and Accept Facebook AI's Dataset Licence Agreement
Researchers may access the Hateful Memes dataset license agreements by visiting the official website at https://hatefulmemeschallenge.com/. Once researchers have carefully reviewed and duly accepted the terms outlined in the license agreements, they are eligible to proceed with the download of the Hateful Memes datasets. This includes

* train, dev, dev_seen and test annotations
* images (critical for vision-language multimodal models)

## Step 2. Configuring the dataset

1. Locate the **configs/datasets.yaml** file. Enter the paths for 'annotation_filepaths', 'image_dirs', 'feats_dirs' depending on your requirement. For FLAVA, we need 'annotation_filepaths' and 'image_dirs' since our ImagesDataModule needs these arguments

2. Next locate the **configs/data** folder. We will use the fhm_data.yaml file because we are using the FHM dataset.

3. Inside the fhm_data.yaml you should see various 'datamodules' that are also listed in the table below. 

| Dataset              | DataModule        | Usage                      |
|----------------------|-------------------|----------------------------|
| FasterRCNNDataModule | FasterRCNNDataset | For vision-language models |
| ImagesDataModule     | ImagesDataset     | For vision-language models |
| TextDataModule       | TextDataset       | For language models        |


2. **tokenizer_class_or_path**: specifies tokenizer or processor class/path for model
3. **frcnn_class_or_path**: specifies class/path Faster R-CNN feature extraction
4. **dataset_class**: specifies the dataset class to use for the current datamodule
5. **dataset_handler**: specifies the file to be passed to the dataset_handler so that the dataset class knows where to get its data from
6. **auxiliary_dicts**: path to .pkl containing auxiliary information like captions for images
8. **num_workers**: perform multi-process data loading by simply setting the argument num_workers to a positive integer


### Modification

1. You can suitably modify the **batch_size**, **num_workers**  and **shuffle_train** arguments based on your need.

## Step 3. Configuring the model-to-dataset mapping

Our model needs to know some information about the data it is going to handle so that it can appropriately initiate metrics calculation and logging.

1. Locate the **fhm** key
2. Locate the **flava** key within the above key
3. Inside cls_dict we specify 'label: 2' because our dataset has exactly 1 label called 'label' and it can take 2 values - 0 or 1. Another example is that the FHM finegrained dataset has exactly 1 label called 'hate' and it can take 2 values - 0 or 1.

## Step 4: Configuring the Model

Everything related to the model configuration is stored inside the **configs/models.yaml** file.

For the **flava** key we specify the following arguments:
1. **class_path**: specifies path to file under **[models/](https://github.com/Social-AI-Studio/MATK/tree/main/models)**
2. **model_class_or_path**: specifies the pretrained model to be used
3. **metrics** (only for VL Models) - List of metrics from torchmetrics, each element specifies the torchmetrics metric name, task and num_classes. 

for further metric configuration, you can look at the **[link](https://torchmetrics.readthedocs.io/en/stable/all-metrics.html)** to add more arguments inside the **args** of list of each metric in metrics




## Step 5: Configuring the Trainer

The Trainer helps automate several aspects of training. As the documentation says,  it handles all loop details for you, some examples include:
* Automatically enabling/disabling grads
* Running the training, validation and test dataloaders
* Calling the Callbacks at the appropriate times
* Putting batches and computations on the correct devices

Everything related to the trainer is specified under the **{dataset}_{task}_trainer.yaml** file inside the **configs/trainers** folder.

### Modification
1. Suitably modify **dirpath** and **name** arguments under callbacks to choose where your checkpoints will be stored and what name it will be given respectively. 
2. Suitably modify **save_dir** and **name** arguments under logger to choose where your lightining logs will be stored and what name it will be given respectively.
3. To add arguments like 'seed_everything' or 'ckpt_path', add them at the same level as the **trainer** key.
3. You can also modify other hyperparameters such as **max_epochs** or even find new ways to tweak the trainer (within the **trainer** key) by adding keys mentioned here: https://lightning.ai/docs/pytorch/stable/common/trainer.html#


## Step 6: Train the Model model

Make sure you have already run :
```
pip install -r requirements.txt
```
Then run:
```
python main.py --model flava --dataset fhm --datamodule ImagesDataModule --action fit
```

--model => from the keys under **models** key in models.yaml \
--fhm => from the keys under **datasets** key in datasets.yaml \
--datamodule => from the keys under the {dataset}_{task}_data.yaml file

## Step 7: Test the Model model

Remember to create a key called **ckpt_path** at the same level as the **trainer** key under **flava** in the configs/fhm_trainer.yaml file. Look under the **filename** key of your ModelCheckpoint callback for this name.

Then run:
```
python main.py --model flava --dataset fhm --datamodule ImagesDataModule --action test
```


