In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
from torchvision import models


In [None]:
# Define the MRNN class, which inherits from nn.Module
class MRNN(nn.Module):
    
    def __init__(self, cnn_output_size, rnn_hidden_size, output_size, num_layers):
        # Call the init function of the parent class
        super(MRNN, self).__init__()

        # Initialize variables for hidden size and number of layers
        self.rnn_hidden_size = rnn_hidden_size
        self.num_layers = num_layers

        # Use a pre-trained CNN (ResNet50 in this case) for image processing
        # This CNN will be used to extract features from images
        self.cnn = models.resnet50(pretrained=True)

        # Freeze the weights of the pre-trained CNN so they won't be updated during training
        for param in self.cnn.parameters():
            param.requires_grad = False

        # Change the last layer of the CNN to have output size as 'cnn_output_size'
        self.cnn.fc = nn.Linear(self.cnn.fc.in_features, cnn_output_size)

        # Use an LSTM for text processing
        # This LSTM will be used to process sequences of word embeddings (of size 300)
        self.text_rnn = nn.LSTM(input_size=300, hidden_size=rnn_hidden_size, num_layers=num_layers, batch_first=True)

        # Use another LSTM for the final output
        # This LSTM takes in both the output from the CNN and the output from the text processing LSTM
        self.output_rnn = nn.LSTM(input_size=cnn_output_size + rnn_hidden_size, hidden_size=output_size, num_layers=num_layers, batch_first=True)

    def forward(self, images, text):
        # Process the images through the CNN
        cnn_output = self.cnn(images)

        # Process the text through the text processing LSTM
        # First, we initialize the hidden state and cell state for the LSTM
        h0 = torch.zeros(self.num_layers, text.size(0), self.rnn_hidden_size).to(text.device)
        c0 = torch.zeros(self.num_layers, text.size(0), self.rnn_hidden_size).to(text.device)

        # Then we pass the text and the initial hidden and cell states to the LSTM
        text_rnn_output, _ = self.text_rnn(text, (h0, c0))

        # Combine the outputs of the CNN and text LSTM
        # We add an extra dimension to the CNN output to match the shape of the text LSTM output
        combined = torch.cat((cnn_output.unsqueeze(1), text_rnn_output), dim=-1)

        # Generate the final output with the combined LSTM
        # First, we initialize the hidden state and cell state for the LSTM
        h0 = torch.zeros(self.num_layers, combined.size(0), self.rnn_hidden_size).to(combined.device)
        c0 = torch.zeros(self.num_layers, combined.size(0), self.rnn_hidden_size).to(combined.device)

        # Then we pass the combined output and the initial hidden and cell states to the LSTM
        output, _ = self.output_rnn(combined, (h0, c0))

        # Return the final output
        return output