### Notes 

T5 Paper: https://arxiv.org/pdf/1910.10683.pdf

T5 Tokenizer: https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_t5.py

Important Tasks: https://docs.google.com/document/d/1weIZM6QTlnitpPQmpg-WeV2RW70TnYmDuogBQPr5mB0/edit

In [61]:
#installation step
!pip install transformers
!pip install t5
!pip install sentencepiece
#creating the folders 
!mkdir data/
!mkdir data/AD_NMT-master
!mkdir data/train/
!mkdir data/test/
!mkdir data/val/
!mkdir data/model/
!mkdir data/config/
#fetching the pkl files
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1V9crCmqvgQcv0Sx2MCNWB9AET2j6M6FW' -O data/AD_NMT-master/english-Arabic-both.pkl
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1UzL4cOWTMCee83KBUh2QO_H62AFVpDQV' -O data/AD_NMT-master/LAV-MSA-2-both.pkl
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1UjDX7cCG2S23SPfSHxSPdVayMTxB5Y16' -O data/AD_NMT-master/Magribi_MSA-both.pkl
# !wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1fEVj9jCxvcKn9zg8lO43i2sWZquegg5H' -O data/operative_config.gin
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1UGKswXSqHSxWpx57cEDzvNeJaqbAuyt8' -O data/padic.xml

mkdir: cannot create directory ‘data/’: File exists
mkdir: cannot create directory ‘data/AD_NMT-master’: File exists
mkdir: cannot create directory ‘data/train/’: File exists
mkdir: cannot create directory ‘data/test/’: File exists
mkdir: cannot create directory ‘data/val/’: File exists
mkdir: cannot create directory ‘data/model/’: File exists
mkdir: cannot create directory ‘data/config/’: File exists
--2020-08-29 02:21:28--  https://docs.google.com/uc?export=download&id=1V9crCmqvgQcv0Sx2MCNWB9AET2j6M6FW
Resolving docs.google.com (docs.google.com)... 108.177.111.101, 108.177.111.138, 108.177.111.113, ...
Connecting to docs.google.com (docs.google.com)|108.177.111.101|:443... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://doc-10-2s-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/4ci9a7dmo6m4t96egeu5uei6o7qm0ken/1598667675000/16970776037313924126/*/1V9crCmqvgQcv0Sx2MCNWB9AET2j6M6FW?e=download [following]
--2020-08-29 0

In [62]:
#James Chartouni
#Joey Park
#Raef Khan

import torch
from torch.optim import SGD
import pandas as pd
import numpy as np
import pickle
import os, io, glob
import functools

import sentencepiece as spm

import transformers
import t5
from t5.data import preprocessors
import tensorflow as tf
import tensorflow_datasets as tfds
from sklearn.model_selection import train_test_split

## Prepare Datasets

We need to take our training and test sets from the pkl files and create new .txt files that are formatted so that the standard torchtext Dataset class can read them

### PADIC Dataset Parsing

In [63]:
import xml.etree.ElementTree as ET

padic_tree = ET.parse('data/padic.xml')

padic_alg_msa = []
padic_ann_msa = []
padic_syr_msa = []
padic_pal_msa = []
padic_mor_msa = [] 

for sentence in padic_tree.getroot():
  padic_alg_msa.append([sentence.find('ALGIERS').text.strip(), sentence.find('MODERN-STANDARD-ARABIC').text.strip()])
  padic_ann_msa.append([sentence.find('ANNABA').text.strip(), sentence.find('MODERN-STANDARD-ARABIC').text.strip()])
  padic_syr_msa.append([sentence.find('SYRIAN').text.strip(), sentence.find('MODERN-STANDARD-ARABIC').text.strip()])
  padic_pal_msa.append([sentence.find('PALESTINIAN').text.strip(), sentence.find('MODERN-STANDARD-ARABIC').text.strip()])
  padic_mor_msa.append([sentence.find('MOROCCAN').text.strip(), sentence.find('MODERN-STANDARD-ARABIC').text.strip()])

In [64]:
print(padic_alg_msa[0])
print(padic_ann_msa[0])
print(padic_syr_msa[0])
print(padic_pal_msa[0])
print(padic_mor_msa[0])

