In [1]:
import torch
import torch.nn as nn
from mmf.common.registry import registry
from mmf.models.base_model import BaseModel
from mmf.modules.embeddings import ProjectionEmbedding #Haoyun: to be checked
from mmf.utils.build import build_classifier_layer, build_image_encoder #Haoyun: to be checked

In [2]:
@registry.register_model("concat_vl")
class LanguageAndVisionConcat(BaseModel):
    def __init__(self, config, *args, **kargs):
        #config contains all the information you stored 
        #in this model's config (hyperparameters)
        super().__init__(config, *args, **kargs)
     
    #This classmethod tells MMF where to look for default config of this model
    #Haoyun: what's difference between init config and this config_path?
    @classmethod
    def config_path(cls):
        return "configs/models/concat_vl.yaml" #Haoyun: to be checked
    
    def build(self):
        #Haoyun: self.config file to be checked
        self.language_module = ProjectionEmbedding(**self.config.text_encoder.params)
        self.vision_module = build_image_encoder(self.config.image_encoder)
        self.fusion = nn.Linear(**self.config.fusion.params)
        self.dropout = nn.Dropout(self.config.dropout)
        self.classifier = build_classifier_layer(self.config.classifier)
        
    def forward(self, sample_list):
        """
        reminder: the model take a sample_list as input!
        """
        text = sample_list["text"]
        image = sample_list["image"]
        
        text_features = nn.functional.relu(self.language_module(text))
        image_features == nn.functional.relu(self.vision_module(image))
        
        combined = torch.cat([text_features, image_features.squeeze(dim=1)], dim=1)
        
        fused = self.dropout(
            nn.functional.relu(
                self.fusion(combined)
            )
        )
        
        logits = self.classifier(fused)
        
       # For loss calculations (automatically done by MMF
       # as per the loss defined in the config), 
       # we need to return a dict with "scores" key as logits
        output = {"scores": logits}
        
        return output