<a href="https://colab.research.google.com/github/Chiamakac/TRAININGS/blob/main/Alignment/Practical_Awesome_align_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AWESOME: Aligning Word Embedding Spaces of Multilingual Encoders

[``awesome-align``](https://github.com/neulab/awesome-align) is a tool that can extract word alignments from multilingual BERT (mBERT) and allows you to fine-tune mBERT on parallel corpora for better alignment quality (see [our paper](https://arxiv.org/abs/2101.08231) for more details).

This is a simple demo of how `awesome-align` extracts word alignments from mBERT.

First, install and import the following packages. (Note that the original `awesome-align` tool does not require the `transformers` package.)

In [1]:
!pip install transformers==3.1.0
import torch
import transformers
import itertools



Load the multilingual BERT model and its tokenizer.

In [2]:
model = transformers.BertModel.from_pretrained('bert-base-multilingual-cased')
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-multilingual-cased')

### Prepare input sentences


1.   Read the English text file and store the sentences in a list `ensents = [en_sent1, en_sent2, ..., en_sentn]`
2.   Do the same for the Igbo text file `igsents = [ig_sent1, ig_sent2, ..., ig_sentn]`

3.   Store each sentence pair tuple in a list file `en_ig_sents = [(en_sent1, ig_sent1), (en_sent2, ig_sent2), ..., (en_sentn, ig_sentn)]



Input *tokenized* source and target sentences.

```
for src, tgt in en_ig_sents:
  # perform the alignment
```

In [3]:
src = 'My name is Jessica, I am French and I am thirteen years old, I go to school in Nice, but I live in Cagnes-Sur-Mer..'
tgt = 'Aha m bụ Jessica, abụ m French na adị m afọ iri na atọ, aga m akwụkwọ na Nice, mana m bi na Cagnes-Sur-Mer.'.split()

In [4]:
tgt

['Aha',
 'm',
 'bụ',
 'Jessica,',
 'abụ',
 'm',
 'French',
 'na',
 'adị',
 'm',
 'afọ',
 'iri',
 'na',
 'atọ,',
 'aga',
 'm',
 'akwụkwọ',
 'na',
 'Nice,',
 'mana',
 'm',
 'bi',
 'na',
 'Cagnes-Sur-Mer.']

In [4]:
# src= 'onye ahụ nke nwụrụ bụ nke a mụrụ na ụbọchị nke iri na anọ na ọnwa Febụwarị na afọ 1923 na ezinaụlọ nke Pa/Odoziakụ Anakonwa laworo mmụọ nke ime obodo Oranto.'
# tgt= 'The deceased was born on Feb. 14, 1923 into family of late Pa/Mrs Anakonwa of Oranto Village, married to late Pius Amatu of Obinagu Akpu Village'

Run the model and print the resulting alignments.

In [11]:
# pre-processing
sent_src, sent_tgt = src.strip().split(), tgt.strip().split()
token_src, token_tgt = [tokenizer.tokenize(word) for word in sent_src], [tokenizer.tokenize(word) for word in sent_tgt]
wid_src, wid_tgt = [tokenizer.convert_tokens_to_ids(x) for x in token_src], [tokenizer.convert_tokens_to_ids(x) for x in token_tgt]
ids_src, ids_tgt = tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', model_max_length=tokenizer.model_max_length, truncation=True)['input_ids'], tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', truncation=True, model_max_length=tokenizer.model_max_length)['input_ids']
sub2word_map_src = []
for i, word_list in enumerate(token_src):
  sub2word_map_src += [i for x in word_list]
sub2word_map_tgt = []
for i, word_list in enumerate(token_tgt):
  sub2word_map_tgt += [i for x in word_list]

# alignment
align_layer = 8
threshold = 1e-3
model.eval()
with torch.no_grad():
  out_src = model(ids_src.unsqueeze(0), output_hidden_states=True)[2][align_layer][0, 1:-1]
  out_tgt = model(ids_tgt.unsqueeze(0), output_hidden_states=True)[2][align_layer][0, 1:-1]

  dot_prod = torch.matmul(out_src, out_tgt.transpose(-1, -2))

  softmax_srctgt = torch.nn.Softmax(dim=-1)(dot_prod)
  softmax_tgtsrc = torch.nn.Softmax(dim=-2)(dot_prod)

  softmax_inter = (softmax_srctgt > threshold)*(softmax_tgtsrc > threshold)

align_subwords = torch.nonzero(softmax_inter, as_tuple=False)
align_words = set()
for i, j in align_subwords:
  align_words.add( (sub2word_map_src[i], sub2word_map_tgt[j]) )

# printing
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

for i, j in sorted(align_words):
  print(f'{color.BOLD}{color.BLUE}{sent_src[i]}{color.END}==={color.BOLD}{color.RED}{sent_tgt[j]}{color.END}')

[1m[94mMy[0m===[1m[91mAha[0m
[1m[94mJessica,[0m===[1m[91mJessica,[0m
[1m[94mI[0m===[1m[91mAha[0m
[1m[94mam[0m===[1m[91mm[0m
[1m[94mFrench[0m===[1m[91mFrench[0m
[1m[94mthirteen[0m===[1m[91mbụ[0m
[1m[94mold,[0m===[1m[91matọ,[0m
[1m[94mI[0m===[1m[91mm[0m
[1m[94mI[0m===[1m[91mm[0m
[1m[94mI[0m===[1m[91mm[0m
[1m[94mgo[0m===[1m[91miri[0m
[1m[94mschool[0m===[1m[91miri[0m
[1m[94min[0m===[1m[91mna[0m
[1m[94mNice,[0m===[1m[91mNice,[0m
[1m[94mI[0m===[1m[91mm[0m
[1m[94mlive[0m===[1m[91mbi[0m
[1m[94min[0m===[1m[91mna[0m
[1m[94mCagnes-Sur-Mer..[0m===[1m[91mCagnes-Sur-Mer.[0m


```
alignment_dict = {
  my:[aha, m, mu],
  i: [aha, m, m, m, m] => aha=1, m=4 => i==m
  microsoft: [microsoft, ...]

} 
```