['EAdw AlnAs ytbAkAw bdyt nhdr mn qlby tqwl nhdy fAlnAs', "tEAlt >SwAt AlnAs bAlbkA'،  bd>t >tHdv bAnfEAl w k>nny >hdy fy AlnAs"]
['EAdwA AlnAs ytbAkAw bdyt nhdr bg$ w qwl ElyA nhdy fy AlnAs', "tEAlt >SwAt AlnAs bAlbkA'،  bd>t >tHdv bAnfEAl w k>nny >hdy fy AlnAs"]
['Ely Swt AlnAs bAlbky w bl~$t >Hky bESbyp w k>ny Em Ahdy bAlnAs', "tEAlt >SwAt AlnAs bAlbkA'،  bd>t >tHdv bAnfEAl w k>nny >hdy fy AlnAs"]
['SArwA AlnAs ySyHwA bSwt EAly wbdyt AHky wAnA mnfEl wk>ny bhdy fy AlnAs', "tEAlt >SwAt AlnAs bAlbkA'،  bd>t >tHdv bAnfEAl w k>nny >hdy fy AlnAs"]
['nAs bdAw tytbAkAw wbdyt tnhdr b nfEl bHAl <lY tnhdy AlnAs', "tEAlt >SwAt AlnAs bAlbkA'،  bd>t >tHdv bAnfEAl w k>nny >hdy fy AlnAs"]


In [65]:
alg_msa_train, alg_msa_val = train_test_split(padic_alg_msa, test_size=.15)
ann_msa_train, ann_msa_val = train_test_split(padic_ann_msa, test_size=.15)
syr_msa_train, syr_msa_val = train_test_split(padic_syr_msa, test_size=.15)
pal_msa_train, pal_msa_val = train_test_split(padic_pal_msa, test_size=.15)
mor_msa_train, mor_msa_val = train_test_split(padic_mor_msa, test_size=.15)

In [66]:
#all the translations have equal amt. of examples
print(len(alg_msa_train))
print(len(alg_msa_val))

6131
1082


###Initial Loading from Pickle

In [67]:
ls data/AD_NMT-master

english-Arabic-both.pkl  LAV-MSA-2-both.pkl  Magribi_MSA-both.pkl


In [68]:
file_path = 'data/AD_NMT-master/'

with open(file_path + "english-Arabic-both.pkl", 'rb') as handle:
    data_MSA_English_both = pickle.load(handle) 

with open(file_path + "LAV-MSA-2-both.pkl", 'rb') as handle:
    data_LAV_MSA_both = pickle.load(handle) 

with open(file_path + "Magribi_MSA-both.pkl", 'rb') as handle:
    data_Magribi_MSA_both = pickle.load(handle) 
    

In [69]:
#few dataset examples
print(data_MSA_English_both[0:5])
print(data_MSA_English_both[-5:])
print(data_LAV_MSA_both[0:5])
print(data_Magribi_MSA_both[0:5])

