<a href="https://colab.research.google.com/github/UMWordLab/multilingual_amaze/blob/main/Mandarin_A_maze_Alternative_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mandarin Maze Generation

Lisa Levinson<sup>1</sup>, Yizhi Tang<sup>2</sup>, Lucy Yu-Chuan Chiang<sup>1</sup>, Wei-Jie Zhou<sup>1</sup>, Sohee Chung<sup>1</sup>

(<sup>1</sup>University of Michigan, <sup>2</sup>Columbia University)

**Summary.** We introduce an extension of the automatic maze task described in A-Maze (Boyce et al. 2020) for Mandarin.

This script currently works for **simplified Mandarin** texts. We plan to share a variant for traditional Mandarin orthography in the near future.

## Script Setup
Run the following cell and ignore any warnings. You must authorize Google Drive access for the script to work.

In [None]:
%%capture

!pip install --upgrade -q gspread
!python -m spacy download zh_core_web_md
!pip install transformers
!pip install spacy
!pip install tdqm
!pip install urllib3

from tqdm.notebook import tqdm
import spacy
import torch
import csv
import random
import gspread
import pandas as pd
import re
import requests
import urllib.request

from spacy import displacy
from collections import Counter
from transformers import BertTokenizer, BertForMaskedLM
from google.colab import drive
from google.colab import auth
from google.auth import default
from google.colab import files

tokenizer = BertTokenizer.from_pretrained("hfl/chinese-bert-wwm")
model = BertForMaskedLM.from_pretrained("hfl/chinese-bert-wwm")
nlp = spacy.load('zh_core_web_md')
creds, _ = default()
gc = gspread.authorize(creds)
drive.mount('/content/drive')
auth.authenticate_user()


WINDOW_SIZE = 25

## File Inputs and Variables

As input, you will need a Google Sheet with a row for each stimulus sentence,  containing the words/phrases in separate columns (such as "phrase1", "phrase2", etc.). Each row should also contain an item label. If you want to use the matching tool to match words across sentences based on matching labels, items that you want to match must have the same item labels and you will need a column containing the matching labels. 

Once your stimuli are prepared, run the following cell and follow the prompts. After running, you will see a preview of how your input stimuli have been parsed. (See documentation for more detail)

In [None]:
file_name = input("What is the url for the Google Sheet containing the stimuli?\n")
sheet_name = input("What is the name of the sheet containing the stimuli (e.g., Sheet1)?\n")
row_num = input("What row do your phrases start at? (Enter an integer - usually row 2 if you have a header row)\n")
col_num = input("What column do your phrases start at? (Enter an integer)\n")
col_end_num = input("What column do your phrases end at (Enter an integer)\n")
lab_num = input("What column is your item labels in?\n")
match_lab = input("What column are your matching labels in? (if none enter z)\n")
outfile_name = input("What would you like your output file name to be?\n")


row = int(row_num) - 1
col = int(col_num) - 1
lab = int(lab_num) - 1
end = int(col_end_num) - 1

worksheet = gc.open_by_url(file_name).worksheet(sheet_name)
rows = worksheet.get_all_values()
data = pd.DataFrame.from_records(rows)

data

## Methods

Run the following cell to define the necessary functions for generating alternatives and writing output.

In [None]:
def tokenization(sentence, split):
  masked_list = []
  inputs = tokenizer(sentence, add_special_tokens=True, return_tensors="pt")
  mask_index = 0 
  encoding = inputs['input_ids'].clone()

  for i in range(len(split)): 
    masked_list.append(inputs['input_ids'].clone())

    for j in range(len(split[i])):
      masked_list[i][0][mask_index + 1 + j] = 103

    mask_index = mask_index + len(split[i])

  return masked_list

def find_list_of_words(word):
  path = "https://raw.githubusercontent.com/UMWordLab/multilingual_maze_old/main/"

  if len(word) == 1:
    filename = "one_character.csv"

  elif len(word) == 2:
    filename = "two_character.csv"

  elif len(word) == 3:
    filename = "three_character.csv"

  elif len(word) == 4:
    filename = "four_character.csv"

  else:
    filename = "long.csv"

  response = urllib.request.urlopen(path + filename)
  lines = [l.decode('utf-8') for l in response.readlines()]
  cr = csv.reader(lines)

   #with open(filename, 'r') as csvfile:
      #csvreader = csv.reader(csvfile)
  rows = []
  index = -1
  i = 0

  for row in cr:
        rows.append(row)
        if row[0] == word:
          index = i
        i = i + 1

  if index == -1:
        print("Warning: cannot find this word in our frequency list")
        res = []
        random.sample(rows, WINDOW_SIZE * 2)

        for item in random.sample(rows, WINDOW_SIZE * 2):
          for token in nlp(item[0]):
            res.append((item[0], item[1], token.pos_))

        return res

  similar_freq = []
  for j in range(max(0, index - WINDOW_SIZE), min(len(rows), index + WINDOW_SIZE)):
        for token in nlp(rows[j][0]):
          similar_freq.append((rows[j][0], rows[j][1], token.pos_))

  return similar_freq

