In [1]:
import gensim
import os
import collections
import smart_open
import random
import json
import torch
import torch.nn.functional as F
from gensim.test.utils import get_tmpfile
import numpy as np
from allennlp.modules.elmo import Elmo, batch_to_ids

def load_cleaned_data(file = 'train_cleaned.json'):
    file = open(file, 'r', encoding='utf8').read()
    recipe = json.loads(file) #json file contains data in str, convert str to dict
    recipe_context = recipe['context']
    recipe_answer = recipe['answer']
    recipe_choice = recipe['choice']
    recipe_question = recipe['question']
    recipe_images = recipe['images']
    return recipe_context, recipe_images, recipe_question, recipe_choice, recipe_answer 
def accuracy(preds, y):
    preds = F.softmax(preds, dim=1)
    correct = 0 
    pred = preds.max(1, keepdim=True)[1]
    correct += pred.eq(y.view_as(pred)).sum().item()
    acc = correct/len(y)
    return acc 
def seed_everything(seed=123):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything()

recipe_context, recipe_images, recipe_question, recipe_choice, recipe_answer = load_cleaned_data('train_cleaned.json')
recipe_context_valid, recipe_images_valid, recipe_question_valid, recipe_choice_valid, recipe_answer_valid = load_cleaned_data('val_cleaned.json')
fname = get_tmpfile("/Users/LYB/Desktop/coursework/Msc-project/recipe_baseline/hasty_simple_baseline/my_doc2vec_model")
model = gensim.models.doc2vec.Doc2Vec.load(fname) 
print()
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
elmo = Elmo(options_file, weight_file, 1, dropout=0.2, requires_grad = False)
from tqdm import tqdm




In [2]:
import re
choice = []
for i in tqdm(range(len(recipe_choice_valid))):
    temp = []
    for j in range(len(recipe_choice_valid[i])):
        a = gensim.utils.simple_preprocess(recipe_choice_valid[i][j]) 
        #a = re.sub(r"[^a-zA-Z0-9]"," ", recipe_choice_valid[i][j]).lower().split()
        if len(a) == 0:
            a = ['0']
        character_ids = batch_to_ids([a]) 
        b  = torch.sum(elmo(character_ids)['elmo_representations'][0], dim=1)
        temp.append(b.detach().numpy())
    choice.append(temp) 

100%|██████████| 961/961 [06:44<00:00,  2.26it/s]


In [3]:
question = []
for i in recipe_question_valid:
    question.append(gensim.utils.simple_preprocess(' '.join(i)))
for i in tqdm(range(len(question))):
    character_ids = batch_to_ids([question[i]])
    question[i] = torch.sum(elmo(character_ids)['elmo_representations'][0], dim=1) 

100%|██████████| 961/961 [03:47<00:00,  6.08it/s]


In [4]:
choice = torch.FloatTensor(choice)

In [5]:
for i in range(len(question)):
    question[i] = question[i].detach().numpy()

In [6]:
question = torch.FloatTensor(question)

In [7]:
choice = choice.squeeze(2)
question = question.squeeze(1)

In [8]:
def exponent_neg_manhattan_distance(x1, x2):
        return torch.sum(torch.abs(x1 - x2), dim=1)
def cosine_dot_distance(x1, x2):
        return torch.sum(torch.mul(x1, x2), dim=1)
def Infersent(x1, x2): 
        a = torch.nn.functional.cosine_similarity(x1, x2.expand(4, -1), dim=1)
        b = torch.sum(torch.mul(x1, torch.abs(x1 - x2)), dim=1)
        c = torch.sum(torch.mul(x1, x1 * x2),dim=1) 
        d = torch.matmul(x2, torch.abs(x1 - x2).permute(1, 0))
        e = torch.matmul(x2, (x1 * x2).permute(1, 0))  
        f = torch.sum(torch.mul(torch.abs(x1 - x2), x1 * x2), dim=1)
        return a
answer = [] 
for i in range(len(question)):
    answer.append(Infersent(choice[i], question[i]).numpy())
answer = torch.FloatTensor(answer) 
answer_valid = torch.LongTensor(recipe_answer_valid)
acc_val = accuracy(answer, answer_valid)
print('validation accuracy', acc_val) 

validation accuracy 0.30697190426638915
