In [1]:
from mmf.common.registry import registry
from dataset import (
    HatefulMemesFeaturesDataset,
    HatefulMemesImageDataset
)
from mmf.datasets.mmf_dataset_builder import MMFDatasetBuilder
from mmf.utils.configuration import get_mmf_env
from mmf.utils.general import get_absolute_path
from mmf.utils.file_io import PathManager
import os

In [2]:
@registry.register_builder("hateful_memes_data_builder")
class HatefulMemesBuilder(MMFDatasetBuilder):
    def __init__(
        self,
        dataset_name="hateful_memes", #??
        dataset_class=HatefulMemesImageDataset,
        *args,
        **kargs
    ):
        super().__init__(dataset_name, dataset_class, *args, **kwargs)
        self.dataset_class=HatefulMemesImageDataset
        
    @classmethod
    def config_path(self):
        return "configs/dataset_config.yaml"
    
    def load(self, config, dataset_type, *args, **kwargs):
        config=config #??
        if config.use_features:
            self.dataset_class = HatefulMemesFeaturesDataset
            
        self.dataset = super().load(config, dataset_type, *args, **kargs)
        
        return self.dataset
    
    def build(self, config, *args, **kwargs):
        #first, check whether manual downloads have been performed
        data_dir = get_mmf_env(key="data_dir") #??
        test_path = get_absolute_path( #this is the path for data??
            os.path.join(
                data_dir,
                "datasets",
                self.dataset_name,
                "defaults",
                "annotations",
                "train.jasonl"
            )
        )
        
        assert PathManager.exists(test_path), (f"No hateful memes data found at {test_path}")
        super().builder(config, *args, **kwargs)
        
    def update_registry_for_model(self, config):
        if hasattr(self.dataset, "text_processor") and hasattr(
            self.dataset, "get_vocab_size"
        ):
            registry.register(
                self.dataset_name + "_text_vocab_size",
                self.dataset.text_processor.get_vocab_size()
            )
        
        registry.register(self.dataset_name + "_name_final_outputs", 2)