[['Tom was also there', 'كان توم هنا ايضا'], ['That old woman lives by herself', 'تلك المراة العجوز تسكن بمفردها'], ['He went abroad for the purpose of studying English', 'سافر خارج البلد ليتعلم الانجليزية'], ['There is a fork missing', 'هناك شوكة ناقصة'], ["I don't know this game", 'لا اعرف هذه اللعبة']]
[['Please send us more information', 'ارسل الينا المزيد من المعلومات اذا تكرمت'], ['I am an only child', 'انا طفل وحيد ابي و امي'], ['Make good use of your time', 'استفد من وقتك جيدا'], ["Fighting won't settle anything", 'لن يحل القتال اي شيء'], ['Practice makes perfect', 'الممارسة هي الطريق الى الاتقان']]
[['لا انا بعرف وحدة راحت ع فرنسا و معا شنتا حطت فيها الفرش', 'لا اعرف واحدة ذهبت الى فرنسا و لها غرفة و ضعت فيها الافرشة'], ['روح بوشك و فتول عاليسار', 'اذهب تقدم و استدر يسارا'], ['لا لا لازم انه يكون عندك موضوع ما في اشي', ' لا لا يجب ان يكون لديك موضوع هذا ضروري'], ['اوعي تبعدي من هون بلاش تضيعي ', 'لا تبتعد عن هنا حتى لا تفقد الطريق '], ['قصدي صراحة يما انا كمان كرهته من يوم ما 

In [70]:
#splits the train dataset into train and validation sets, define test set as datafile
msa_en_train, msa_en_val = train_test_split(data_MSA_English_both, test_size=.2)

lav_msa_train, lav_msa_val = train_test_split(data_LAV_MSA_both, test_size=.2)

mag_msa_train, mag_msa_val = train_test_split(data_Magribi_MSA_both, test_size=.2)

In [71]:
print(len(msa_en_train))
print(len(msa_en_val))

print(len(lav_msa_train))
print(len(lav_msa_val))

print(len(mag_msa_train))
print(len(mag_msa_val))

8000
2001
12644
3161
15788
3948


In [72]:
file_path = 'data/'

def list_to_csv(ds, src='msa', trg='en', datatype=''):
    src_formatted = datatype + '_' + src + '_' + trg + '.' + 'csv'
    
    with open(file_path + datatype + "/" + src_formatted, 'wt') as csv:
        for i, arr in enumerate(ds):
            csv.write(arr[1] + ',' + arr[0] + '\n')

In [73]:
list_to_csv(msa_en_train, 'msa', 'en', 'train')
list_to_csv(msa_en_val, 'msa', 'en', 'val')

list_to_csv(lav_msa_train, 'lav', 'msa', 'train')
list_to_csv(lav_msa_val, 'lav', 'msa', 'val')

list_to_csv(mag_msa_train, 'mag', 'msa', 'train')
list_to_csv(mag_msa_val, 'mag', 'msa', 'val')

## Training SentencePiece Model

In [74]:
#combine all the training lines of all three languages
spm_input_ds = msa_en_train + mag_msa_train + lav_msa_train

In [75]:
def list_to_input(ds):
    src_formatted = 'spm_input' + '.' + 'txt'

    with open(file_path + "/" + src_formatted, 'wt') as sentencelinefile:
        for i, arr in enumerate(ds):
            sentencelinefile.write(arr[0] + '\n' + arr[1] + '\n')

In [76]:
list_to_input(spm_input_ds)

In [77]:
VOCAB_SIZE = 32128
spm.SentencePieceTrainer.train('--input=data/spm_input.txt --model_prefix=data/model/spm --vocab_size=' + str(VOCAB_SIZE) + ' --unk_id=2 --bos_id=-1 --eos_id=1 --pad_id=0 --hard_vocab_limit=False')

In [78]:
filepath = 'data/model/spm.model'

##Tensor Processing + Add to TaskRegistry

### English to Arabic Task

In [79]:
msa_en_split_csv_path = {
    "train": "data/train/train_msa_en.csv",
    "validation": "data/val/val_msa_en.csv"
}
msa_en_example_count = {
    "train": len(msa_en_train),
    "validation": len(msa_en_val)
}

In [80]:
def msa_en_translation_dataset_fn(split, shuffle_files=False):
  ds = tf.data.TextLineDataset(msa_en_split_csv_path[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["",""],
                        field_delim=",", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE
  )
  ds = ds.map(lambda *example: dict(zip(["source", "target"], example)) )
  return ds

for example in tfds.as_numpy(msa_en_translation_dataset_fn("train").take(5)):
    print(example)

{'source': b'\xd9\x87\xd8\xb0\xd8\xa7 \xd8\xa7\xd9\x85\xd8\xb1 \xd9\x8a\xd8\xa7 \xd8\xaa\xd9\x88\xd9\x85', 'target': b"That's an order Tom"}
{'source': b'\xd8\xa7\xd9\x86\xd8\xa7 \xd8\xac\xd8\xa7\xd8\xa6\xd8\xb9 \xd9\x84\xd9\x84\xd8\xba\xd8\xa7\xd9\x8a\xd8\xa9 \xd8\xa7\xd9\x84\xd8\xa7\xd9\x86', 'target': b"I'm very hungry now"}
{'source': b'\xd8\xa7\xd8\xb3\xd8\xaa\xd9\x85\xd8\xb1 \xd8\xa7\xd9\x84\xd9\x85\xd8\xb7\xd8\xb1 \xd8\xae\xd9\x85\xd8\xb3\xd8\xa9 \xd8\xa7\xd9\x8a\xd8\xa7\xd9\x85', 'target': b'The rain lasted five days'}
{'source': b'\xd9\x84\xd9\x82\xd8\xaf \xd8\xb3\xd9\x85\xd8\xb9\xd8\xaa \xd9\x87\xd8\xb0\xd9\x87 \xd8\xa7\xd9\x84\xd8\xa7\xd8\xba\xd9\x86\xd9\x8a\xd8\xa9 \xd9\x85\xd9\x86 \xd9\x82\xd8\xa8\xd9\x84', 'target': b"I've heard this song before"}
{'source': b'\xd8\xaf\xd8\xb1\xd8\xa7\xd8\xac\xd8\xaa\xd9\x87 \xd8\xb2\xd8\xb1\xd9\x82\xd8\xa7\xd8\xa1', 'target': b'His bicycle is blue'}


In [81]:
#turn the ds of dictionaries and change the keys to inputs and targets that the model
def msa_en_translation_preprocessor(ds):
  def to_inputs_and_targets(ex):
    return{
        "inputs": tf.strings.join(["translate MSA to English: ",ex["source"]]),
        "targets": ex["target"]
    }
  return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [82]:
t5.data.TaskRegistry.remove("translation_msa_en")
t5.data.TaskRegistry.add(
    #name of the Task
    "translation_msa_en",
    #Supply a function which returns a tf.data.Dataset
    dataset_fn=msa_en_translation_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[msa_en_translation_preprocessor],
    # Lowercase targets before computing metrics.

    #postprocess_fn=t5.data.postprocessors.lower_text, 

    # We'll use accuracy as our evaluation metric.

    metric_fns=[t5.evaluation.metrics.bleu],

    # Not required, but helps for mixing and auto-caching.
    num_input_examples=msa_en_example_count,
    # output_features
    output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(filepath)),
    # specifying token processor
    token_preprocessor=[
      functools.partial(
          preprocessors.select_random_chunk,
          feature_key="targets",
          max_length=65536
      ),
      functools.partial(
          preprocessors.reduce_concat_tokens,
          feature_key="targets",
          batch_size=128
      ),
      preprocessors.split_tokens_to_inputs_length,
      functools.partial(
          preprocessors.denoise,
          inputs_fn=preprocessors.noise_span_to_unique_sentinel,
          targets_fn=preprocessors.nonnoise_span_to_unique_sentinel,
          noise_density=0.15,
          noise_mask_fn=preprocessors.iid_noise_mask,
      )
    ]
)

###Levantine to MSA Task

In [83]:
lav_msa_split_csv_path = {
    "train": "data/train/train_lav_msa.csv",
    "validation": "data/val/val_lav_msa.csv"
}
lav_msa_example_count = {
    "train": len(lav_msa_train),
    "validation": len(lav_msa_val)
}

In [84]:
def lav_msa_translation_dataset_fn(split, shuffle_files=False):
  ds = tf.data.TextLineDataset(lav_msa_split_csv_path[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["",""],
                        field_delim=",", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE
  )
  ds = ds.map(lambda *example: dict(zip(["source", "target"], example)) )
  return ds

for example in tfds.as_numpy(lav_msa_translation_dataset_fn("train").take(5)):
    print(example)

{'source': b'\xd8\xa8\xd8\xb5\xd9\x81\xd9\x87 \xd8\xb9\xd8\xa7\xd9\x85\xd9\x87 \xd8\xa7\xd9\x84\xd8\xb3\xd8\xad\xd8\xb1 \xd8\xa7\xd9\x84\xd8\xa7\xd8\xb3\xd9\x88\xd8\xaf \xd9\x81\xd9\x89 \xd9\x87\xd8\xb0\xd9\x87 \xd8\xa7\xd9\x84\xd9\x85\xd9\x86\xd8\xa7\xd8\xb7\xd9\x82 \xd9\x8a\xd8\xb9\xd8\xaa\xd9\x85\xd8\xaf \xd8\xb9\xd9\x84\xd9\x89 \xd8\xa7\xd9\x84\xd9\x83\xd8\xab\xd9\x8a\xd8\xb1 \xd9\x85\xd9\x86 \xd8\xa7\xd9\x84\xd8\xa7\xd8\xb4\xd9\x8a\xd8\xa7\xd8\xa1 ', 'target': b'\xd8\xa8\xd8\xb4\xd9\x83\xd9\x84 \xd8\xb9\xd8\xa7\xd9\x85 \xd9\x88\xd9\x83\xd8\xa7\xd9\x86 \xd8\xa7\xd9\x84\xd8\xb3\xd8\xad\xd8\xb1 \xd8\xa7\xd9\x84\xd8\xa7\xd8\xb3\xd9\x88\xd8\xaf \xd9\x81\xd9\x8a \xd8\xaa\xd9\x84\xd9\x83 \xd8\xa7\xd9\x84\xd9\x85\xd9\x86\xd8\xa7\xd8\xb7\xd9\x82 \xd9\x8a\xd8\xb9\xd8\xaa\xd9\x85\xd8\xaf \xd8\xb9\xd9\x84\xd9\x89 \xd9\x83\xd8\xab\xd9\x8a\xd8\xb1 \xd9\x85\xd9\x86 \xd8\xa7\xd9\x84\xd9\x85\xd9\x85\xd8\xa7\xd8\xb1\xd8\xb3\xd8\xa7\xd8\xaa'}
{'source': b'\xd9\x81\xd9\x87\xd9\x85\xd8\xaa \xd9\x83\xd

In [85]:
#turn the ds of dictionaries and change the keys to inputs and targets that the model
def lav_msa_translation_preprocessor(ds):
  def to_inputs_and_targets(ex):
    return{
        "inputs": tf.strings.join(["translate Levantine to MSA: ",ex["source"]]),
        "targets": ex["target"]
    }
  return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [86]:
t5.data.TaskRegistry.remove("translation_lav_msa")
t5.data.TaskRegistry.add(
    #name of the Task
    "translation_lav_msa",
    #Supply a function which returns a tf.data.Dataset
    dataset_fn=lav_msa_translation_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[lav_msa_translation_preprocessor],
    # Lowercase targets before computing metrics.

    #postprocess_fn = t5.data.postprocessors.lower_text, 

    # We'll use accuracy as our evaluation metric.

    #metric_fns=[t5.evaluation.metrics.bleu],

    # Not required, but helps for mixing and auto-caching.
    num_input_examples=lav_msa_example_count,
    # output_features
    output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(filepath)),
    # specifying token processor
    token_preprocessor=[
      functools.partial(
          preprocessors.select_random_chunk,
          feature_key="targets",
          max_length=65536
      ),
      functools.partial(
          preprocessors.reduce_concat_tokens,
          feature_key="targets",
          batch_size=128
      ),
      preprocessors.split_tokens_to_inputs_length,
      functools.partial(
          preprocessors.denoise,
          inputs_fn=preprocessors.noise_span_to_unique_sentinel,
          targets_fn=preprocessors.nonnoise_span_to_unique_sentinel,
          noise_density=0.15,
          noise_mask_fn=preprocessors.iid_noise_mask,
      )
    ]
)

###Maghrib to MSA Task

In [87]:
mag_msa_split_csv_path = {
    "train": "data/train/train_mag_msa.csv",
    "validation": "data/val/val_mag_msa.csv"
}
mag_msa_example_count = {
    "train": len(mag_msa_train),
    "validation": len(mag_msa_val)
}

In [88]:
def mag_msa_translation_dataset_fn(split, shuffle_files=False):
  ds = tf.data.TextLineDataset(mag_msa_split_csv_path[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["",""],
                        field_delim=",", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE
  )
  ds = ds.map(lambda *example: dict(zip(["source", "target"], example)) )
  return ds

for example in tfds.as_numpy(mag_msa_translation_dataset_fn("train").take(5)):
    print(example)
    print(example['source'].decode())
    print(example['target'].decode())
    print(len(example['source']))
    print(len(example['target']))

{'source': b'\xd9\x85\xd9\x8a\xd9\x85\xd9\x8a \xd8\xaf\xd8\xb9\xd9\x8a \xd9\x87\xd8\xa7\xd8\xaa\xd9\x81\xd9\x83 \xd9\x88 \xd8\xb4\xd8\xa7\xd9\x86\xd9\x87 \xd9\x84\xd9\x82\xd8\xaf \xd8\xac\xd9\x86 \xd8\xa7\xd9\x84\xd9\x85\xd8\xb3\xd9\x83\xd9\x8a\xd9\x86', 'target': b'\xd9\x85\xd9\x8a\xd9\x85\xd9\x8a \xd8\xae\xd9\x84\xd9\x8a \xd8\xa7\xd9\x84\xd8\xaa\xd9\x84\xd9\x8a\xd9\x81\xd9\x88\xd9\x86 \xd9\x88 \xd8\xae\xd9\x84\xd9\x8a\xd9\x87 \xd8\xb9\xd9\x84\xd9\x89 \xd8\xae\xd8\xa7\xd8\xb7\xd8\xb1\xd9\x88 \xd8\xb1\xd8\xa7\xd9\x87 \xd9\x85\xd8\xb3\xd9\x83\xd9\x8a\xd9\x86 \xd8\xad\xd9\x85\xd9\x82\xd8\xaa\xd9\x8a\xd9\x87'}
ميمي دعي هاتفك و شانه لقد جن المسكين
ميمي خلي التليفون و خليه على خاطرو راه مسكين حمقتيه
65
93
{'source': b'\xd9\x83\xd9\x85\xd8\xa7 \xd9\x82\xd9\x84\xd8\xaa \xd9\x87\xd8\xb0\xd8\xa7 \xd9\x83\xd8\xa7\xd9\x81 \xd8\xac\xd8\xaf\xd8\xa7 \xd9\x84\xd8\xa7\xd9\x86\xd9\x83 \xd8\xb3\xd8\xaa\xd8\xb1\xd8\xa7\xd9\x81\xd9\x82\xd9\x8a\xd9\x86 \xd8\xa7\xd9\x84\xd8\xb9\xd8\xb1\xd9\x88\xd8\xb3  \xd9

In [89]:
  #turn the ds of dictionaries and change the keys to inputs and targets that the model
def mag_msa_translation_preprocessor(ds):
    def to_inputs_and_targets(ex):
      return{
          "inputs": tf.strings.join(["translate Maghrib to MSA: ",ex["source"]]),
          "targets": ex["target"]
      }
    return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [90]:
t5.data.TaskRegistry.remove("translation_mag_msa")
t5.data.TaskRegistry.add(
    #name of the Task
    "translation_mag_msa",
    #Supply a function which returns a tf.data.Dataset
    dataset_fn=mag_msa_translation_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[mag_msa_translation_preprocessor],
    # Lowercase targets before computing metrics.

    #postprocess_fn = t5.data.postprocessors.lower_text, 
    
    # We'll use accuracy as our evaluation metric.
    
    #metric_fns=[t5.evaluation.metrics.bleu],
    
    # Not required, but helps for mixing and auto-caching.
    num_input_examples=mag_msa_example_count,
    # output_features
    output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(filepath)),
    # specifying token processor
    token_preprocessor=[
      functools.partial(
          preprocessors.select_random_chunk,
          feature_key="targets",
          max_length=65536
      ),
      functools.partial(
          preprocessors.reduce_concat_tokens,
          feature_key="targets",
          batch_size=128
      ),
      preprocessors.split_tokens_to_inputs_length,
      functools.partial(
          preprocessors.denoise,
          inputs_fn=preprocessors.noise_span_to_unique_sentinel,
          targets_fn=preprocessors.nonnoise_span_to_unique_sentinel,
          noise_density=0.15,
          noise_mask_fn=preprocessors.iid_noise_mask,
      )
    ]
)

##Dataset Mixture

In [91]:
t5.data.MixtureRegistry.remove("translation_msa")
t5.data.MixtureRegistry.add(
    "translation_msa",
    ["translation_msa_en", "translation_lav_msa", "translation_mag_msa"],
     default_rate=1.0
)

##Pre-Training

In [92]:
#gotta get the base config and add the new tasks' task params
!wget "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json" -O data/config/t5-base-config.json

--2020-08-29 02:21:35--  https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.240.142
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.240.142|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1199 (1.2K) [application/json]
Saving to: ‘data/config/t5-base-config.json’


2020-08-29 02:21:35 (64.2 MB/s) - ‘data/config/t5-base-config.json’ saved [1199/1199]



In [93]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

#Using the base config from Huggingface T5 Model
config = transformers.T5Config.from_json_file(json_file="data/config/t5-base-config.json")
model = t5.models.HfPyTorchModel(config, "/tmp/hft5/", device)

INFO:absl:Loading from /tmp/hft5/model-10000.checkpoint


In [94]:
ls /tmp/hft5

events.out.tfevents.1598661619.6f0b5dc6b701.111.0  model-2000.checkpoint
events.out.tfevents.1598667702.6f0b5dc6b701.111.1  model-4000.checkpoint
example_predictions.txt                            model-6000.checkpoint
model-0.checkpoint                                 model-8000.checkpoint
model-10000.checkpoint                             [0m[01;34mvalidation_eval[0m/


In [95]:
# !wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1fEVj9jCxvcKn9zg8lO43i2sWZquegg5H' -O data/operative_config.gin

In [96]:
# import gin
# with gin.unlock_config():
#   gin.parse_config_file("data/operative_config.gin")

In [98]:
STEPS = 1000 #@param {type: "integer"}
model.train(
    mixture_or_task_name="translation_msa",
    steps=STEPS,
    save_steps=STEPS/5,                                                   
    sequence_length={"inputs": 32, "targets": 32},
    split="train",
    batch_size=32,
    optimizer=functools.partial(transformers.AdamW, lr=1e-4),
)

  return dataset.map(my_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
INFO:absl:Saving checkpoint for step 12045
INFO:absl:Saving checkpoint for step 12245
INFO:absl:Saving checkpoint for step 12445
INFO:absl:Saving checkpoint for step 12645
INFO:absl:Saving checkpoint for step 12845
INFO:absl:Saving final checkpoint for step 13045


In [99]:
review_task = t5.data.TaskRegistry.get("translation_msa_en")
ds = review_task.get_dataset(split="train", sequence_length={"inputs": 128, "targets": 32})
print("A few preprocessed validation examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex['inputs'])

  return dataset.map(my_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)


A few preprocessed validation examples...
[   61     5  5998    32    36   309   184    51   699  2093   232 24734
    39  2324 24733  1109   116  8061   190 24732 11311  4913   388 24731
   208 24730  1101 24729  4969 24728     5  3396  1638   727    36    39
 24727  3437    51 10185   321   303   487 24726  1293   121  1027  3013
    62   160    36   487  2188 24725    61  3612   260   135    36    82
   256 24724  4950  2881 24723   121  1558   827    39 24722  4677  1793
   303 24721  1314   696   150  5001 15869   170    61  1911   336 24720
  1813    88  2787 24719   184   256   987    51  4182     5  1393   573
    39  1028  2407  6634  2707    62  1616 24718    32   568    72  2150
 24717  9990   170 24716  5258 24715  1028   247    72 11275  5568    62
     5     2  5556 17237 10085    88     1]
[24734  6665   303 24733   597    51  4675  5956   292  3331  6246   528
    39   830    88  2093 24732   154    51  2256   330   242    54  2725
    62   260    54   827  6357    62  

##Evaluation

In [100]:
# Evaluate after fine-tuning
model.eval(
    mixture_or_task_name="translation_msa",
    checkpoint_steps= 13045,
    sequence_length={"inputs": 32, "targets": 32},
    batch_size=32,
)

INFO:absl:Loading from /tmp/hft5/model-13045.checkpoint


KeyError: ignored

##Predictions

In [None]:
inputs = [
    "translate Levantine to MSA: هلا والله",
    "translate MSA to English: Please help me fix this",
]
model.predict(
    inputs,
    sequence_length={"inputs": 32},
    batch_size=2,
    output_file="/tmp/hft5/example_predictions.txt",
    vocabulary=t5.data.SentencePieceVocabulary(filepath),
)