# In the name of Allah

#### Install and include requirements

In [1]:
!pip -q install hazm
!pip -q install fasttext

In [2]:
import pandas as pd
import re
import fasttext
import pandas as pd
import numpy as np
import json
from hazm import Normalizer

### Define fasttext model and idf addrs

In [5]:
BASE_DIR = '../milvus/Milvus-Deploy/Milvus-Deploy/server/src'

en_bin = f'{BASE_DIR}/model/model_en.bin'
fa_bin = f'{BASE_DIR}/model/model_fa.bin'

en_vec = f'{BASE_DIR}/model/model_en.vec'
fa_vec = f'{BASE_DIR}/model/model_fa.vec'

en_idf_addr = f'{BASE_DIR}/idf/en_idf.csv'
fa_idf_addr = f'{BASE_DIR}/idf/fa_idf.csv'

### Load idf files

In [6]:
en_idf = pd.read_csv(en_idf_addr)
en_idf.index = en_idf['word']
del en_idf['word']

fa_idf = pd.read_csv(fa_idf_addr)
fa_idf.index = fa_idf['word']
del fa_idf['word']

### Load fasttext models

In [7]:
model_en = fasttext.load_model(en_bin)
model_fa = fasttext.load_model(fa_bin)



### Normalization functions

In [8]:
def en_remove_punc(s):
    punc = '"#\'*+,-/:;<=>@[\]^_`{|}~\'●,•()»«–‑-،؛−٫—'
    table = str.maketrans(dict.fromkeys(punc, ' ')) 
    new_s = s.translate(table) 
    new_s = ' '.join(new_s.split())
    return new_s


def fa_remove_punc(s):
    punc = '"#\'*+,-:;<=>@[\]^_`{|}~\'●,•()»«–‑-،؛−—'
    table = str.maketrans(dict.fromkeys(punc, ' ')) 
    new_s = s.translate(table) 
    new_s = ' '.join(new_s.split())
    return new_s

def en_normalizer(text):
    text = text.lower()
    text = text.replace('\xa0','')
    #text = text.replace('-',' ')
    text = re.sub(r"\[[\d| ]+\]", " ", text)
    text = en_remove_punc(text)
    #text = re.sub(r"(.)\.([^0-9]|\n|$)", r"\1 . \2", text)
    text = re.sub(r"(\w{2,}| )\.([^0-9]|\n|$)", r"\1 . \2", text)
    text = re.sub(r"!", " ! ", text)
    text = re.sub(r"\?", " ? ", text)
    text = re.sub(r"؟", " ؟ ", text)
    text = re.sub(r" +", " ", text)
    return text.strip()

def fa_normalizer(text):
    text = arToPersianChar(text)
    text = arToPersianNumb(text)
    text = text.replace('\xa0','')
    #text = text.replace('-',' ')
    text = text.replace('ٔ', '')
    text = fa_remove_punc(text)
    # more_normalization_function()
    normalizer = Normalizer(persian_style = False, punctuation_spacing = False, affix_spacing = False)
    text = normalizer.normalize(text)
    text = text.replace('\u200c',' ')
    text = text.replace('\u200b',' ')
    text = re.sub(r"(\w{2,}| )\.([^0-9]|\n|$)", r"\1 . \2", text)
    text = re.sub(r'([\d+])\.([\d+])', r'\1٫\2', text)
    text = re.sub(r'([\d+])/([\d+])', r'\1٫\2', text)
    text = re.sub(r"!", " ! ", text)
    text = re.sub(r"\?", " ? ", text)
    text = re.sub(r" +", " ", text)
    text = re.sub(r" +", " ", text)
    return text.strip()


def arToPersianNumb(number):
    dic = {
        '١': '۱',
        '٢': '۲',
        '٣': '۳',
        '٤': '۴',
        '٥': '۵',
        '٦': '۶',
        '٧': '۷',
        '٨': '۸',
        '٩': '۹',
        '٠': '۰',
    }
    return multiple_replace(dic, number)


