In [None]:
import torch
from torchvision import models
from torchvision import transforms
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator, Vocab 
from torchtext.vocab import GloVe
from PIL import Image
from collections import Counter

In [None]:
# Define image transformations
image_transforms = transforms.Compose([
    
    transforms.Resize((224, 224)),  #size expected by ResNet50
    transforms.ToTensor(),  #Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  #Normalize to match ImageNet dataset
])

tokenizer = str.split

# Function to yield list of tokens
def yield_tokens(data_iter):

    for text_str in data_iter:

        yield tokenizer(text_str)


counter = Counter()

for tokens in yield_tokens("TRAIN_DATA_PATH"):

    counter.update(tokens)
    
#Vocabulary
vocab = build_vocab_from_iterator(yield_tokens(train_data), min_freq=1, specials=('<unk>', '<BOS>', '<EOS>', '<PAD>'))
vocab.load_vectors(GloVe(name='6B', dim=300))

def process_text(text_str):
    
    #Tokenize text
    tokens = tokenizer(text_str)

    #Tokens to indices
    text = [vocab[token] for token in tokens]

    return text

def load_data(image_path, text_str):

    # Load and preprocess image
    image = Image.open(image_path)
    image = image_transforms(image)

    #Convert text to embeddings
    text = process_text(text_str)
    text = torch.tensor(text)

    return image, text


In [None]:
# Define the MRNN class, which inherits from nn.Module
class MRNN(nn.Module):
    
    def __init__(self, cnn_output_size, rnn_hidden_size, output_size, num_layers):

        super(MRNN, self).__init__()

        #Initialize variables for hidden size and number of layers
        self.rnn_hidden_size = rnn_hidden_size
        self.num_layers = num_layers

        #Use ResNet50 for image processing to extract features of images
        self.cnn = models.resnet50(pretrained=True)

        #Freeze the weights of the pre-trained CNN
        for param in self.cnn.parameters():

            param.requires_grad = False

        #Change the last layer of the CNN to have output size as 'cnn_output_size'
        self.cnn.fc = nn.Linear(self.cnn.fc.in_features, 
                                cnn_output_size)

        #LSTM for text processing
        self.text_rnn = nn.LSTM(input_size=300, 
                                hidden_size=rnn_hidden_size, 
                                num_layers=num_layers, 
                                batch_first=True)

        #Use another LSTM for the final output taking both the output from the CNN and the output from the text processing LSTM
        self.output_rnn = nn.LSTM(input_size=cnn_output_size + rnn_hidden_size, 
                                  hidden_size=output_size, 
                                  num_layers=num_layers, 
                                  batch_first=True)

    def forward(self, images, text):

        #Process the images through the CNN
        cnn_output = self.cnn(images)

        #Process the text through the text processing LSTM (initialize the hidden state and cell state for the LSTM) 
        h0 = torch.zeros(self.num_layers, text.size(0), 
                         self.rnn_hidden_size).to(text.device)
        c0 = torch.zeros(self.num_layers, text.size(0), 
                         self.rnn_hidden_size).to(text.device)

        #Pass the text and the initial hidden and cell states to the LSTM
        text_rnn_output, _ = self.text_rnn(text, (h0, c0))

        # Combine the outputs of the CNN and text LSTM (extra dimension to the CNN output to match the shape of the text LSTM output)
        combined = torch.cat((cnn_output.unsqueeze(1), text_rnn_output), 
                             dim=-1)

        # Generate the final output with the combined LSTM (initialize the hidden and cell states for the LSTM)
        h0 = torch.zeros(self.num_layers, 
                         combined.size(0), 
                         self.rnn_hidden_size).to(combined.device)
        c0 = torch.zeros(self.num_layers, 
                         combined.size(0), 
                         self.rnn_hidden_size).to(combined.device)

        #Pass the combined output and the initial hidden and cell states to the LSTM
        output, _ = self.output_rnn(combined, (h0, c0))

        return output