<a href="https://colab.research.google.com/github/Karthick47v2/mcq-generator/blob/base-dev/false_ans_gen.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install sense2vec
!pip install sentence_transformers==2.2.0

In [None]:
import numpy as np
import random

from sense2vec import Sense2Vec
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
sentence_model = SentenceTransformer('all-MiniLM-L12-v2')

!wget https://github.com/explosion/sense2vec/releases/download/v1.0.0/s2v_reddit_2015_md.tar.gz
!tar -xf s2v_reddit_2015_md.tar.gz

In [6]:
s2v = Sense2Vec().from_disk("/content/s2v_old")

In [7]:
## HELPER FUNCTIONS FOR FALSE ANSWERS
# generate false answers from correct answer 
def false_answers(query):
  # get the best sense for given word (like NOUN, PRONOUN, VERB...)
  query_al = s2v.get_best_sense(query.lower().replace(' ', '_'))

  # sometimes word won't be in sense2vec in that case we can't produce any output -- ##### TODO DO: DROP THAT QUESTION
  try:
    assert query_al in s2v
    # get most similar 20 words (if any)
    temp = s2v.most_similar(query_al, n=20)
    formatted_string = change_format(query_al, temp)
    formatted_string.insert(0, query)
    # if answers are numbers then we don't need to filter 
    if query_al.split('|')[1] == 'CARDINAL':
      return formatted_string[:4]
    # else filter because sometimes similar words will be US, U.S, USA, AMERICA.. bt all are same no?
    return filter_output(query, formatted_string)
  except:
    return None

# change s2v format to fair readable form
def change_format(query, distractors):
  output = []
  for result in distractors:
    res = result[0].split('|')
    res = res[0].replace('_', ' ')
    res = res[0].upper() + res[1:]
    output.append(res)
  return output

# generate embeddings 
def return_embedding(answer, distractors):
  return sentence_model.encode([answer]), sentence_model.encode(distractors)

# filter false answers 
def filter_output(orig, dummies):
  ans_embedded, dis_embedded = return_embedding(orig, dummies)
  # filter using MMMR 
  dist = mmr(ans_embedded, dis_embedded,dummies)

  filtered_dist = []
  for d in dist:
    # 0 -> word, 1 -> confidence / probability
    filtered_dist.append(d[0])

  return filtered_dist

# Mdicersity using MR - Maximal Marginal Relevence
def mmr(doc_embedding, word_embedding, words, top_n=4, diversity=0.9):
  # extract similarity between words and docs
  word_doc_similarity = cosine_similarity(word_embedding, doc_embedding)
  word_similarity = cosine_similarity(word_embedding)
  
  kw_idx = [np.argmax(word_doc_similarity)]
  dist_idx = [i for i in range(len(words)) if i != kw_idx[0]]

  for i in range(top_n - 1):
    dist_similarities = word_doc_similarity[dist_idx, :]
    target_similarities = np.max(word_similarity[dist_idx][:, kw_idx], axis=1)

    # calculate MMR
    mmr = (1 - diversity) * dist_similarities - diversity * target_similarities.reshape(-1, 1)
    mmr_idx = dist_idx[np.argmax(mmr)]

    # update kw
    kw_idx.append(mmr_idx)
    dist_idx.remove(mmr_idx)

  return [(words[idx], round(float(word_doc_similarity.reshape(1, -1)[0][idx]), 4)) for idx in kw_idx]

In [None]:
## MAIN
query = "7"
results = false_answers(query)
random.shuffle(results)

if results == None:
  print("Sorry input is wrong")
else:
  print(results)