def arToPersianChar(userInput):
    dic = {
        'ك': 'ک',
        'دِ': 'د',
        'بِ': 'ب',
        'زِ': 'ز',
        'ذِ': 'ذ',
        'شِ': 'ش',
        'سِ': 'س',
        'ى': 'ی',
        'ي': 'ی'
    }
    return multiple_replace(dic, userInput)

def multiple_replace(dic, text):
    pattern = "|".join(map(re.escape, dic.keys()))
    return re.sub(pattern, lambda m: dic[m.group()], str(text))

###  Functions to get paragraph and output paragraph embedding vector

In [9]:
non_vocab_idf_weight = 3

In [10]:
def encode_en(par, use_idf=False):
    par_normd_spltd = en_normalizer(par).split(' ')
    vecs = [] 
    for word in par_normd_spltd:
        if use_idf:
            vecs.append(en_idf.idf.get(word, non_vocab_idf_weight) * model_en.get_word_vector(word))
        else:
            vecs.append(model_en.get_word_vector(word))
    return np.average(np.array(vecs), axis = 0)

In [11]:
def encode_fa(par, use_idf=False):
    par_normd_spltd = fa_normalizer(par).split(' ')
    vecs = [] 
    for word in par_normd_spltd:
        if use_idf:
            vecs.append(fa_idf.idf.get(word, non_vocab_idf_weight) * model_fa.get_word_vector(word))
        else:
            vecs.append(model_fa.get_word_vector(word))
    return np.average(np.array(vecs), axis = 0)

# Load pair data and train a cross lingual projection

In [12]:
with open('pairs_merged.json', 'r') as f:
    pairs = json.load(f)

In [13]:
en_vecs = []
fa_vecs = []
for en_word, fa_word in pairs:
    en_vecs.append(encode_en(en_word, use_idf=False))
    fa_vecs.append(encode_fa(fa_word, use_idf=False))

In [14]:
en_mat = np.array(en_vecs)
fa_mat = np.array(fa_vecs)

## Let's use pytorch

In [15]:
import torch
from torch.autograd import Variable

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [16]:
def train(X, Y, model, device, epochs=20000, lr=0.00003, log_interval=1000):
    inputs = Variable(torch.from_numpy(X).to(device)).float()
    targets = Variable(torch.from_numpy(Y).to(device)).float()
    
    criterion = torch.nn.MSELoss() 
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for i in range(epochs):
        optimizer.zero_grad()
        preds = model(inputs)
        loss = criterion(preds, targets)

        loss.backward()

        optimizer.step()
        
        if i % log_interval == 0:
            print(f"loss at iteration {i} is: {loss.item():.4f}")

In [17]:
import numpy as np

fa_dim = fa_mat.shape[1]
mid_dim = fa_dim
en_dim = en_mat.shape[1]

model = torch.nn.Sequential(
    torch.nn.Linear(fa_dim, mid_dim),
#     torch.nn.ReLU(),
#     torch.nn.Linear(mid_dim, en_dim)
)
model.to(device)

train(fa_mat, en_mat, model, device)

loss at iteration 0 is: 0.0833
loss at iteration 1000 is: 0.0487
loss at iteration 2000 is: 0.0431
loss at iteration 3000 is: 0.0408
loss at iteration 4000 is: 0.0399
loss at iteration 5000 is: 0.0395
loss at iteration 6000 is: 0.0394
loss at iteration 7000 is: 0.0393
loss at iteration 8000 is: 0.0393
loss at iteration 9000 is: 0.0393
loss at iteration 10000 is: 0.0393
loss at iteration 11000 is: 0.0393
loss at iteration 12000 is: 0.0393
loss at iteration 13000 is: 0.0393
loss at iteration 14000 is: 0.0393
loss at iteration 15000 is: 0.0393
loss at iteration 16000 is: 0.0393
loss at iteration 17000 is: 0.0393
loss at iteration 18000 is: 0.0393
loss at iteration 19000 is: 0.0393


### Save model for future use

In [18]:
torch.save(model.state_dict(), './cross_ling.pt')