In [27]:
import warnings
warnings.filterwarnings("ignore")
import openai
from starlette.config import Config
from openai.embeddings_utils import distances_from_embeddings
import pandas as pd
import spacy
from collections import Counter
from string import punctuation

config = Config(".env")
openai.api_key = config("OPENAI_KEY", cast=str)

In [32]:
data = {'tag':['温かい, 苦い', '温かい, 酸っぱい',
               '温かい, 甘い', '温かい, 辛い',
               '冷たい, 酸っぱい',
               '冷たい, 甘い', '冷たい, 辛い'],
        'recipe':['ゴーヤチャンプルー', '手羽元の黒酢煮',
                  'フォンダンショコラ','チゲ鍋',
                  'フォー風うどん',
                  'かき氷','冷やし中華']}
recipe = pd.DataFrame(data)
recipe['embeddings'] = recipe['tag'].apply(lambda x: openai.Embedding.create(input=x, engine='text-embedding-ada-002')['data'][0]['embedding'])
recipe

Unnamed: 0,tag,recipe,embeddings
0,"温かい, 苦い",ゴーヤチャンプルー,"[0.018563181161880493, -0.01882537268102169, 0..."
1,"温かい, 酸っぱい",手羽元の黒酢煮,"[0.020437199622392654, -0.00019845727365463972..."
2,"温かい, 甘い",フォンダンショコラ,"[0.019048359245061874, -0.012244462966918945, ..."
3,"温かい, 辛い",チゲ鍋,"[0.016103975474834442, -0.008901920169591904, ..."
4,"冷たい, 酸っぱい",フォー風うどん,"[0.0163579024374485, -6.880174623802304e-05, -..."
5,"冷たい, 甘い",かき氷,"[0.014005844481289387, -0.010488653555512428, ..."
6,"冷たい, 辛い",冷やし中華,"[0.010120761580765247, -0.009949223138391972, ..."


In [36]:
def get_answers(question, df):
    q_embeddings = openai.Embedding.create(input=question, engine='text-embedding-ada-002')['data'][0]['embedding']
    
    vec_q = q_embeddings
    vec_a = df['embeddings'].values

    df['distances'] = distances_from_embeddings(vec_q, vec_a, distance_metric='cosine')
    return df.sort_values(by=['distances'])

def get_hotwords(text):
    result = []
    pos_tag = ['PROPN', 'ADJ', 'NOUN'] 
    doc = nlp(text.lower()) 
    for token in doc:
        if(token.text in nlp.Defaults.stop_words or token.text in punctuation):
            continue
        if(token.pos_ in pos_tag):
            result.append(token.text)
    return result

In [37]:
nlp = spacy.load("ja_core_news_sm")

new_text = """
寒いから温かくて辛いものが食べたい．
"""
output = set(get_hotwords(new_text))
common = Counter(output).most_common()

keywords = []
for item in common:
    keywords.append(item[0])
keywords = '，'.join(keywords)
keywords

'温かく，寒い，辛い'

In [38]:
get_answers(keywords, recipe.copy())

Unnamed: 0,tag,recipe,embeddings,distances
3,"温かい, 辛い",チゲ鍋,"[0.016103975474834442, -0.008901920169591904, ...",0.054743
0,"温かい, 苦い",ゴーヤチャンプルー,"[0.018563181161880493, -0.01882537268102169, 0...",0.073621
6,"冷たい, 辛い",冷やし中華,"[0.010120761580765247, -0.009949223138391972, ...",0.080765
2,"温かい, 甘い",フォンダンショコラ,"[0.019048359245061874, -0.012244462966918945, ...",0.087933
1,"温かい, 酸っぱい",手羽元の黒酢煮,"[0.020437199622392654, -0.00019845727365463972...",0.090033
5,"冷たい, 甘い",かき氷,"[0.014005844481289387, -0.010488653555512428, ...",0.092103
4,"冷たい, 酸っぱい",フォー風うどん,"[0.0163579024374485, -6.880174623802304e-05, -...",0.103872
