<a href="https://colab.research.google.com/github/UMWordLab/multilingual_amaze/blob/main/Group_Enabled_Multilingual_A_maze_Alternative_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Multilingual A-maze Alternative Generation

## 1. Preliminaries
Please run the following cells to install and import the necessary libraries.

In [1]:
%%capture

!pip install minicons
!pip install transformers


In [2]:
%%capture

from minicons import scorer
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM

In [3]:
%pdb on

Automatic pdb calling has been turned ON


## 2. Selecting a Minicons language model
Please run the following cell and input the language model you would like to use for the experiment.


In [27]:
langmodel = input("What minicons language model would you like to use?\nYou can select any from this list: https://huggingface.co/models\nThe name of the model can be copied using the clipboard icon next to the name on the webpage.\n")
print(langmodel, "selected as model.")
model = scorer.IncrementalLMScorer(langmodel, 'cpu')
tokenizer = AutoTokenizer.from_pretrained(langmodel)

What minicons language model would you like to use?
You can select any from this list: https://huggingface.co/models
The name of the model can be copied using the clipboard icon next to the name on the webpage.
vgaraujov/bart-base-spanish
vgaraujov/bart-base-spanish selected as model.


tokenizer_config.json:   0%|          | 0.00/277 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/4.78M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/191 [00:00<?, ?B/s]

##3. A few things before we begin...
Please run the following cells and answer the questions before moving onto the main functions. They will be asking for input CSV files for frequency mapping and stimuli. Please consult the documentation for specifications regarding these files.

In [29]:
from google.colab import files
freq_window = int(input("What is the window of frequency that you would like to use to consider words 'similar' frequency? "))
print("Please upload your file that contains the word-to-frequency mapping.")
uploaded = files.upload()
freq_file = next(iter(uploaded))


import csv
import io

# process input file
def process_frequency_file(filename):
  res = []
  with open(filename, mode='r', encoding='utf-8-sig') as csv_file:
      csv_reader = csv.DictReader(csv_file)
      for i, row in enumerate(csv_reader, 1):
            word = row['word']
            res.append(word)
  return res

freq_dict = process_frequency_file(freq_file)

What is the window of frequency that you would like to use to consider words 'similar' frequency? 5
Please upload your file that contains the word-to-frequency mapping.


Saving frequency.csv to frequency (2).csv


In [30]:
print("Please upload your file that contains the stimuli sentences to be used for alternative generation")
uploaded = files.upload()
stim_file = next(iter(uploaded))

def process_stimuli_file(filename):
  global groupMode
  with open(filename, mode='r', encoding='utf-8-sig') as csv_file:
      csv_reader = csv.DictReader(csv_file)
      for row in csv_reader:
        if 'group' in row:
          groupMode = True
        else:
          groupMode = False
        break
      csv_file.seek(0)
      if groupMode == True:
        res = {}
        for row in csv_reader:
            if row['group'] in res:
             res[row['group']].append((row['sentence'], row['item_labels']))
            else:
             g = []
             g.append((row['sentence'], row['item_labels']))
             res[row['group']] = g
        return res

      if groupMode == False:
        res = []
        for row in csv_reader:
          sent = row['sentence']
          res.append(sent)
        return res


sentence_groups = process_stimuli_file(stim_file)

print("Stimuli saved. ")

Please upload your file that contains the stimuli sentences to be used for alternative generation


Saving sp-stim.csv to sp-stim.csv
Stimuli saved. 


##4. Main Functions
- find_similar_frequency
- tokenization
- calculate_surprisal
- find_alternative

In [31]:
import random


def find_words_in_tokens(sentence, tokens):
    split = sentence.split()
    map = {}
    word_index = 0
    word_start = 0
    t = 0
    current_word = ''
    while t < len(tokens):
      current_word += tokens[t]
      t += 1
      if current_word == split[word_index]:
        for i in range(word_start, t+1):
          map[i] = word_index
        word_index += 1
        word_start = t+1
        current_word = ''

    return map


