# Example: FLAVA training and inference on Harm-P dataset

This notebook provides a comprehensive guide on using the MATK (Multimodal AI Toolkit) library to evaluate the performance of the FLAVA model on the Harm-P dataset.

## Step 1. Configuring the dataset

The dataset class helps us do the following:
* Itemize preprocessed datasets as records.
* Aggregate primary dataset information (images, image features) with their respective records.
* Aggregate auxilliary information (captions, web entities, explanations) with their respective records.

The dataset configurations are defined in the `configs/dataset` directory. 

**Creating Harm-P datasets**

Locate the **datasets** folder. You can duplicate one of the existing dataset file (e.g., fhm.py) and ensure that the following information are provided in each record:
- `img`
- `id`
- labels, such as `harm_p_intensity` and `harm_p_intensity`
- `templated_text`
- `templated_labels` (For language outputs)

**Expected Output**
```
def __init__():
        self.annotations = utils._load_jsonl(annotation_filepath)
        self._preprocess_dataset()

        self.auxiliary_data = self._load_auxiliary(auxiliary_dicts)
        self._format_input_output(
            text_template,
            labels_template,
            labels_mapping
        )
```

```
def _preprocess_dataset(self):
        for record in tqdm.tqdm(self.annotations, desc="Dataset preprocessing"):
            record["img"] = record["image"]
            record["id"] = os.path.splitext(record["img"])[0]
            del record["image"]

            # convert label to numeric values
            if record["labels"][0] == '':
                continue
            
            record[f"{DATASET_PREFIX}_intensity"] = INTENSITY_MAP[record["labels"][0]]
            record[f"{DATASET_PREFIX}_target"] = TARGET_MAP[record["labels"][1]] \
            if len(record["labels"]) > 1 else 0
```

```
def _format_input_output(
        self,
        text_template: str,
        labels_template: str,
        labels_mapping: dict
    ):
        for record in tqdm.tqdm(self.annotations, desc="Input/Output formatting"):
            # format input text template
            input_kwargs = {"text": record['text']}
            for key, data in self.auxiliary_data.items():
                input_kwargs[key] = data[record["id"]]
            text = text_template.format(**input_kwargs)
            record["templated_text"] = text

            # format output text template (for text-to-text generation)
            if labels_mapping:
                for cls_name, label2word in labels_mapping.items():
                    label = record[cls_name]
                    record[f"templated_{cls_name}"] = labels_template.format(
                        label=label2word[label]
                    )
```

**Adding Harm-P Configuration**

Locate the **configs/datasets** folder. You can duplicate one of the existing YAML configuration (e.g., `fhm.yaml`) and change the filepaths for the following keys:
- `annotation_filepaths`
- `image_dirs`
- `auxiliary_dicts`

**Expected Output**

```
fhm:
  annotation_filepaths:
    train: /mnt/data1/datasets/memes/harmp/annotations/train_v1.jsonl
    validate: /mnt/data1/datasets/memes/harmp/annotations/val_v1.jsonl
    test: /mnt/data1/datasets/memes/harmp/annotations/test_v1.jsonl
    predict: /mnt/data1/datasets/memes/harmp/annotations/test_v1.jsonl

  image_dirs:
    train: /mnt/data1/datasets/memes/harmp/images/img_clean/harmeme_images_us_pol
    validate: /mnt/data1/datasets/memes/harmp/images/img_clean/harmeme_images_us_pol
    test: /mnt/data1/datasets/memes/harmp/images/img_clean/harmeme_images_us_pol
    predict: /mnt/data1/datasets/memes/harmp/images/img_clean/harmeme_images_us_pol

  auxiliary_dicts:
    train: 
      caption: /mnt/data1/datasets/memes/harmp/preprocessing/blip2_captions/harmeme_images_us_pol/
    validate:
      caption: /mnt/data1/datasets/memes/harmp/preprocessing/blip2_captions/harmeme_images_us_pol/
    test: 
      caption: /mnt/data1/datasets/memes/harmp/preprocessing/blip2_captions/harmeme_images_us_pol/
    predict: 
      caption: /mnt/data1/datasets/memes/harmp/preprocessing/blip2_captions/harmeme_images_us_pol/

...
```

## Step 2. Configuring FLAVA Model

The model class should inherit PyTorch's `BaseLightningModule` and contain the following functionalities:
- *training_step*
- *validation_step*
- *test_step*
- *predict_step*

