In [21]:
import json
import time
import os
import torch
from collections import defaultdict
all_word_list = json.load(open('word_freq_sorted.json'))

In [None]:
from tqdm.notebook import tqdm
bert_cache_dir = 'bert_cache_dir.pt'
bert_cache_dict = torch.load(bert_cache_dir) if os.path.exists(bert_cache_dir) else {}

In [None]:
def inference(model,tokenizer,word, model_name):
    if model_name not in bert_cache_dict:
        bert_cache_dict[model_name] = {}
    if word in bert_cache_dict[model_name]:
        return bert_cache_dict[model_name][word]
    
    inputs1 = tokenizer(word, return_tensors="pt").to('mps')
    with torch.no_grad():
        outputs1 = model(**inputs1)
    embedding = outputs1.last_hidden_state.mean(dim=1)
    bert_cache_dict[model_name][word] = embedding
    return embedding

In [None]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

# Function to calculate similarity
def calculate_similarity(model, tokenizer, word1, word2, model_name):

    embeddings1 = inference(model,tokenizer,word1, model_name)
    embeddings2 = inference(model,tokenizer,word2, model_name)

    # Calculate cosine similarity
    similarity = F.cosine_similarity(embeddings1, embeddings2)
    return similarity.item()

# Initialize model and tokenizer
model_name = 'bert-large-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = AutoModel.from_pretrained(model_name).to('mps')


In [None]:
import numpy as np
from gensim.models import KeyedVectors
from gensim.downloader import load

# Load a pre-trained Word2Vec model
w2v_model = load('word2vec-google-news-300')
    
def calculate_similarity_w2v(model, text1, text2):
    def get_average_vector(text):
        vectors = [model[word] for word in text.split() if word in model.key_to_index]
        if len(vectors) == 0:
            return None
        return np.mean(vectors, axis=0)

    vec1 = get_average_vector(text1)
    vec2 = get_average_vector(text2)

    if vec1 is not None and vec2 is not None:
        return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))
    return 0


### Strategy
1. hint priority : longest word
2. 2 x 2 pair work. We can solve two puzzles in parallel.

In [None]:
import ipywidgets as widgets
from IPython.display import display

general = [
    widgets.Text(description="Topic: ",disabled=False),
    widgets.Label('Type your general prompt here.'),
    widgets.Text(''),
    widgets.Label('Add your target words.'),
    widgets.HBox([
        widgets.Button(description='Add word'),
        widgets.Button(description='Remove last word'),
        widgets.Button(description='Fix word list'),
        widgets.Button(description='check validity')
    ]),
]
calc = widgets.HBox([
    widgets.Button(description='Calculate candidates'),
    widgets.Button(description='Explore candidates for: '),
    widgets.IntText(description='Word id:',disabled=False),
])
words = widgets.Tab()
word_values = []
word_relations = [] # (i,p,j,q) : pth position of word i overlaps with qth position of word j
histories = {} # set((i,word))

suggestions = widgets.VBox([
    widgets.HBox([
        widgets.Label(''),
        widgets.Checkbox(value=False,label='checked?'),
    ]) for i in range(30)
])

saving = [
    widgets.HBox([    widgets.Button(description='Save'),
    widgets.Button(description='Load')]),
    widgets.HBox([
    widgets.Label('Save/Load file name'),
    widgets.Text('puzzle_1_backup.json')])
]
output = widgets.Output()

def state2dict():
    global histories
    global general
    global word_values
    statedict = {'topic': general[0].value,'prompt':general[2].value, 'words':[]}
    for word in word_values:
        word_dict = {}
        word_dict['len'] = word.children[0].value
        word_dict['hint'] = word.children[1].value
        word_dict['ans'] = word.children[-1].value
        word_dict['rules'] = []
        for rule in word.children[3:-1]:
            word_dict['rules'].append([x.value for x in rule.children])
        statedict['words'].append(word_dict)
    statedict['histories'] = [[i,word] for i,word in histories]
    return statedict

