In [14]:
#!pip install -U sentence-transformers

def intermediate_rewards_function(three_frames, language_instruction):
    """This function takes 3 (consecutive) images/game frames and a language instruction and evaluates 
       if the agent's behaviour in the frames agrees with the given language instruction.
       
       It uses the the previously trained LIERN model and outputs a reward of `1' if agent's behaviour
       aligns with the given instruction (and `0' otherwise)."""
    
    
    ### IMPORT NECESSARY PACKAGES ###

    import numpy as np
    import torch
    import torch.nn as nn
    import torchvision.models as models
    import torchvision.transforms as transforms
    from torch.autograd import Variable
    from sentence_transformers import SentenceTransformer
    
    
    ### CONVERTING IMAGES TO EMBEDDING VECTOR ###
    
    img_model = models.resnet50(pretrained=True)
    layer = img_model._modules.get('avgpool')

    # Set model to evaluation mode
    img_model.eval()

    def get_img_embeddings(image_array):
        """gets 3 input frames/images and converts them to a single embedding vector"""

        temp_imgs_l = []

        for image in image_array:
            temp_img = Image.fromarray(image).convert('RGB')
            temp_img = normalize(to_tensor(scaler(temp_img)))
            temp_imgs_l.append(temp_img)

        temp_imgs_l = torch.cat(temp_imgs_l)
        temp_imgs_l = temp_imgs_l.reshape((3,3,224,224))

        # create PyTorch Var. w/ pre-processed image
        t_img = Variable(temp_imgs_l)

        # create an empty vector to hold the embeddings vector
        img_embedding = torch.zeros(2048*3)

        # function to copy the output of the layer
        def copy_data(m, i, o):
            img_embedding.copy_(o.data.reshape(2048*3))

        # attach that function to the `avgpool` layer
        h = layer.register_forward_hook(copy_data)

        # run the model on the image
        img_model(t_img)

        # remove the copy function from the layer
        h.remove()

        # return the feature vector
        return img_embedding

    
    ### CONVERTING LANGUAGE INSTRUCTION TO EMBEDDING VECTOR ###
    
    sentence_model = SentenceTransformer('roberta-large-nli-stsb-mean-tokens')

    def get_sentence_embeddings(sentence_model, language_instruct):

        sentc_embedding = sentence_model.encode(language_instruct)

        return sentc_embedding.reshape(-1).tolist()
    
    
    ### GENERATING THE COMBINED EMBEDDINGS VECTOR (INPUT TO CLASSIFICATION MODEL) ###
    
    def get_combined_embeddings(img_embed, sentence_embed):

        return np.concatenate([img_embed, sentence_embed]).tolist()
    
    
    ### GENERATE INTERMEDIATE REWARDS USING SAVED MODEL ###
    
    class Classification_Net(nn.Module):

        def __init__(self, n_features):
            super(Classification_Net, self).__init__()
            self.fc1 = nn.Linear(n_features, int(7168/2))
            self.fc2 = nn.Linear(int(7168/2), int(7168/10))
            self.fc3 = nn.Linear(int(7168/10), int(7168/100))
            self.fc4 = nn.Linear(int(7168/100), 1)

        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = F.relu(self.fc3(x))
            return torch.sigmoid(self.fc4(x))

    # load the previously trained model parameters
    classify_net = torch.load('classify_model.pth')

    def get_probability(model, combined_embed):
        
        classify_net = model
        embed_input = torch.from_numpy(np.array(combined_embed)).float()
        y_pred - classify_net(embed_input)
        return y_pred

    def generate_interim_reward(predicted_prob):

        if predicted_prob > .5:
            return 1
        else:
            return 0
        
    img_embeddings = get_img_embeddings(three_frames)
    sentence_embeddings =  get_sentence_embeddings(sentence_model, language_instruction)
    combined_embeddings = get_combined_embeddings(img_embeddings,sentence_embeddings)
    probability = get_probability(classify_net,combined_embeddings)
    interim_reward = generate_interim_reward(probability)
    
    return interim_reward