**word2_vec_demo with tfidf wighted average for document vectors**

In [None]:
import pandas as pd
import numpy as np
import json
from zipfile import ZipFile
import nltk
import re
from nltk.corpus import stopwords
from google.colab import files
from nltk.corpus import wordnet
import gensim.downloader
from gensim.models import KeyedVectors
from sklearn.neighbors import NearestNeighbors
from sklearn.feature_extraction.text import TfidfVectorizer
nltk.download('stopwords')
nltk.download('punkt')
nltk.download('wordnet')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...


True

In [None]:
# Please upload the required_data folder here
required_data = files.upload()

In [None]:
# Extract files

with ZipFile('/content/Required_data.zip', 'r') as zip:
  zip.extractall()

NameError: ignored

In [None]:
which_model = input('Type 0 for the google news pre-trained model or 1 for the Law2Vec model: ')

if which_model == '0':
  w2v_model = gensim.downloader.load('word2vec-google-news-300')
  dim = 300
else:
  w2v_model = KeyedVectors.load_word2vec_format('/content/Required_data/Law2Vec.200d.txt', binary=False)
  dim = 200

Type 0 for the google news pre-trained model or 1 for the Law2Vec model: 0


In [None]:
# Load csv table

with ZipFile('/content/Required_data/us_code.csv.zip') as z:
  with z.open('us_code.csv') as f:
    documents_df = pd.read_csv(f)
documents_df = documents_df.drop(documents_df.columns[0], axis=1)

In [None]:
print(documents_df.head())

    Index                                          documents  \
0  01.1.1  In determining the meaning of any Act of Congr...   
1  01.1.2  The word “ county ” includes a parish , or any...   
2  01.1.3  The word “ vessel ” includes every description...   
3  01.1.4  The word “ vehicle ” includes every descriptio...   
4  01.1.5  The word “ company ” or “ association ” , when...   

                                   documents_cleaned  
0  determining meaning act congress   unless cont...  
1  word   county   includes parish   equivalent s...  
2  word   vessel   includes every description wat...  
3  word   vehicle   includes every description ca...  
4  word   company     association     used refere...  


In [None]:
# Define the function to get the vector of a document with the simple average of all word vectors

def get_mean_vector(sentence):
    vectors = [w2v_model[word] for word in sentence.split() if word in w2v_model]
    return sum(vectors) / len(vectors)

In [None]:
# Find the embeddings for all the documents, i.e. all sections of the USC

document_embeddings = [get_mean_vector(doc) for doc in documents_df['documents_cleaned']]

In [None]:
# Define and train the NearestNeighbors model

neighborhood = NearestNeighbors(n_neighbors=10,metric="cosine",algorithm="auto",radius=1.0)
neighborhood.fit(document_embeddings)

In [None]:
# Define the function to clean text

def clean_text(txt):
    return " ".join(re.sub(r'[^a-zA-Z]',' ',w).lower() for w in txt.split() if re.sub(r'[^a-zA-Z]',' ',w).lower() not in stopwords.words('english'))

In [None]:
# Define the function to find similarity with a given query that returns just the 10 most similar predicted indexes

def find_similarity(query):

  query_cleaned = clean_text(query)

  query_vec = get_mean_vector(query_cleaned)

  similar_indx = neighborhood.kneighbors([query_vec],return_distance=False)
  return list(documents_df.iloc[similar_indx[0],0])

In [None]:
# Define the function to calculate accuracy as described in the paper

def accuracy_score(y_pred, correct_label):

  acc = right =  0

  if correct_label in y_pred:
    return (acc + 1-0.05*y_pred.index(correct_label), right + 1)

  else:
    correct_label = correct_label.split('.')
    for pred_label in y_pred:
      buffer = 0

      if int(correct_label[0]) == int(pred_label.split('.')[0]):
        buffer += 0.2

        if correct_label[1] == pred_label.split('.')[0]:
          buffer += 0.3

      acc += buffer/10
    return (acc, right)

In [None]:
with open('/content/Required_data/queries.json', 'r') as f:
  test_data = json.load(f)

In [None]:
# Go over every test query, predict the labels and campute the accuracy
def test(test_data, similarity_func, accuracy = []):

  for k,v in test_data.items():
    if len(k.split('.')[0]) == 1:
      k = '0' + k

    if int(k.split('.')[0]) > 0:
      for q in v:
        y_pred = similarity_func(q)
        a = accuracy_score(y_pred, k)
        #print('predicted: ', y_pred)
        #print('accuracy: ', a)
        #print('correct: ', k)
        #print("\n")
        accuracy.append(a)
  return accuracy

In [None]:
accuracy = test(test_data, find_similarity)

In [None]:
# Compute accuracy values and print them
def print_accuracy(accuracy):
  accuracy_value = sum([x[0] for x in accuracy])/len(accuracy)
  total_correct = sum([x[1] for x in accuracy])/len(accuracy)
  print('Total Accuracy Score: ' + str(accuracy_value))
  print('Total percentage of correct: ' + str(total_correct))

In [None]:
print_accuracy(accuracy)

Total Accuracy Score: 0.5773809523809524
Total percentage of correct: 0.6041666666666666