def calculate_surprisal_tiebreaker(sentence, word, candidates):

  split_sentence = sentence.split()
  similar = candidates
  surprisal_list = []


  for index, similar_word in enumerate(similar):
    split_sentence[word] = similar_word
    sentence = ' '.join(split_sentence)

      #we tokenize the sentence
  #and get a string split into readable tokens
    tokens = tokenizer.tokenize(sentence)
    tokenized_sentence = {}
    for i, t, in enumerate(tokens):
      tokenized_sentence[i] = tokenizer.convert_tokens_to_string([t]).removeprefix(" ")
    token_map = find_words_in_tokens(sentence, tokenized_sentence)

    tokens_in_word = []
    for key in token_map.keys():
      if key == word:
        tokens_in_word.append(key)
    w = ''
    surp = 0
    for i in tokens_in_word:
      w = w + tokenized_sentence[token_map[i]]
      surp += model.token_score(sentence, surprisal = True)[0][i][1]
    surprisal = (w, surp)
    surprisal_list.append(surprisal)
    print(surprisal)


  max_val = max(surprisal_list)
  print(max_val)
  max_index = surprisal_list.index(max_val)
  print(similar[max_index], max_val)
  return similar[max_index]

def find_alt_given_label(label, group):
  #first we find a suitable alt for each [label] word in each sentence
  temp_alts = []

  for tup in group:
    if label in tup[1].split():
     sentence = tup[0]
     split = sentence.split()
     index = tup[1].split().index(label)

     alt = calculate_surprisal(sentence, index, verbose_mode=True)
     temp_alts.append(alt)

  #now we have a list of alts for [label] in sentence order
  #we will repeat the process of finding alternatives, using temp_alts as
  #candidate alternatives

  score = {}
  for a in temp_alts:
    score[a] = 0

  for tup in group:
    if label in tup[1].split():
     sentence = tup[0]
     index = tup[1].split().index(label)

     alternative = calculate_surprisal_tiebreaker(sentence, index, temp_alts)
     score[alternative] = score[alternative] + 1

  labelalt = max(zip(score.keys(), score.values()))[0]
  print(label, ", ", labelalt)
  return labelalt

# instead of random selection, provide a window for frequency selection
def find_similar_frequency(word, window):
  res = []
  if word in freq_dict:
    print('word exists in dict')
    rank = freq_dict.index(word)
    start_index = max(0, rank - window)
    end_index = min(len(freq_dict), (rank + window))
    res = freq_dict[start_index:end_index+1]
  elif word.lower() in freq_dict:
    print('word exists in dict')
    rank = freq_dict.index(word.lower())
    start_index = max(0, rank - window)
    end_index = min(len(freq_dict), (rank + window))
    res = freq_dict[start_index:end_index+1]
  elif word.upper() in freq_dict:
    rank = freq_dict.index(word.upper())
    start_index = max(0, rank - window)
    end_index = min(len(freq_dict), (rank + window))
    res = freq_dict[start_index:end_index+1]
  else:
    # error handling - word doesn't exist in given frequency list
    # complete random selection?
    print('Frequency not found in list. Next word will be chosen through random selection.')
    res.append(random.choice(freq_dict))
  return res

