In [None]:
#default_exp model

# Model
> Our project will require the use of 4 separate types of models: an image model, a text model, a tabular model, and a decoder network. The relationship between the 4 can be seen in the figure below. 

![](model_diagram.jpg)

In [None]:
#exporti

from fastai.data.core import DataLoaders
from transformers import DistilBertModel
from DSAI_proj.dataset import *
from torch import nn
from functools import partial
import torchvision.models as models

In [None]:
#exporti

def freeze_all_but_layer(m, layer):
    if not isinstance(m, layer):
        if hasattr(m, 'weight') and m.weight is not None:
            m.weight.requires_grad_(False)
        if hasattr(m, 'bias') and m.bias is not None:
            m.bias.requires_grad_(False)

We first design a cnn_encoder module using a pretrained resnet 18 architecture. We will keep the weights frozen as we do not want them to be updated too much in the training process. We also unfreeze the batchnorm layers, as these have been shown to learn the distributions better when unfrozen during fine-tuning. 

In [None]:
#export

def cnn_encoder(pretrained: bool, in_channels: int, out_channels: int):
    model = models.resnet18(pretrained=pretrained)
    last_layers = [nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False), nn.AdaptiveAvgPool2d(1)]
    model = nn.Sequential(*list(model.children())[:-2], *last_layers)
    img_freeze_fn = partial(freeze_all_but_layer, layer=nn.BatchNorm2d)
    model.apply(img_freeze_fn)
    return model

Next, we design a text_encoder module, which will consist of the encoder layers of DistilBert. Similar to the cnn_encoder, we freeze all the layers except the normalization layers, which in this case is LayerNorm. 

In [None]:
#export

def text_encoder(model_type: str):
    model = DistilBertModel.from_pretrained(model_type)
    text_freeze_fn = partial(freeze_all_but_layer, layer=nn.LayerNorm) 
    model.apply(text_freeze_fn)
    return model

We will also need a module for our tabular meta data, and hence use a simple linear layer which will map the input meta data to necessary output shape required. 

In [None]:
#export

def meta_encoder(in_channels: int, out_channels: int):
    model = nn.Linear(in_features=in_channels, out_features=out_channels)
    return model

The last piece of the puzzle is a decoder network that will decode the outputs of the above 3 encoder modules and produce the predicted score where the last dimension represents the vocabulary size of the model. In other words, these are the raw logits distributed across all possible words, and a softmax will be applied to determine the most likely word. 

In [None]:
#export

def decoder(hidden_dim: int, nhead: int, num_decoders: int, vocab_size: int):
    decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=nhead)
    transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoders)
    linear_layer = nn.Linear(in_features=hidden_dim out_features=vocab_size)
    model = nn.Sequential([transformer_decoder, linear_layer])
    return model

In [None]:
class TaglinePredictorModel(nn.Module):
    
    def __init__(self, vocab_size: int, meta_features: int):
        super(TaglinePredictorModel, self).__init__()
        self.cnn_encoder = cnn_encoder(pretrained=True, in_channels=512, out_channels=768)
        self.text_encoder = text_encoder(model_type='distilbert-base-uncased')
        self.meta_encoder = meta_encoder(in_channels=meta_features, out_channels=768)
        self.decoder = decoder(hidden_dim=768, nhead=12, num_decoders=3, vocab_size=30522)
        
    def forward(self, x: dict):
        poster_feature = self.cnn_encoder(x['poster_img']).squeeze(-1).view(0, 2, 1)
        backdrop_feature = self.cnn_encoder(x['backdrop_img']).squeeze(-1).view(0, 2, 1)
        text_feature = self.text_encoder(x['text_inputs'])
        meta_feature = self.meta_encoder(x['meta']).unsqueeze(1)
        
        hidden_dim = torch.cat((poster_feature, backdrop_feature, text_feature, meta_feature), dim=1)
        for i in range()

The values for this Tagline model are mostly hard-coded as we are limited by architectural choices. As we will be using DistilBert, the 