def calculate_surprisal(sentence, word, token, start_position, verbose_mode):
    inputs = tokenizer(sentence, add_special_tokens=True, return_tensors="pt")
    inputs['input_ids'] = token  
    outputs = model(**inputs) 
    similar = find_list_of_words(word)
    surprisal_list = []

    for triple in similar:
      i = 0
      prob = 0
      for character in triple[0]:
        embeddings = tokenizer.convert_tokens_to_ids(character)
        actual_position = start_position + i + 1
        word_weights = outputs[0][0][actual_position].squeeze().div(1.0).exp()

        if i == 0:
          prob = (word_weights / sum(word_weights))[embeddings]

        else:
          prob = prob * (word_weights / sum(word_weights))[embeddings]

        i = i + 1

      if verbose_mode:
        print(triple[0], prob)
      surprisal_list.append(-1 * torch.log2(prob))
    
    max_val = max(surprisal_list)
    max_index = surprisal_list.index(max_val)

    while True:
      count = 0
      flag = False

      for token in nlp(word):
        if similar[max_index][2] == token.pos_:
          count += 1
          surprisal_list.pop(max_index)
          similar.pop(max_index)
          max_val = max(surprisal_list)
          max_index = surprisal_list.index(max_val)

        else:
          flag = True
      
      if flag or len(surprisal_list) == 0 or count > 10:
        break

    return similar[max_index]

def find_alternative(sentence, split):
  token_list = tokenization(sentence, split)
  start_position = len(split[0])
  result = []

  for i in range(1, len(split)):
    alter = calculate_surprisal(sentence, split[i], token_list[i], start_position, verbose_mode=False)
    result.append(alter)
    start_position = start_position + len(split[i])

  return result

def df_into_list(row_start, input_data, col_start, col_end):
  list_of_splits = []
  list_of_sentences = []
  list_of_labels = []

  for i in range(row_start,len(input_data)):
    lst = []

    for j in range(col_start,col_end):
      if data[j][i] != '':
        lst.append(data[j][i])

    list_of_splits.append(lst)

  for i in range(0, len(list_of_splits)):
    list_of_sentences.append(''.join(list_of_splits[i]))
    list_of_labels.append(data[0][i + 1])

  print(len(list_of_sentences))

  return list_of_sentences, list_of_splits, list_of_labels

def write_to_csv(sentences_list, splits_list, labels_list, output_file):
  f = open(output_file, 'a')
  writer = csv.writer(f)
  alternative_splits_list = []

  for i in tqdm(range(len(sentences_list))):
    result = find_alternative(sentences_list[i], splits_list[i])
    alt_splits = []
    alt_splits.append(splits_list[i][0])

    for alternatives in result:
        alt_splits.append(alternatives[0])

    splits_list[i].insert(0, labels_list[i])
    alt_splits.insert(0, labels_list[i] + '.alt')

    alternative_splits_list.append(alt_splits)
    writer.writerow(splits_list[i])
    writer.writerow(alt_splits)

    print("finished sentence " + str(i + 1) + " out of " + str(len(sentences_list)))

  f.close()

def matcher(alternatives_splits_list, label_sentence_list, labels_used_list = ['alt', 'verb']): 
  distractors = {}
  distractor_supplier_sent = alternatives_splits_list[0]

  for i in range(len(label_sentence_list[0].split())):
    if label_sentence_list[0].split()[i] in labels_used_list:

      distractors[label_sentence_list[0].split()[i]] = distractor_supplier_sent[i]
  
  for i in range(len(alternatives_splits_list)):
    words = alternatives_splits_list[i]

    for word_pos in range(1, len(words)):

      if label_sentence_list[i].split()[word_pos] in distractors:

        alternatives_splits_list[i][word_pos] = distractors[label_sentence_list[i].split()[word_pos]]
  
  print(distractors)
  return(alternatives_splits_list)

def find_longest_list(lists):
    longest_list = []
    max_length = 0
    
    for lst in lists:
        length = len(lst)
        
        if length > max_length:
            longest_list = lst
            max_length = length
    
    return max_length

def find_largest_number(string):
    numbers = []
    current_number = ""

    for char in string:
        if char.isdigit():
            current_number += char
        else:
            if current_number != "":
                numbers.append(int(current_number))
                current_number = ""

    if current_number != "":
        numbers.append(int(current_number))

    if len(numbers) == 0:
        return None

    return max(numbers)