def dict2state(statedict):
    global histories
    global general
    global word_values
    global words
    general[0].value = statedict['topic']
    general[2].value = statedict['prompt']
    word_values = []
    for word_dict in statedict['words']:
        add_word(None)
        word_values[-1].children[0].value = word_dict['len']
        word_values[-1].children[1].value = word_dict['hint']
        word_values[-1].children[-1].value = word_dict['ans']
        for rule, rule_obj in zip(word_dict['rules'], word_values[-1].children[3:-1]):
            for child, r in zip(rule_obj.children,rule):
                child.value = r
        words.children = word_values
    histories = { (x[0],x[1]) for x in statedict['histories'] }
    
def progress_save(b):
    file_name = saving[1].children[1].value
    statedict = state2dict()
    json.dump(statedict,open(file_name,'w'))
    torch.save(bert_cache_dict,bert_cache_dir)

def progress_load(b):
    file_name = saving[1].children[1].value
    statedict = json.load(open(file_name))
    dict2state(statedict)

saving[0].children[0].on_click(progress_save)
saving[0].children[1].on_click(progress_load)

def add_word(b):
    global words
    global word_values

    # 길이, 힌트,
    word_values.append(widgets.VBox([
        widgets.IntText(
                description='Length:',
                disabled=False
            ),
        widgets.Text(description='Hint:',disabled=False),
        widgets.Label('Relations. idx starts from 0. Set idx to negative to disable.'),
        widgets.HBox([
            widgets.IntText(value=-1,description='Char idx:',disabled=False),
            widgets.IntText(description='Word id:',disabled=False),
            widgets.IntText(description='Target idx:',disabled=False)
        ]),
        widgets.HBox([
            widgets.IntText(value=-1,description='Char idx:',disabled=False),
            widgets.IntText(description='Word id:',disabled=False),
            widgets.IntText(description='Target idx:',disabled=False)
        ]),
        widgets.HBox([
            widgets.IntText(value=-1,description='Char idx:',disabled=False),
            widgets.IntText(description='Word id:',disabled=False),
            widgets.IntText(description='Target idx:',disabled=False)
        ]),
        widgets.HBox([
            widgets.IntText(value=-1,description='Char idx:',disabled=False),
            widgets.IntText(description='Word id:',disabled=False),
            widgets.IntText(description='Target idx:',disabled=False)
        ]),
        widgets.HBox([
            widgets.IntText(value=-1,description='Char idx:',disabled=False),
            widgets.IntText(description='Word id:',disabled=False),
            widgets.IntText(description='Target idx:',disabled=False)
        ]),
        widgets.HBox([
            widgets.IntText(value=-1,description='Char idx:',disabled=False),
            widgets.IntText(description='Word id:',disabled=False),
            widgets.IntText(description='Target idx:',disabled=False)
        ]),
        widgets.Text(description='Answer:',disabled=False)
    ]))
    word_values[-1].id = len(word_values)-1
    words.children=word_values
    words.set_title(len(word_values)-1,f"word #{len(word_values)}")
    display_all()
    
def remove_word(b):
    global words
    global word_values
    if len(word_values) > 0:
        word_values = word_values[:-1]
        words.children=word_values
        display_all()

fix = False
def disable(b):
    global general
    global fix
    if fix:
        fix = False
        general[4].children[0].disabled = False
        general[4].children[1].disabled = False
    else:
        fix = True
        general[4].children[0].disabled = True
        general[4].children[1].disabled = True

def display_all():
    with output:
        output.clear_output() 
        for g in general:
            display(g)
        if len(words.children)>0:
            display(words)    
        display(calc)
        display(progress_bar)
        for s in saving:
            display(s)
            
sim_scores = defaultdict(dict) # {1:{'word':0.8,...,'hihi':0.6}}

def recursive_explorer(fixed_rules,sim_scores,word_values,topic,num_per_iter = 20,skip_tuples=set()):
    global all_word_list
    for word in word_values:
        i = word.id
        if not len(word.children[-1].value) == word.children[0].value:
            filtered_keys = []
            for w in all_word_list:
                if len(w) == word.children[0].value:
                    if (i,w) in skip_tuples:
                        continue
                    if i in sim_scores and w in sim_scores[i]:
                        continue
                    chk=True
                    for rule in fixed_rules:
                        if rule[0] == i and w[rule[1]] != rule[2]:
                            chk=False
                            break
                    if chk:
                        filtered_keys.append(w)
                if len(filtered_keys) >= num_per_iter:
                    break
            
            for k in filtered_keys:
                comparing_str = topic if '' == word.children[1].value else word.children[1].value
                # if hint exist, prioritize hint.
                sim = calculate_similarity_w2v(w2v_model, comparing_str, k)
                # sim = calculate_similarity(model, tokenizer, comparing_str, k)
                sim_scores[i][k] = sim * 0.8 if word.children[1].value == '' else sim
                # give adventage to hint-driven one.

