In [1]:
import numpy as np
import pandas as pd
import re
from sentence_transformers import SentenceTransformer
import torch
from sklearn.feature_extraction.text import TfidfVectorizer
from numpy.linalg import norm
from sklearn.metrics.pairwise import cosine_similarity

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

retriever = SentenceTransformer(
    "paraphrase-MiniLM-L6-v2",
    device = device
)

In [3]:
dict_df = pd.read_csv("data/dictionary.csv")

dict_df

Unnamed: 0,Word,POS,Definition
0,A,,The first letter of the English and of many ot...
1,A,,The name of the sixth tone in the model major ...
2,A,,An adjective commonly called the indefinite ar...
3,A,,"In each; to or for each; as """"""""twenty leagues..."
4,A,prep.,In; on; at; by.
...,...,...,...
175718,Zymotic,a.,Of pertaining to or caused by fermentation.
175719,Zymotic,a.,Designating or pertaining to a certain class o...
175720,Zythem,n.,See Zythum.
175721,Zythepsary,n.,A brewery.


In [4]:
words = [
    "CRAB", "RAY", "SPONGE", "SQUID", 
    "CIRCLE", "DIAMOND", "SQUARE", "TRIANGLE",
    "BOB", "CROSS", "HOOK", "WEAVE",
    "FEAST", "FREE", "PANTS", "THAT"
]

In [5]:
dict_df["Word"] = dict_df["Word"].str.upper()

In [6]:
dict_df = dict_df[dict_df["Word"].isin(words)]

dict_df = dict_df.reset_index()
dict_df

Unnamed: 0,index,Word,POS,Definition
0,17498,BOB,n.,Anything that hangs so as to play loosely or w...
1,17499,BOB,n.,A knot of worms or of rags on a string used in...
2,17500,BOB,n.,A small piece of cork or light wood attached t...
3,17501,BOB,n.,The ball or heavy part of a pendulum; also the...
4,17502,BOB,n.,A small wheel made of leather with rounded edg...
...,...,...,...,...
202,171769,WEAVE,v. t.,To unite as threads of any kind in such a mann...
203,171770,WEAVE,v. t.,To form as cloth by interlacing threads; to co...
204,171771,WEAVE,v. i.,To practice weaving; to work with a loom.
205,171772,WEAVE,v. i.,To become woven or interwoven.


In [9]:
embeddings = retriever.encode(dict_df['Definition'])

embeddings.shape

(207, 384)

In [10]:
matrix = embeddings

In [11]:
def cosine_similarity(a, b):
    return np.dot(a,b)/(norm(a)*norm(b))

In [12]:
similarities = []

for i in range(len(matrix)):
    a = matrix[i]
    for j in range(i, len(matrix)):
        b = matrix[j]
        if dict_df.iloc[i]["Word"] != dict_df.iloc[j]["Word"]:
            similarities.append([dict_df.iloc[i]["Word"], dict_df.iloc[j]["Word"], cosine_similarity(a, b)])
            
df = pd.DataFrame(similarities, columns=["word_1", "word_2", "similarity"])

df

Unnamed: 0,word_1,word_2,similarity
0,BOB,CIRCLE,0.143793
1,BOB,CIRCLE,0.322873
2,BOB,CIRCLE,0.325041
3,BOB,CIRCLE,0.359586
4,BOB,CIRCLE,0.199129
...,...,...,...
19235,TRIANGLE,WEAVE,-0.127972
19236,TRIANGLE,WEAVE,-0.064480
19237,TRIANGLE,WEAVE,-0.034874
19238,TRIANGLE,WEAVE,-0.044646


In [13]:
df = df.groupby(['word_1', 'word_2'])['similarity'].max().reset_index()

df

Unnamed: 0,word_1,word_2,similarity
0,BOB,CIRCLE,0.454816
1,BOB,CRAB,0.546538
2,BOB,CROSS,0.585206
3,BOB,DIAMOND,0.401919
4,BOB,FEAST,0.437472
...,...,...,...
86,SQUARE,TRIANGLE,0.711578
87,SQUARE,WEAVE,0.427148
88,SQUID,TRIANGLE,0.309317
89,SQUID,WEAVE,0.157068


In [14]:
df = df.drop_duplicates()
df = df.dropna()
df = df.sort_values(by="similarity", ascending=False)

df

Unnamed: 0,word_1,word_2,similarity
86,SQUARE,TRIANGLE,0.711578
53,DIAMOND,TRIANGLE,0.695270
42,CROSS,SQUARE,0.664227
51,DIAMOND,SQUARE,0.626042
38,CROSS,FREE,0.612566
...,...,...,...
68,FREE,TRIANGLE,0.210935
67,FREE,SQUID,0.182668
61,FEAST,TRIANGLE,0.166279
60,FEAST,SQUID,0.159227


