<a href="https://colab.research.google.com/github/UMWordLab/multilingual_amaze/blob/main/Multilingual_A_maze_Alternative_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#BERT-based Multilingual A-maze Alternative Generation

## 1. Preliminaries
Please run the following cells to install and import the necessary libraries. 

In [None]:
%%capture

!pip install minicons
!pip install transformers

In [None]:
%%capture

from minicons import scorer
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForMaskedLM

## 2. Selecting a Minicons language model
Please run the following cell and input the language model you would like to use for the experiment.


In [None]:
langmodel = input("What minicons language model would you like to use?\nYou can select any from this list: https://huggingface.co/models\nThe name of the model can be copied using the clipboard icon next to the name on the webpage.\n")
print(langmodel, "selected as model.")
model = scorer.IncrementalLMScorer(langmodel, 'cpu')
tokenizer = BertTokenizer.from_pretrained(langmodel)

##3. A few things before we begin...
Please run the following cells and answer the questions before moving onto the main functions. They will be asking for input CSV files for frequency mapping and stimuli. Please consult the documentation for specifications regarding these files.

In [None]:
from google.colab import files
print("Please upload your file that contains the word-to-frequency mapping.")
uploaded = files.upload()
freq_file = next(iter(uploaded))
freq_window = int(input("What is the window of frequency that you would like to use to consider words 'similar' frequency? "))

import csv
import io

# process input file
def process_frequency_file(filename):
  res = {}
  with open(filename, mode='r', encoding='utf-8-sig') as csv_file:
      csv_reader = csv.DictReader(csv_file)
      for row in csv_reader:
        freq = row['frequency']
        word = row['word']
        res[word] = freq
  return res

freq_dict = process_frequency_file(freq_file)

In [None]:
print("Please upload your file that contains the stimuli sentences to be used for alternative generation")
uploaded = files.upload()
stim_file = next(iter(uploaded))

def process_stimuli_file(filename):
  res = []
  with open(filename, mode='r', encoding='utf-8-sig') as csv_file:
      csv_reader = csv.DictReader(csv_file)
      for row in csv_reader:
          sent = row['sentences']
          res.append(sent)
  return res

sentences = process_stimuli_file(stim_file)
print("Stimuli saved. ")

##4. Main Functions
- find_similar_frequency
- tokenization
- calculate_surprisal
- find_alternative

In [None]:
import random

model = BertForMaskedLM.from_pretrained(langmodel)

# instead of random selection, provide a window for frequency selection
def find_similar_frequency(word, window):
  res = []
  if word in freq_dict.keys():
    print('word exists in dict')
    word_frq = freq_dict[word]
    for w, f in freq_dict.items():
      if w != word: 
        if int(word_frq) < (int(f) + window) and int(word_frq) > (int(f) - window):
          res.append(w)
          return res
        else:
          # error handling - word exists in given freq list but doesn't exist within the window
          print('There weren\'t words that matches the frequency in the given window. Next word will be chosen through random selection.')
          res.append(random.choice(list(freq_dict.keys())))
          return res
  else:
    # error handling - word doesn't exist in given frequency list
    # complete random selection?
    print('Frequency not found in list. Next word will be chosen through random selection.')
    res.append(random.choice(list(freq_dict.keys())))
  return res


def tokenization(sentence, separation):
  # tokenize each sentence
  # for example, if we have a sentence consists of word AA B CCC DD
  # we get, [ [[MASK][MASK] B CCC DD], [AA [MASK] CCC DD], [AA B [MASK][MASK][MASK] DD], [AA B CCC [MASK][MASK]]
  masked_list = []
  inputs = tokenizer(sentence, add_special_tokens=True, return_tensors="pt") # we tokenize this sentence
  mask_index = 0 # we keep track of where the [MASK] is
  encoding = inputs['input_ids'].clone()
  for i in range(len(separation)): # note that we don't replace code#101[CLS] or code#102[SEP]
    masked_list.append(inputs['input_ids'].clone())
    # We replace every word with code#103 which is the [MASK]
    # note we +1 because we don't want to replace the [101] start of a sentence
    masked_list[0][0][mask_index + 1] = 103
    # increment mask_index to replace the next word with [MASK]
    mask_index += 1
  return masked_list

def calculate_surprisal(sentence, word, token, start_position, verbose_mode):
    inputs = tokenizer(sentence, is_split_into_words=True, add_special_tokens=True, return_tensors="pt") # create a placeholder for masked sentences
    inputs['input_ids'] = token  # replace placeholder with masked sentence
    outputs = model(**inputs) # let the model predict
    # find a list of similar frequency words
    similar = find_similar_frequency(word, freq_window) 
    surprisal_list = []
    # calculate surprisal of each word in similar[]
    for word in similar:
      i = 0
      prob = 0
      # tokenize the character (character -> id)
      embeddings = tokenizer.convert_tokens_to_ids(word)
      # actual position is the actual index
      # we + 1 because of start_of_sentence token in BERT
      actual_position = start_position + i
      try:
        word_weights = outputs[0][0][actual_position].squeeze().div(1.0).exp()
        # if it is the first character, we set the probability to the first one
        # else, we times current probability with previous one
        if i == 0:
          prob = (word_weights / sum(word_weights))[embeddings]
        else:
          prob = prob * (word_weights / sum(word_weights))[embeddings]
        i = i + 1
        if verbose_mode:
          print(word, prob)
        # now we have the probability, we calculate surprisal
        surprisal_list.append(-1 * torch.log2(prob))
      except:
        surprisal_list.append(0.0)
      
    
    # now we have a list of surprisal, find the highest one
    max_val = max(surprisal_list)
    max_index = surprisal_list.index(max_val)
    print(similar[max_index], max_val)
    return similar[max_index]

def find_alternative(sentence, split):
  result = []
  # we get a list of [MASK] at different word position of a sentence
  # detailed description in tokenization()
  token_list = tokenization(sentence, split)
  start_position = len(split)
  for i in range(1, len(split)):
    alternative = calculate_surprisal(sentence, split[i], token_list[i], start_position, verbose_mode=False)
    result.append((split[i], alternative))
    start_position = start_position + 1
  return result

##5. Alternative Generation

This block runs the alternate generation and creates an output file under the name of your choosing. 

NOTE: THIS CELL WILL TAKE APPROXIMATELY 3 MINUTES **PER SENTENCE**. 

Make sure you check the sentences at the end after this block executes.

Please run the following cell to generate the alternatives.

In [None]:
outfile_name = input("What is the name of your output file? ")
f = open(outfile_name, mode='a', encoding='utf-8-sig')
writer = csv.writer(f)
counter = 1

for i in range(len(sentences)):
  sentence = sentences[i]
  result = find_alternative(sentence, sentence.split())
  writer.writerow([counter, 0])
  counter += 1
  for pair in result:
    writer.writerow(pair)
f.close()