def flatten_sim_scores(sim_scores):
    ans = []
    for i, vdict in sim_scores.items():
        ans += [(i,word,score) for word,score in vdict.items()]
        # ans += [(i,word,score) for word,score in vdict.items() if (i,word) not in skip_tuples]
    ans.sort(key=lambda x:-x[2])
    return ans

def calculate_candidates(b):
    # using only bert for this function.
    start_time = time.time()
    global progress_bar
    global word_values
    global general
    global sim_scores
    global suggestions
    global histories
    progress_bar.value = 0
    for s in suggestions.children:
        if s.children[1].value == True and s.children[0].value != '':
            _ = s.children[0].value.split('\t')
            histories.add((int(_[0])-1,_[1]))
            
        
    num_per_iter = 40
    skip_tuples = {_ for _ in histories}
    # first, collect fixed word
    fixed_rules = [] # (i,p_i, x) : p_ith position of word i should be x
    solved = set()
    for i,word in enumerate(word_values):
        if len(word.children[-1].value) == word.children[0].value:
            solved.add(i)
            for rule in word.children[3:-1]:
                if rule.children[0].value >= 0:
                    target = word.children[-1].value[rule.children[0].value]
                    fixed_rules.append((rule.children[1].value-1,rule.children[2].value,target))
    # Second. calculate top 20 candidates for each word
    progress_bar.value = 1
    recursive_explorer(fixed_rules,sim_scores,word_values,general[0].value,num_per_iter,skip_tuples)
    # choose top 20 best scenarios by greedy
    solutions = []
    for s in solved:
        if s in sim_scores:
            del sim_scores[s]
    top_scores = flatten_sim_scores(sim_scores)[:num_per_iter] # (i,word,score)
    skip_tuples.union(set([(i,word) for i,word,sim in top_scores]))
    for i,word,sim in top_scores:
        progress_bar.value = progress_bar.value + 1
        virtual_solved = {_ for _ in solved}
        virtual_solved.add(i)
        virtual_rules = [_ for _ in fixed_rules]
        solution = {i:(word,sim)}
        prev_id = i
        prev_word = word
        flag = True
        while(len(virtual_solved)<len(word_values)):
            virtual_sim_scores = defaultdict(dict)
            for rule in word_values[prev_id].children[3:-1]:
                if rule.children[0].value >= 0:
                    target = prev_word[rule.children[0].value]
                    virtual_rules.append((rule.children[1].value-1,rule.children[2].value,target))
            recursive_explorer(
                virtual_rules,
                virtual_sim_scores,
                [x for j,x in enumerate(word_values) if j not in virtual_solved],
                general[0].value,
                20,
                skip_tuples
            )
            candis = flatten_sim_scores(virtual_sim_scores)
            if len(candis) == 0:
                flag = False
                break
            top_score = candis[0]
            virtual_solved.add(top_score[0])
            solution[top_score[0]] = (top_score[1],top_score[2])
            prev_id = top_score[0]
            prev_word = top_score[1]
            skip_tuples.add((prev_id,prev_word))
        if flag:
            solutions.append(solution)
            
    # calculate average score for rules and display top 30.
    solutions.sort(key=lambda x:-1*sum([v[1] for v in x.values()]))
    s_cnt = 0
    for solution in solutions:
        for i, (word,sim) in solution.items():
            suggestions.children[s_cnt].children[0].value = f"{i+1}\t{word}\t{sim}"
            suggestions.children[s_cnt].children[1].value = False
            s_cnt += 1
            if s_cnt == len(suggestions.children):
                break
        if s_cnt == len(suggestions.children):
            break
    if s_cnt < len(suggestions.children):
        for i in range(s_cnt,len(suggestions.children)):
            suggestions.children[i].children[0].value = ""
    display_all()
    with output:
        print("--- %s seconds ---" % (time.time() - start_time))
        display(suggestions)
        
