# This notebook is used for extracting feature embeddings from the recipe text - NetID:  gg676, xl598, vt152, smk371

In [None]:
import json
import time
from ast import literal_eval
import glob
import os
import pyarrow

import pandas as pd
import torch
import gc
from torch import nn
from transformers import BertTokenizer, BertModel

import pickle
import numpy as np
from multiprocessing import Pool

In [None]:
# first json from the recipe text containing title, ingredients, id, instructions
with open('/common/home/gg676/536/data/text_data/layer1.json', 'r') as fp:
    data_1 = json.load(fp)

In [None]:
# second json from the recipe text containing the mapping b/w text id and image id(s)
with open('/common/home/gg676/536/data/text_data/layer2.json', 'r') as fp:
    data_2 = json.load(fp)

In [None]:
# I am creating a map using dictionary to make it run in constant time b/w text id and image id
data_3 = {}
for i in data_2:
    data_3[i['id']] = [os.path.splitext(j['id'])[0] for j in i['images']]

In [None]:
#Get ids of images from the id that we saved already from the image dataset
def load_ids(path):
    id_list = []
    id_dict = {}
    for i in glob.glob(path+'/*'):
        data = torch.load(i)
        id_list.extend(data)
    id_dict = dict.fromkeys(id_list)
    return id_dict
id_dict = load_ids('/common/home/gg676/536/data/ids')

In [None]:
# Finding all the text data that belong to either train, validation or test using the image id as reference
# I have already created a split data of image features along with their id
def prepare_df(data_1, data_2, id_dict):
    title_list = []
    ingredients_list = []
    instructions_list = []
    combined_list = []
    img_id_list = []
    id_list = []
    count = 0
    for item in data_1:
        flag = 0
        try:
            data_3[item['id']]
            for img_id in data_3[item['id']]:

                title_sent = ""
                ingredients_sent = ""
                instructions_sent = ""

                title_sent = title_sent + item['title']
                for i in item['ingredients']:
                    ingredients_sent = ingredients_sent + i['text']
                for i in item['instructions']:
                    instructions_sent = instructions_sent + i['text']
                title_list.append(title_sent)
                ingredients_list.append(ingredients_sent)
                instructions_list.append(instructions_sent)
                combined_list.append(title_sent + ingredients_sent + instructions_sent)
                img_id_list.append(img_id)
                id_list.append(item['id'])
        except KeyError:
                pass
            
        if count % 10000 == 0:
            print(count)
            
        count += 1
    #print(title_sent + ingredients_sent + instructions_sent)
    df = pd.DataFrame(zip(id_list, img_id_list, title_list, ingredients_list, instructions_list, combined_list), columns=['id', 'img_id', 'title', 'ingredients', 'instructions', 'combined'])
    return df

df = prepare_df(data_1, data_2, id_dict)

In [None]:
#Get both ids and features extracted from the image dataset.
#This is required because we will create a datafrane that can be merged with the text data to make sure the ids match
#by using left join
def load_data(path):
    data_list = []

    for i in glob.glob(path+'/*'):
        x = torch.load(i)
        data_list.extend(x)
    return data_list
id_list = load_data('/common/home/gg676/536/data/ids/')
feature_list = load_data('/common/home/gg676/536/data/features/')
df_image = pd.DataFrame(zip(id_list, feature_list), columns=['img_id', 'features'])

In [None]:
# left join wrt image id to make sure only the ones that match with the 
df_merged = pd.merge(df_image, df, how='left', on='img_id')

In [None]:
# this is done to tokenize each of title, ingredients, instructions and combined and store as pickle    
def prepare_df(data):
    title_list = []
    ingredients_list = []
    instructions_list = []
    combined_list = []
    id_list = []
    for item in data:
        title_sent = ""
        ingredients_sent = ""
        instructions_sent = ""

        title_sent = title_sent + item['title']
        for i in item['ingredients']:
            ingredients_sent = ingredients_sent + i['text']
        for i in item['instructions']:
            instructions_sent = instructions_sent + i['text']
        title_list.append(title_sent)
        ingredients_list.append(ingredients_sent)
        instructions_list.append(instructions_sent)
        combined_list.append(title_sent + ingredients_sent + instructions_sent)
        id_list.append(item['id'])
    #print(title_sent + ingredients_sent + instructions_sent)
    df = pd.DataFrame(zip(id_list, title_list, ingredients_list, instructions_list, combined_list), columns=['id', 'title', 'ingredients', 'instructions', 'combined'])
    return df    
def apply_tokenize(data, tokenizer, field_name):
    data[field_name] = data.apply(lambda row: tokenizer(row[field_name], max_length=512, padding='max_length', truncation=True), axis=1)
    return data

def parallelize_dataframe(df, tokenizer, func, field_name):
    df_split = np.array_split(df, 8)
    pool = Pool(8)
    df = pd.concat(pool.map(func, df_split, tokenizer, field_name))
    pool.close()
    pool.join()
    return df
    