def get_unique_numbers(df, column):
    unique_numbers = set()
    
    for value in df[column]:
        if isinstance(value, (int, float)):
            unique_numbers.add(int(value))
        elif isinstance(value, str):
            numbers = re.findall(r'\d+', value)
            for number in numbers:
                unique_numbers.add(int(number))

    return list(unique_numbers)

def find_rows_with_integer(df, column, integer):
    matching_rows = []

    for index, row in df.iterrows():
        if str(integer) in str(row[column]):
            matching_rows.append(row)

    return pd.DataFrame(matching_rows)

def flatten_list_of_list_of_lists(list_of_list_of_lists):
    flattened_list = [item for sublist in list_of_list_of_lists for item in sublist]
    return flattened_list

def select_rows(df, match_strings, output_column):
    selected_rows = df.loc[df.iloc[:, lab].isin(match_strings), output_column].tolist()
    return selected_rows

def write_list_of_lists_to_csv(list_of_lists, filename):
    with open(filename, 'w', newline='') as file:
        writer = csv.writer(file)

        for row in list_of_lists:
            writer.writerow(row)

def find_common_words(sentences):
    if not sentences:
        return []

    word_sets = [set(sentence.lower().split()) for sentence in sentences]
    common_words = set.intersection(*word_sets)
    print("Labels being matched are: " + str(common_words))

    for word_set in word_sets[1:]:
        common_words &= word_set
    return sorted(list(common_words))

def match_labels_exist(num):
  if num == "z":
    return
  else:
    return int(num)

Run the following cell to make sure it is the correct number of sentences.

In [None]:
sentence_list, split_list, labeled_list = df_into_list(row, data, col, end)

## Generate Alternatives

This section is where the alternatives are actually generated, which is the computationally intensive portion. **This step will take approximately 4 minutes PER SENTENCE. Please plan accordingly.** If you have a large set of sentences, you may want to generate alternatives in batches to make sure that the server does not timeout or run out of memory before completing the job. 

Run the following cell to write the alternatives to the output file specified in the **File Inputs and Variables** step. The file can be found if you click the folder icon on the left edge of the Colab window, which will open the "Files" tab. You can download the file by highlighting it and clicking the three dot menu on the right side. **Files are deleted from the Colab Notebook when the session disconnects/times out - make sure to save the output to your computer or a cloud drive.**



In [None]:
write_to_csv(sentence_list, split_list, labeled_list, outfile_name)

## Within-Item Alternative Matching (Optional)
You should now see a file in your google colab's files tab named "your_inputted_output_name".txt. We now use these generated alternatives to match the alternatives based on inputted matching labels and item labels.

Run the following cell and input a new filename different from your original output that you want the ouptut to be named. The new output with matched alternatives will also be saved to the files tab. 

In [None]:
distractor_wks = pd.read_csv('outs.txt', header = None)
matched_alts_outfile = input("What do you want the outfile to be named? \n")
match_lab_col = match_labels_exist(match_lab) - 1
distractor_list = list(data[match_lab_col])

Run the following two cells

In [None]:
unique_labels = get_unique_numbers(distractor_wks, 0)
alternatives_groups = []

for i in unique_labels:
  alternatives_groups.append(find_rows_with_integer(distractor_wks, 0, i))

print(unique_labels)
print(alternatives_groups)

In [None]:
alts_list = []
alts_label_list = []

for i in range(len(alternatives_groups)):
  distractor_wks_temp = alternatives_groups[i][alternatives_groups[i][0].str.contains('alt')]
  alts_labels = distractor_wks_temp.iloc[:, 0].tolist()
  new_list = [row[1:].tolist() for index, row in distractor_wks_temp.iterrows()]
  alts_labels = [word[:-4] for word in alts_labels]

  matching_label_list = find_common_words(select_rows(data, alts_labels, match_lab_col))
  alts_list_group = matcher(new_list, select_rows(data, alts_labels, match_lab_col), matching_label_list)
  alts_list.append(alts_list_group)
  alts_label_list.append(alts_labels)

alts_list = flatten_list_of_list_of_lists(alts_list)
alts_label_list = flatten_list_of_list_of_lists(alts_label_list)

for i in range(len(alts_label_list)):
  alts_list[i].insert(0, alts_label_list[i] + ".matchedalt")

In [None]:
result = []

for i in range(len(alts_list)):
  split_list[i].insert(0, labeled_list[i])
  result.append(split_list[i])
  result.append(alts_list[i])

write_list_of_lists_to_csv(result, matched_alts_outfile)