As FLAVA is supported within the MATK package by default, we can simply reference the model class in the model configuration. A model configuration needs to have the following key:
- `class_path`
- `model_class_or_path`
- `dropout`
- `optimizers`

**Expected Output**
```
class_path: models.flava.FlavaClassificationModel
model_class_or_path: facebook/flava-full
dropout: 0.1
optimizers: ???
```

## Step 3: Configuring Experiments

Finally, we will construct a experiment configuration, which is stored inside the **configs/experiments.yaml** file. The experiment configuration is a composition of the various configuration that was constructed earlier

```
# @package _global_
defaults:
  - /model: flava
  - /dataset: 
    - harm_p
  - /datamodule: processor_datamodule
  - /trainer: single_gpu_trainer
  - /metric:
    - accuracy
    - auroc
  - /hydra: experiment
  - _self_
```




The experiment configuration will **override** optional/subjective parameters such as:
```
model:
  optimizers: 
  - class_path: torch.optim.Adam
    lr: 2e-5

dataset:
  harm_p:
    dataset_class: datasets.harm_p.ImageDataset
    text_template: "{text}"
    labels:
      harm_p_intensity: 2

datamodule:
  processor_class_or_path: facebook/flava-full

monitor_metric: validate_harm_p_intensity_average
monitor_mode: max
save_top_ks: 1

# Experiment settings
experiment_name: baseline/harmeme/flava

# Job settings
hydra.verbose: True
seed_everything: 1111
overwrite: False
action: ???
```

## Step 4: Train the Model model

In [2]:
!python3 ../main.py \
    +experiment=harm_p/flava.yaml \
    action=fit \
    trainer=single_gpu_trainer

[INFO] - Setting custom seed: 1111...
Seed set to 1111
Some weights of the model checkpoint at facebook/flava-full were not used when initializing FlavaModel: ['itm_head.pooler.dense.weight', 'mim_head.transform.dense.bias', 'image_codebook.blocks.group_2.group.block_1.res_path.path.conv_2.bias', 'image_codebook.blocks.group_4.group.block_2.res_path.path.conv_2.bias', 'mmm_text_head.bias', 'image_codebook.blocks.group_4.group.block_2.res_path.path.conv_1.weight', 'mlm_head.bias', 'image_codebook.blocks.group_2.group.block_2.res_path.path.conv_1.bias', 'mmm_text_head.transform.LayerNorm.bias', 'image_codebook.blocks.group_2.group.block_1.res_path.path.conv_3.weight', 'image_codebook.blocks.group_3.group.block_1.res_path.path.conv_2.weight', 'image_codebook.blocks.group_2.group.block_2.res_path.path.conv_4.weight', 'mmm_text_head.decoder.bias', 'image_codebook.blocks.group_4.group.block_1.res_path.path.conv_2.bias', 'image_codebook.blocks.group_2.group.block_1.id_path.bias', 'image_codeb

## Step 5: Test the Model model

In [7]:
!python3 ../main.py \
    +experiment="harm_p/flava.yaml" \
    +model_checkpoint="/mnt/data1/mshee/test-repository/MATK/examples/experiments/baseline/harm_p/flava/epoch\=2-step\=285.ckpt" \
    trainer=single_gpu_trainer \
    action=test

[INFO] - Setting custom seed: 1111...
Seed set to 1111
Some weights of the model checkpoint at facebook/flava-full were not used when initializing FlavaModel: ['mmm_image_head.transform.dense.bias', 'image_codebook.blocks.group_3.group.block_2.res_path.path.conv_2.weight', 'image_codebook.blocks.group_4.group.block_1.id_path.bias', 'image_codebook.blocks.group_4.group.block_1.res_path.path.conv_4.weight', 'image_codebook.blocks.group_3.group.block_2.res_path.path.conv_2.bias', 'mlm_head.transform.dense.bias', 'mim_head.transform.LayerNorm.weight', 'image_codebook.blocks.group_4.group.block_1.res_path.path.conv_2.weight', 'image_codebook.blocks.group_3.group.block_2.res_path.path.conv_3.weight', 'image_codebook.blocks.group_2.group.block_1.res_path.path.conv_1.bias', 'image_codebook.blocks.group_4.group.block_1.res_path.path.conv_1.bias', 'mmm_image_head.decoder.weight', 'mim_head.transform.dense.weight', 'image_codebook.blocks.input.weight', 'mmm_image_head.transform.dense.weight', 'im