def explore_candidates(b): # exploring with bert.
    global progress_bar
    global word_values
    global general
    global suggestions
    global histories
    global calc
    progress_bar.value = 0
    start_time = time.time()
    target_id = calc.children[2].value - 1
    fixed_rules = [] # (i,p_i, x) : p_ith position of word i should be x
    
    for i,word in enumerate(word_values):
        if len(word.children[-1].value) == word.children[0].value:
            for rule in word.children[3:-1]:
                if rule.children[0].value >= 0:
                    target = word.children[-1].value[rule.children[0].value]
                    if rule.children[1].value-1 == target_id:
                        fixed_rules.append((rule.children[1].value-1,rule.children[2].value,target))
    filtered_keys = []
    for w in all_word_list:
        if len(w) == word_values[target_id].children[0].value:
            if (target_id,w) in histories:
                continue
            chk=True
            for rule in fixed_rules:
                if w[rule[1]] != rule[2]:
                    chk=False
                    break
            if chk:
                filtered_keys.append(w)
    hint = word_values[target_id].children[1].value
    combined_prompt = general[0].value if hint else f'{general[0].value}: {hint}'
    word_w2v_scores = [(k,calculate_similarity_w2v(w2v_model, combined_prompt.replace(':',' '), k)) for k in filtered_keys]
    word_w2v_scores.sort(key=lambda x:-x[1])
    
    top_4000 = word_w2v_scores[:4000]
    word_bert_scores = []
    for i,(k,_) in enumerate(top_4000):
        if i%100 == 0:
            progress_bar.value = progress_bar.value+1
        word_bert_scores.append((k,calculate_similarity(bert_model,tokenizer,combined_prompt,k,model_name)) )
    word_bert_scores.sort(key=lambda x:-x[1])
    top_30 = word_bert_scores[:30]
    for suggestion, (word,sim) in zip(suggestions.children,top_30):
        suggestion.children[0].value = f"{target_id+1}\t{word}\t{sim}"
        suggestion.children[1].value = False
    display_all()
    with output:
        print("--- %s seconds ---" % (time.time() - start_time))
        display(suggestions)

valid_fix= False
def validation(b):
    global valid_fix
    if valid_fix:
        valid_fix=False
        for word_value in word_values:
            
            word_value.children[0].disabled = False
            for rule in word_value.children[3:-1]:
                for ch in rule.children:
                    ch.disabled = False
    else:
        is_valid = True
        word_len = [wv.children[0].value for wv in word_values]
        conditions = {} # {A : {B: (i,j)}}
        for word_value in word_values:  
            conditions[word_value.id] = {}
            for rule in word_value.children[3:-1]:
                if rule.children[0].value >= 0:
                    conditions[word_value.id][rule.children[1].value - 1] = (rule.children[0].value,rule.children[2].value)
        for a in conditions.keys():
            for b in conditions[a].keys():
                if a not in conditions[b]:
                    is_valid = False
                    with output:
                        print(f'Not valid! word {a+1} and {b+1} have a conflicting condition')
                        continue
                    
                if conditions[a][b][0] != conditions[b][a][1] or conditions[a][b][1] != conditions[b][a][0]:
                    is_valid = False
                    with output:
                        print(f'Not valid! word {a+1} and {b+1} have a conflicting condition')
                if conditions[a][b][0] >= word_len[a]:
                    is_valid = False
                    with output:
                        print(f'Not valid! word {a+1} is shorter than its requirements from condition!')
        
        if is_valid:
            valid_fix= True
            with output:
                print('Valid!')
            for word_value in word_values:
                word_value.children[0].disabled = True
                for rule in word_value.children[3:-1]:
                    for ch in rule.children:
                        ch.disabled = True
    
general[4].children[0].on_click(add_word)
general[4].children[1].on_click(remove_word)
general[4].children[2].on_click(disable)
general[4].children[3].on_click(validation)
calc.children[0].on_click(calculate_candidates)
calc.children[1].on_click(explore_candidates)

progress_bar = widgets.IntProgress(
    value=0,    min=0,    max=41,    step=1,    description='Calculating:',    bar_style='', # 'success', 'info', 'warning', 'danger' or ''
    orientation='horizontal'
)

display_all()

display(output)