In [15]:
relation_dict = {}

for i, n in df.iterrows():
    word1 = n["word_1"]
    word2 = n["word_2"]
    
    key1 = (word1, word2)
    key2 = (word2, word1)
    
    relation_dict[key1] = n["similarity"]
    relation_dict[key2] = n["similarity"]

relation_dict

{('SQUARE', 'TRIANGLE'): 0.7115780115127563,
 ('TRIANGLE', 'SQUARE'): 0.7115780115127563,
 ('DIAMOND', 'TRIANGLE'): 0.6952704787254333,
 ('TRIANGLE', 'DIAMOND'): 0.6952704787254333,
 ('CROSS', 'SQUARE'): 0.6642268300056458,
 ('SQUARE', 'CROSS'): 0.6642268300056458,
 ('DIAMOND', 'SQUARE'): 0.6260420083999634,
 ('SQUARE', 'DIAMOND'): 0.6260420083999634,
 ('CROSS', 'FREE'): 0.6125659346580505,
 ('FREE', 'CROSS'): 0.6125659346580505,
 ('BOB', 'SQUID'): 0.6010774374008179,
 ('SQUID', 'BOB'): 0.6010774374008179,
 ('BOB', 'SQUARE'): 0.5906329154968262,
 ('SQUARE', 'BOB'): 0.5906329154968262,
 ('BOB', 'CROSS'): 0.5852063894271851,
 ('CROSS', 'BOB'): 0.5852063894271851,
 ('CIRCLE', 'HOOK'): 0.5789260864257812,
 ('HOOK', 'CIRCLE'): 0.5789260864257812,
 ('CRAB', 'SPONGE'): 0.5782073736190796,
 ('SPONGE', 'CRAB'): 0.5782073736190796,
 ('CRAB', 'SQUID'): 0.5733674168586731,
 ('SQUID', 'CRAB'): 0.5733674168586731,
 ('CRAB', 'HOOK'): 0.5635383129119873,
 ('HOOK', 'CRAB'): 0.5635383129119873,
 ('CROSS

In [18]:
sim_4 = []

def similarity_4(a, b, c, d):
    return relation_dict[(a, b)] + relation_dict[(a, c)] + relation_dict[(a, d)] + relation_dict[(b, c)] + relation_dict[(b, d)] + relation_dict[(c, d)]

for i, a in enumerate(words):
    for j in range(i + 1, len(words)):
        b = words[j]
        for k in range(j + 1, len(words)):
            c = words[k]
            for l in range(k + 1, len(words)):
                d = words[l]
                try:
                    score = similarity_4(a, b, c, d)
                        
                    index = 0
                    
                    while index < len(sim_4) and score < sim_4[index][1]:
                        index += 1
                    sim_4.insert(index, ([a, b, c, d], score))
                except:
                    pass
                    

sim_4

[(['DIAMOND', 'SQUARE', 'TRIANGLE', 'CROSS'], 3.652441620826721),
 (['CIRCLE', 'DIAMOND', 'SQUARE', 'TRIANGLE'], 3.583052098751068),
 (['SQUARE', 'TRIANGLE', 'BOB', 'CROSS'], 3.4814471900463104),
 (['DIAMOND', 'SQUARE', 'TRIANGLE', 'BOB'], 3.4717105627059937),
 (['SQUARE', 'BOB', 'CROSS', 'FREE'], 3.45873162150383),
 (['RAY', 'DIAMOND', 'SQUARE', 'TRIANGLE'], 3.4487803876399994),
 (['CRAB', 'SQUARE', 'BOB', 'CROSS'], 3.42920458316803),
 (['CRAB', 'DIAMOND', 'SQUARE', 'TRIANGLE'], 3.427803486585617),
 (['SQUARE', 'BOB', 'CROSS', 'HOOK'], 3.423436254262924),
 (['CRAB', 'SQUARE', 'TRIANGLE', 'CROSS'], 3.415434867143631),
 (['CIRCLE', 'SQUARE', 'TRIANGLE', 'CROSS'], 3.4085545241832733),
 (['DIAMOND', 'SQUARE', 'TRIANGLE', 'HOOK'], 3.40796959400177),
 (['RAY', 'SQUARE', 'TRIANGLE', 'CROSS'], 3.402939200401306),
 (['RAY', 'SQUARE', 'BOB', 'CROSS'], 3.365562856197357),
 (['DIAMOND', 'SQUARE', 'BOB', 'CROSS'], 3.33981654047966),
 (['CRAB', 'SQUARE', 'TRIANGLE', 'BOB'], 3.337059199810028),
 (['