def tokenize_data(df, tokenizer, field_no):
    counter = 0
    id_list = []
    img_id_list = []
    result_list = []
    start = time.time()
    #print(type(field_name))
    for row in df.itertuples():
        #result_list = df.apply(lambda row: tokenizer(row[field_name], max_length=512, padding='max_length', truncation=True), axis=1)
        img_id_list.append(row.img_id)
        id_list.append(row.id)
        result_list.append(tokenizer(row[field_no], max_length=512, padding='max_length', truncation=True))
        if counter % 100000 == 0:
            print(counter)
        counter += 1
    return id_list, img_id_list, result_list#df_res

def save_data(data, path, file_name):
    with open(path+file_name, 'wb') as fp:
        pickle.dump(data, fp)

def main():
    #input_path = 'C:/535/text_data/layer1.json'
    output_path ='/common/home/gg676/536/data/text_data/'
    #data = read_data(input_path)
    #df = prepare_df(data)
    #del data
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    col_pos_list = [4, 5, 6, 7]
    output_file_name_list = ['title', 'ingredients', 'instructions', 'combined']
    
    for idx, col_name in enumerate(output_file_name_list):
        print("idx: ", idx)
        start = time.time()
        id_list, img_id_list, result_list = tokenize_data(df_merged, tokenizer, col_pos_list[idx])
        print("Time Taken: ", time.time()-start)
        #df_res = parallelize_dataframe(df[:100], tokenizer, apply_tokenize, col_name)
        save_data(id_list, output_path, "id_col_"+col_name)
        save_data(img_id_list, output_path, "img_id_col_"+col_name)
        save_data(result_list, output_path, "result_col_"+col_name)
main()

In [None]:
def load_pickle_data(file_name):
    with open('/common/home/gg676/536/data/text_data/'+file_name, 'rb') as fp:
        x_data = pickle.load(fp)
    return x_data

x_data = load_pickle_data('result_col_combined')
x_id = load_pickle_data('id_col_combined')
df_x_data = pd.DataFrame(x_data)
x_input_id = df_x_data['input_ids']
x_attn_mask = df_x_data['attention_mask']

In [None]:
#os.environ["CUDA_VISIBLE_DEVICES"]="1,2,3"
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
#for param in model.bert.parameters():
#    param.requires_grad = False
model.eval()
if torch.cuda.device_count() > 1:
  print(torch.cuda.device_count(), "GPUs!")
  model = nn.DataParallel(model)

In [None]:
device = torch.device("cuda")

In [None]:
model.to(device)

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, ids, input_ids, attn_mask):
        self.id = ids
        #self.title = [tokenizer(text, max_length=512, padding='max_length', truncation=True) for text in df['title']]
        #self.ingredients = [tokenizer(text, max_length=512, padding='max_length', truncation=True) for text in df['ingredients']]
        #self.instructions = [tokenizer(text, max_length=512, padding='max_length', truncation=True) for text in df['instructions']]
        self.input_ids = input_ids
        self.attn_mask = attn_mask

    def __len__(self):
        return len(self.id)
    
    def get_batch_texts(self, idx):
        return self.id[idx], torch.tensor(self.input_ids[idx]), torch.tensor(self.attn_mask[idx])
    
    def __getitem__(self, idx):
        #batch_id, batch_input_ids, batch_attn_mask = torch.tensor(self.get_batch_texts(idx))
        return self.id[idx], torch.tensor(self.input_ids[idx]), torch.tensor(self.attn_mask[idx])

In [None]:
dataset = Dataset(x_id, x_input_id, x_attn_mask)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=256, num_workers=60)

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
def save_features(path, file_name, data):
    torch.save(data, path+file_name)
    
def extract_features(model, dataloader):
    start = time.time()
    feature_list = []
    id_list = []
    with torch.no_grad():
        for idx, batch_data in enumerate(dataloader):
            if idx % 4096 == 0:
                print(idx)
            batch_id, batch_input_ids, batch_attn_mask = batch_data
            #print(type(torch.Tensor(batch_input_ids)))
            batch_input_ids, batch_attn_mask = torch.from_numpy(np.asarray(batch_input_ids)).to('cuda'), torch.from_numpy(np.asarray(batch_attn_mask)).to('cuda')
            #batch_data = batch_data.to('cuda')
            outputs = model(batch_input_ids, batch_attn_mask)
            #print(outputs[0].shape)
            #hidden_states = outputs[0][:, 0]
            #print(hidden_states.shape)
            #break
            sent_embed = torch.stack(hidden_states[-4:]).sum(0)
            sent_embed = torch.mean(word_embed, dim=1)
            feature_list.extend(sent_embed.detach().cpu())
            id_list.extend(batch_id)
        print("Time taken: ", time.time() - start)
    return id_list, feature_list

In [None]:
ids_list, feature_list = extract_features(model, dataloader)
save_features('/common/home/gg676/536/data/text_data/results/', 'combined_id.pth', ids_list)
save_features('/common/home/gg676/536/data/text_data/results/', 'combined_features.pth', feature_list)