def calculate_surprisal(sentence, word, verbose_mode):

  period = False
  split_sentence = sentence.split()

  if split_sentence[word][len(split_sentence[word])-1] == '.':
    split_sentence[word] = split_sentence[word].removesuffix('.')
    period = True


  similar = find_similar_frequency(split_sentence[word], freq_window)

  surprisal_list = []


  for index, similar_word in enumerate(similar):
    split_sentence[word] = similar_word
    sentence = ' '.join(split_sentence)

      #we tokenize the sentence
  #and get a string split into readable tokens
    tokens = tokenizer.tokenize(sentence)
    tokenized_sentence = {}
    for i, t, in enumerate(tokens):
      tokenized_sentence[i] = tokenizer.convert_tokens_to_string([t]).removeprefix(" ")
    token_map = find_words_in_tokens(sentence, tokenized_sentence)

    tokens_in_word = []
    for key in token_map.keys():
      if key == word:
        tokens_in_word.append(key)
    w = ''
    surp = 0
    for i in tokens_in_word:
      w = w + tokenized_sentence[token_map[i]]
      surp += model.token_score(sentence, surprisal = True)[0][i][1]
    surprisal = (w, surp)
    surprisal_list.append(surprisal)
    if verbose_mode:
      print(surprisal)


  max_val = max(surprisal_list)
  print(max_val)
  max_index = surprisal_list.index(max_val)
  print(similar[max_index], max_val)
  if period:
    return similar[max_index] + '.'
  return similar[max_index]


def find_alternative(sentence, split, labels, alts):

  result = []
  result.append((split[0], None))
  if(labels == None):
    print("no labels provided")
    for i in range(1, len(split)):
      alternative = calculate_surprisal(sentence, i, verbose_mode=True)
      result.append((split[i], alternative))
    return result

  print("using labels")
  split_labels = labels.split()

  for i in range(1, len(split)):
    if split_labels[i] in alts.keys():
      alternative = alts[split_labels[i]]
    else:
      alternative = calculate_surprisal(sentence, i, verbose_mode=False)
    result.append((split[i], alternative))
  print(result)
  return result

##5. Alternative Generation

This block runs the alternate generation and creates an output file under the name of your choosing.

NOTE: EXPECT THIS CELL TO TAKE A LONG TIME.

Make sure you check the sentences at the end after this block executes.

Run the following cell to enable groups

In [32]:
if groupMode:

  group_alts = {}
  for group in sentence_groups:
    #creating the alts for each word label
    alts = {}
    labels = []
    #we start by getting a list of all the labels

    g = sentence_groups[group]
    #g is a list of tuples. the first element of the tuple is the sentence
    #the second is the item labels
    for tup in g:
      split_labels = tup[1].split()
      for l in split_labels:
        if l not in labels:
         labels.append(l)


  #now we have the list of labels. we iterate over each one and generate an alt
  #for that label
    for l in labels:
      alts[l] = find_alt_given_label(l, g)
    group_alts[group] = alts


#group_alts is a dictionary with groups as keys and label-alt dicts as values




Run this cell to generate the alternatives

In [53]:
outfile_name = input("What is the name of your output file? ")
f = open(outfile_name, mode='a', encoding='utf-8-sig')
writer = csv.writer(f)
counter = 1
if groupMode == True:
  print("Using provided groups")
  for group in sentence_groups.keys():
    if group == 'group':
      continue
    counter = 1
    writer.writerow(["group: " + group])
    alts_dict = group_alts[group]
    results = []
    for tup in sentence_groups[group]:
     writer.writerow([counter, 0])
     counter +=1
     sentence = tup[0]
     r = (find_alternative(sentence, sentence.split(), labels = tup[1], alts = alts_dict))
     for pair in r:
      print(pair)
      writer.writerow(pair)


if groupMode == False:
    result = []
    for i in range(len(sentence_groups)):
      if sentence_groups[i] == "sentence":
        continue
      writer.writerow([counter, 0])
      sentence = sentence_groups[i]
      result = find_alternative(sentence, sentence.split(), labels = None, alts = None)
      counter += 1
      for pair in result:
        writer.writerow(pair)

f.close()