In [None]:
# Here you can try with a single custom query

query = input('Pls enter the query: ')

# Or put the code above in a comment and try this sample
#query = "how do i know if my marriage is valid?"

Pls enter the query: how do i know if my marriage is valid?


In [None]:
# Clean the query text and predict results

query_cleaned = clean_text(query)
query_vec = get_mean_vector(query_cleaned)
similar_indx = neighborhood.kneighbors([query_vec],return_distance=False)

In [None]:
# Print results
idx_list = list(documents_df.iloc[similar_indx[0],0])
text_list = list(documents_df.iloc[similar_indx[0],1])

for i in range(10):
  print(idx_list[i] + '\n' + text_list[i] + '\n')

01.1.7
For the purposes of any Federal law , rule , or regulation in which marital status is a factor , an individual shall be considered married if that individual ’ s marriage is between 2 individuals and is valid in the State where the marriage was entered into or , in the case of a marriage entered into outside any State , if the marriage is between 2 individuals and is valid in the place where entered into and the marriage could have been entered into in a State . In this section , the term “ State ” means a State , the District of Columbia , the Commonwealth of Puerto Rico , or any other territory or possession of the United States . For purposes of subsection ( a ) , in determining whether a marriage is valid in a State or the place where entered into , if outside of any State , only the law of the jurisdiction applicable at the time the marriage was entered into may be considered .

37.7.423
A payment of an allowance , based on a purported marriage , that is made under this cha

In [None]:
# Get tf-idf vecotrs

tfidfvectoriser = TfidfVectorizer()
tfidfvectoriser.fit(documents_df.documents_cleaned)
tfidf_vectors = tfidfvectoriser.transform(documents_df.documents_cleaned)

# Get the list of all the found words
words = tfidfvectoriser.get_feature_names_out()

In [None]:
# Get the embedding matrix of word2vec vectors

embedding_matrix=np.zeros((len(words),dim))
for i, word in enumerate(words):
    if word in w2v_model:
        embedding_matrix[i] = w2v_model[word]

In [None]:
# Get the document embeddings by weighting the w2v word emebeddings trough the tfidf embeddings
document_embeddings_weighted=(tfidf_vectors@embedding_matrix)

In [None]:
# Train the new NearestNeighbors model

neighborhood_weighted = NearestNeighbors(n_neighbors=10,metric="cosine",algorithm="auto",radius=1.0,)
neighborhood_weighted.fit(document_embeddings_weighted)

In [None]:
# Redefine similarity function

def find_similarity_weighted(query):

  query_cleaned = clean_text(query)
  query_tifidf = tfidfvectoriser.transform([query_cleaned])
  query_vec = query_tifidf @ embedding_matrix

  similar_indx = neighborhood_weighted.kneighbors(query_vec, return_distance=False)

  return list(documents_df.iloc[similar_indx[0],0])

In [None]:
accuracy = test(test_data, find_similarity_weighted)

In [None]:
print_accuracy(accuracy)

Total Accuracy Score: 0.5812499999999995
Total percentage of correct: 0.6041666666666666


In [None]:
def find_similarity_mixed(query):
  query_cleaned = clean_text(query)
  query_tifidf = tfidfvectoriser.transform([query_cleaned])
  query_vec_weighted = query_tifidf@embedding_matrix

  query_list = query_cleaned.split(" ")
  query_mat=np.zeros((len(query_list), dim))
  for i, word in enumerate(query_list):
    if word in w2v_model:
      query_mat[i] = w2v_model[word]

  query_vec_unweighted = np.array([query_mat.mean(axis=0).T])

  query_vec=(query_vec_weighted*0.5+query_vec_unweighted*0.5)

  similar_indx = neighborhood_weighted.kneighbors(query_vec, return_distance=False)
  return list(documents_df.iloc[similar_indx[0],0])

In [None]:
accuracy = test(test_data, find_similarity_mixed)

In [None]:
print_accuracy(accuracy)

Total Accuracy Score: 0.5859623015873011
Total percentage of correct: 0.6101190476190477


In [None]:
def find_synonims(word):
  synonyms = set()
  for syn in wordnet.synsets(word):
      for lm in syn.lemmas():
              synonyms.add(lm.name())
  return synonyms

In [None]:
def find_similarity_syn(query):
  query_cleaned = clean_text(query)
  query_list = query_cleaned.split()

  query_list_syn=[]
  for word in query_list:
    if word in words:
      query_list_syn.append(word)
    else:
      for syn in find_synonims(word):
        if syn in words:
          query_list_syn.append(word)

  query_syn_cleaned = " ".join(re.sub(r'[^a-zA-Z]',' ',w).lower() for w in query_list_syn)

  query_tifidf = tfidfvectoriser.transform([query_syn_cleaned])

  query_vec = query_tifidf @ embedding_matrix

  similar_indx = neighborhood_weighted.kneighbors(query_vec, return_distance=False)
  return list(documents_df.iloc[similar_indx[0],0])


In [None]:
accuracy = test(test_data, find_similarity_syn)

In [None]:
print_accuracy(accuracy)

Total Accuracy Score: 0.5850146842878118
Total percentage of correct: 0.6079295154185022