What is the name of your output file? sp-test-3
no labels provided
('Dale', 16.714624404907227)
('Dale', 16.714624404907227)
('Dale', 16.714624404907227)
('Dale', 16.714624404907227)
('Dale', 16.714624404907227)
('Dale', 16.714624404907227)
('Dale', 16.714624404907227)
('Dale', 16.714624404907227)
('Dale', 16.714624404907227)
('Dale', 16.714624404907227)
('Dale', 16.714624404907227)
('Dale', 16.714624404907227)
pero ('Dale', 16.714624404907227)
word exists in dict
('al', 11.56169319152832)
('al', 11.56169319152832)
('al', 11.56169319152832)
('al', 11.56169319152832)
('al', 11.56169319152832)
('al', 11.56169319152832)
('al', 11.56169319152832)
('al', 11.56169319152832)
('al', 11.56169319152832)
('al', 11.56169319152832)
('al', 11.56169319152832)
('al', 11.56169319152832)
guardián ('al', 11.56169319152832)
word exists in dict
('balón', 19.392223358154297)
('balón', 19.392223358154297)
('balón', 19.392223358154297)
('balón', 19.392223358154297)
('balón', 19.392223358154297)
('balón', 19.3

In [49]:
files.download(outfile_name)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [56]:
def find_similar_frequency(word, window):
  res = []
  if word in freq_dict:
    print('word exists in dict')
    rank = freq_dict.index(word)
    start_index = max(0, rank - window)
    end_index = min(len(freq_dict), (rank + window))
    res = freq_dict[start_index:end_index+1]
  elif word.lower() in freq_dict:
    print('word exists in dict')
    rank = freq_dict.index(word.lower())
    start_index = max(0, rank - window)
    end_index = min(len(freq_dict), (rank + window))
    res = freq_dict[start_index:end_index+1]
  elif word.upper() in freq_dict:
    rank = freq_dict.index(word.upper())
    start_index = max(0, rank - window)
    end_index = min(len(freq_dict), (rank + window))
    res = freq_dict[start_index:end_index+1]
  else:
    # error handling - word doesn't exist in given frequency list
    # complete random selection?
    print('Frequency not found in list. Next word will be chosen through random selection.')
    res.append(random.choice(freq_dict))
  return res

word = "fria"
print(find_similar_frequency(word, 5))

word exists in dict
['cobalto', 'codear', 'deena', 'esposito', 'explicara', 'fria', 'heck', 'letrina', 'misionero', 'obeso', 'paxton']


In [60]:
from transformers import pipeline

model_checkpoint = "wietsedv/xlm-roberta-base-ft-udpos28-en"

pos_tag = pipeline(
    "token-classification", model=model_checkpoint, aggregation_strategy="simple"
)




config.json:   0%|          | 0.00/1.37k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

[{'entity_group': 'PRON', 'score': 0.9990953, 'word': 'I', 'start': 0, 'end': 1}, {'entity_group': 'VERB', 'score': 0.9953033, 'word': 'went', 'start': 2, 'end': 6}, {'entity_group': 'ADP', 'score': 0.9976915, 'word': 'to', 'start': 7, 'end': 9}, {'entity_group': 'DET', 'score': 0.9992933, 'word': 'the', 'start': 10, 'end': 13}, {'entity_group': 'NOUN', 'score': 0.9954781, 'word': 'store', 'start': 14, 'end': 19}, {'entity_group': 'ADV', 'score': 0.992121, 'word': 'quickly', 'start': 20, 'end': 27}, {'entity_group': 'PUNCT', 'score': 0.9752043, 'word': '.', 'start': 27, 'end': 28}]


In [64]:
tk = AutoTokenizer.from_pretrained(model_checkpoint)

In [87]:
print(tk.is_fast)
for i in range(len(sentence_groups)):
  if i == 0:
    continue
  categories = pos_tag(sentence_groups[i])
  encoding = tk(sentence_groups[i])
  counter = 1
  words = encoding.word_ids()
  tokens = encoding.tokens()
  for w in categories:
    print(tokens[words[counter]].removeprefix("▁"))
    if w["word"] == tokens[words[counter]].removeprefix("▁"):
      print(w["word"] + ",", w["entity_group"])
    counter +=1

True
<s>
<s>
Da
le
le
al
bal
ón
con
la
punta
de
<s>
El
té
no
se
hace
con
agua
agua
