### Notes 

T5 Paper: https://arxiv.org/pdf/1910.10683.pdf

T5 Tokenizer: https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_t5.py

Important Tasks: https://docs.google.com/document/d/1weIZM6QTlnitpPQmpg-WeV2RW70TnYmDuogBQPr5mB0/edit

In [1]:
#installation step
!pip install transformers
!pip install t5
!pip install sentencepiece
#creating the folders 
!mkdir data/
!mkdir data/AD_NMT-master
!mkdir data/train/
!mkdir data/test/
!mkdir data/val/
!mkdir data/model/
!mkdir data/config/
#fetching the pkl files
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1V9crCmqvgQcv0Sx2MCNWB9AET2j6M6FW' -O data/AD_NMT-master/english-Arabic-both.pkl
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1UzL4cOWTMCee83KBUh2QO_H62AFVpDQV' -O data/AD_NMT-master/LAV-MSA-2-both.pkl
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1UjDX7cCG2S23SPfSHxSPdVayMTxB5Y16' -O data/AD_NMT-master/Magribi_MSA-both.pkl
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1fEVj9jCxvcKn9zg8lO43i2sWZquegg5H' -O data/operative_config.gin
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1UGKswXSqHSxWpx57cEDzvNeJaqbAuyt8' -O data/padic.xml

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/48/35/ad2c5b1b8f99feaaf9d7cdadaeef261f098c6e1a6a2935d4d07662a6b780/transformers-2.11.0-py3-none-any.whl (674kB)
[K     |████████████████████████████████| 675kB 2.7MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 8.4MB/s 
Collecting tokenizers==0.7.0
[?25l  Downloading https://files.pythonhosted.org/packages/14/e5/a26eb4716523808bb0a799fcfdceb6ebf77a18169d9591b2f46a9adb87d9/tokenizers-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (3.8MB)
[K     |████████████████████████████████| 3.8MB 16.1MB/s 
Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |██████████

--2020-06-22 03:05:49--  https://docs.google.com/uc?export=download&id=1V9crCmqvgQcv0Sx2MCNWB9AET2j6M6FW
Resolving docs.google.com (docs.google.com)... 108.177.125.100, 108.177.125.102, 108.177.125.138, ...
Connecting to docs.google.com (docs.google.com)|108.177.125.100|:443... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://doc-10-2s-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/fil581jfqd654rk1p68dm1n99idm4600/1592795100000/16970776037313924126/*/1V9crCmqvgQcv0Sx2MCNWB9AET2j6M6FW?e=download [following]
--2020-06-22 03:05:50--  https://doc-10-2s-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/fil581jfqd654rk1p68dm1n99idm4600/1592795100000/16970776037313924126/*/1V9crCmqvgQcv0Sx2MCNWB9AET2j6M6FW?e=download
Resolving doc-10-2s-docs.googleusercontent.com (doc-10-2s-docs.googleusercontent.com)... 108.177.125.132, 2404:6800:4008:c01::84
Connecting to doc-10-2s-docs.googleusercontent.com (doc-10

In [2]:
#James Chartouni
#Joey Park
#Raef Khan

import torch
from torch.optim import SGD
import pandas as pd
import numpy as np
import pickle
import os, io, glob
import functools

import sentencepiece as spm

import transformers
import t5
import tensorflow as tf
import tensorflow_datasets as tfds
from sklearn.model_selection import train_test_split

## Prepare Datasets

We need to take our training and test sets from the pkl files and create new .txt files that are formatted so that the standard torchtext Dataset class can read them

### PADIC Dataset Parsing

In [3]:
import xml.etree.ElementTree as ET

padic_tree = ET.parse('data/padic.xml')

padic_alg_msa = []
padic_ann_msa = []
padic_syr_msa = []
padic_pal_msa = []
padic_mor_msa = [] 

for sentence in padic_tree.getroot():
  padic_alg_msa.append([sentence.find('ALGIERS').text.strip(), sentence.find('MODERN-STANDARD-ARABIC').text.strip()])
  padic_ann_msa.append([sentence.find('ANNABA').text.strip(), sentence.find('MODERN-STANDARD-ARABIC').text.strip()])
  padic_syr_msa.append([sentence.find('SYRIAN').text.strip(), sentence.find('MODERN-STANDARD-ARABIC').text.strip()])
  padic_pal_msa.append([sentence.find('PALESTINIAN').text.strip(), sentence.find('MODERN-STANDARD-ARABIC').text.strip()])
  padic_mor_msa.append([sentence.find('MOROCCAN').text.strip(), sentence.find('MODERN-STANDARD-ARABIC').text.strip()])

In [4]:
print(padic_alg_msa[0])
print(padic_ann_msa[0])
print(padic_syr_msa[0])
print(padic_pal_msa[0])
print(padic_mor_msa[0])

['EAdw AlnAs ytbAkAw bdyt nhdr mn qlby tqwl nhdy fAlnAs', "tEAlt >SwAt AlnAs bAlbkA'،  bd>t >tHdv bAnfEAl w k>nny >hdy fy AlnAs"]
['EAdwA AlnAs ytbAkAw bdyt nhdr bg$ w qwl ElyA nhdy fy AlnAs', "tEAlt >SwAt AlnAs bAlbkA'،  bd>t >tHdv bAnfEAl w k>nny >hdy fy AlnAs"]
['Ely Swt AlnAs bAlbky w bl~$t >Hky bESbyp w k>ny Em Ahdy bAlnAs', "tEAlt >SwAt AlnAs bAlbkA'،  bd>t >tHdv bAnfEAl w k>nny >hdy fy AlnAs"]
['SArwA AlnAs ySyHwA bSwt EAly wbdyt AHky wAnA mnfEl wk>ny bhdy fy AlnAs', "tEAlt >SwAt AlnAs bAlbkA'،  bd>t >tHdv bAnfEAl w k>nny >hdy fy AlnAs"]
['nAs bdAw tytbAkAw wbdyt tnhdr b nfEl bHAl <lY tnhdy AlnAs', "tEAlt >SwAt AlnAs bAlbkA'،  bd>t >tHdv bAnfEAl w k>nny >hdy fy AlnAs"]


In [5]:
alg_msa_train, alg_msa_val = train_test_split(padic_alg_msa, test_size=.15)
ann_msa_train, ann_msa_val = train_test_split(padic_ann_msa, test_size=.15)
syr_msa_train, syr_msa_val = train_test_split(padic_syr_msa, test_size=.15)
pal_msa_train, pal_msa_val = train_test_split(padic_pal_msa, test_size=.15)
mor_msa_train, mor_msa_val = train_test_split(padic_mor_msa, test_size=.15)

In [6]:
#all the translations have equal amt. of examples
print(len(alg_msa_train))
print(len(alg_msa_val))

6131
1082


###Initial Loading from Pickle

In [7]:
ls data/AD_NMT-master

english-Arabic-both.pkl  LAV-MSA-2-both.pkl  Magribi_MSA-both.pkl


In [8]:
file_path = 'data/AD_NMT-master/'

with open(file_path + "english-Arabic-both.pkl", 'rb') as handle:
    data_English_MSA_both = pickle.load(handle) 

with open(file_path + "LAV-MSA-2-both.pkl", 'rb') as handle:
    data_LAV_MSA_both = pickle.load(handle) 

with open(file_path + "Magribi_MSA-both.pkl", 'rb') as handle:
    data_Magribi_MSA_both = pickle.load(handle) 
    

In [9]:
#few dataset examples
print(data_English_MSA_both[0:5])
print(data_English_MSA_both[-5:])
print(data_LAV_MSA_both[0:5])
print(data_Magribi_MSA_both[0:5])

[['Tom was also there', 'كان توم هنا ايضا'], ['That old woman lives by herself', 'تلك المراة العجوز تسكن بمفردها'], ['He went abroad for the purpose of studying English', 'سافر خارج البلد ليتعلم الانجليزية'], ['There is a fork missing', 'هناك شوكة ناقصة'], ["I don't know this game", 'لا اعرف هذه اللعبة']]
[['Please send us more information', 'ارسل الينا المزيد من المعلومات اذا تكرمت'], ['I am an only child', 'انا طفل وحيد ابي و امي'], ['Make good use of your time', 'استفد من وقتك جيدا'], ["Fighting won't settle anything", 'لن يحل القتال اي شيء'], ['Practice makes perfect', 'الممارسة هي الطريق الى الاتقان']]
[['لا انا بعرف وحدة راحت ع فرنسا و معا شنتا حطت فيها الفرش', 'لا اعرف واحدة ذهبت الى فرنسا و لها غرفة و ضعت فيها الافرشة'], ['روح بوشك و فتول عاليسار', 'اذهب تقدم و استدر يسارا'], ['لا لا لازم انه يكون عندك موضوع ما في اشي', ' لا لا يجب ان يكون لديك موضوع هذا ضروري'], ['اوعي تبعدي من هون بلاش تضيعي ', 'لا تبتعد عن هنا حتى لا تفقد الطريق '], ['قصدي صراحة يما انا كمان كرهته من يوم ما 

In [10]:
#splits the train dataset into train and validation sets, define test set as datafile
en_msa_train, en_msa_val = train_test_split(data_English_MSA_both, test_size=.2)

lav_msa_train, lav_msa_val = train_test_split(data_LAV_MSA_both, test_size=.2)

mag_msa_train, mag_msa_val = train_test_split(data_Magribi_MSA_both, test_size=.2)

In [11]:
print(len(en_msa_train))
print(len(en_msa_val))

print(len(lav_msa_train))
print(len(lav_msa_val))

print(len(mag_msa_train))
print(len(mag_msa_val))

8000
2001
12644
3161
15788
3948


In [12]:
file_path = 'data/'

def list_to_csv(ds, src='en', trg='msa', datatype=''):
    src_formatted = datatype + '_' + src + '_' + trg + '.' + 'csv'
    
    with open(file_path + datatype + "/" + src_formatted, 'wt') as csv:
        for i, arr in enumerate(ds):
            csv.write(arr[1] + ',' + arr[0] + '\n')

In [13]:
list_to_csv(en_msa_train, 'en', 'msa', 'train')
list_to_csv(en_msa_val, 'en', 'msa', 'val')

list_to_csv(lav_msa_train, 'lav', 'msa', 'train')
list_to_csv(lav_msa_val, 'lav', 'msa', 'val')

list_to_csv(mag_msa_train, 'mag', 'msa', 'train')
list_to_csv(mag_msa_val, 'mag', 'msa', 'val')

## Training SentencePiece Model

In [14]:
#combine all the training lines of all three languages
spm_input_ds = en_msa_train + mag_msa_train + lav_msa_train

In [15]:
def list_to_input(ds):
    src_formatted = 'spm_input' + '.' + 'txt'

    with open(file_path + "/" + src_formatted, 'wt') as sentencelinefile:
        for i, arr in enumerate(ds):
            sentencelinefile.write(arr[0] + '\n' + arr[1] + '\n')

In [16]:
list_to_input(spm_input_ds)

In [17]:
VOCAB_SIZE = 32128
spm.SentencePieceTrainer.train('--input=data/spm_input.txt --model_prefix=data/model/spm --vocab_size=' + str(VOCAB_SIZE) + ' --unk_id=2 --bos_id=-1 --eos_id=1 --pad_id=0 --hard_vocab_limit=False')

In [18]:
filepath = 'data/model/spm.model'

##Tensor Processing + Add to TaskRegistry

### English to Arabic Task

In [19]:
en_msa_split_csv_path = {
    "train": "data/train/train_en_msa.csv",
    "validation": "data/val/val_en_msa.csv"
}
en_msa_example_count = {
    "train": len(en_msa_train),
    "validation": len(en_msa_val)
}

In [20]:
def en_msa_translation_dataset_fn(split, shuffle_files=False):
  ds = tf.data.TextLineDataset(en_msa_split_csv_path[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["",""],
                        field_delim=",", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE
  )
  ds = ds.map(lambda *example: dict(zip(["source", "target"], example)) )
  return ds

for example in tfds.as_numpy(en_msa_translation_dataset_fn("train").take(5)):
    print(example)

{'source': b'\xd8\xa8\xd9\x85\xd8\xa7\xd8\xb0\xd8\xa7 \xd8\xa7\xd8\xac\xd8\xa8\xd8\xaa\xd8\x9f', 'target': b'What did you answer?'}
{'source': b'\xd9\x85\xd9\x84\xd8\xa7\xd8\xa8\xd8\xb3 \xd8\xaa\xd9\x88\xd9\x85 \xd9\x85\xd9\x88\xd8\xb6\xd8\xa9 \xd9\x82\xd8\xaf\xd9\x8a\xd9\x85\xd8\xa9', 'target': b"Tom's clothes are out of fashion"}
{'source': b'\xd8\xa7\xd9\x8a\xd9\x86 \xd8\xa7\xd9\x84\xd8\xa7\xd8\xae\xd8\xb1\xd9\x8a\xd9\x86\xd8\x9f', 'target': b'Where are the others?'}
{'source': b'\xd9\x84\xd9\x8a\xd8\xb3\xd8\xaa \xd8\xac\xd8\xaf\xd9\x8a\xd8\xaf\xd8\xa9', 'target': b"It isn't new"}
{'source': b'\xd9\x84\xd8\xa7 \xd9\x8a\xd9\x85\xd9\x83\xd9\x86\xd9\x83 \xd8\xa7\xd9\x84\xd8\xb9\xd9\x8a\xd8\xb4 \xd9\x85\xd9\x86 \xd8\xaf\xd9\x88\xd9\x86 \xd8\xa7\xd9\x83\xd8\xb3\xd8\xac\xd9\x8a\xd9\x86', 'target': b"You can't live without oxygen"}


In [21]:
#turn the ds of dictionaries and change the keys to inputs and targets that the model
def en_msa_translation_preprocessor(ds):
  def to_inputs_and_targets(ex):
    return{
        "inputs": tf.strings.join(["translate English to MSA: ",ex["source"]]),
        "targets": ex["target"]
    }
  return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [22]:
t5.data.TaskRegistry.remove("translation_en_msa")
t5.data.TaskRegistry.add(
    #name of the Task
    "translation_en_msa",
    #Supply a function which returns a tf.data.Dataset
    dataset_fn=en_msa_translation_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[en_msa_translation_preprocessor],
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.bleu],
    # Not required, but helps for mixing and auto-caching.
    num_input_examples=en_msa_example_count,
    # output_features
    output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(filepath))
)

###Levantine to MSA Task

In [23]:
lav_msa_split_csv_path = {
    "train": "data/train/train_lav_msa.csv",
    "validation": "data/val/val_lav_msa.csv"
}
lav_msa_example_count = {
    "train": len(lav_msa_train),
    "validation": len(lav_msa_val)
}

In [24]:
def lav_msa_translation_dataset_fn(split, shuffle_files=False):
  ds = tf.data.TextLineDataset(lav_msa_split_csv_path[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["",""],
                        field_delim=",", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE
  )
  ds = ds.map(lambda *example: dict(zip(["source", "target"], example)) )
  return ds

for example in tfds.as_numpy(lav_msa_translation_dataset_fn("train").take(5)):
    print(example)

{'source': b'\xd9\x82\xd9\x84\xd8\xaa \xd9\x84\xd9\x87\xd8\xa7 \xd9\x85\xd9\x86\xd8\xb0 \xd9\x82\xd9\x84\xd9\x8a\xd9\x84: \xd9\x84\xd8\xa7 \xd8\xaa\xd8\xb9\xd8\xb7\xd9\x8a\xd9\x87\xd8\xa7 \xd9\x84\xd8\xa7\xd8\xad\xd8\xaf \xd8\xb9\xd9\x86\xd8\xaf\xd9\x85\xd8\xa7 \xd8\xa7\xd8\xad\xd8\xaa\xd8\xa7\xd8\xac\xd9\x87\xd8\xa7 \xd8\xa7\xd8\xb3\xd8\xaa\xd8\xb9\xd9\x85\xd9\x84\xd9\x87\xd8\xa7', 'target': b'\xd9\x82\xd9\x84\xd8\xaa\xd9\x84\xd9\x87\xd8\xa7 \xd9\x82\xd8\xa8\xd9\x84 \xd8\xb4\xd9\x88\xd9\x8a \xd9\x85\xd8\xa7 \xd8\xaa\xd8\xb9\xd8\xb7\xd9\x8a\xd9\x87 \xd9\x84\xd8\xad\xd8\xaf\xd8\xa7 \xd9\x84\xd9\x85\xd8\xa7 \xd8\xa7\xd8\xad\xd8\xaa\xd8\xa7\xd8\xac\xd9\x87\xd8\xa7 \xd8\xa8\xd8\xb7\xd9\x84\xd8\xb9\xd9\x87\xd8\xa7'}
{'source': b' \xd8\xb9\xd9\x86\xd8\xaf\xd9\x85\xd8\xa7 \xd9\x8a\xd9\x83\xd9\x88\xd9\x86 \xd8\xa7\xd9\x84\xd8\xa7\xd9\x86\xd8\xb3\xd8\xa7\xd9\x86 \xd9\x85\xd8\xaa\xd9\x81\xd8\xb1\xd8\xba\xd8\xa7   ', 'target': b'\xd9\x84\xd9\x85\xd8\xa7 \xd9\x8a\xd9\x83\xd9\x88\xd9\x86 \xd8\xa7\x

In [25]:
#turn the ds of dictionaries and change the keys to inputs and targets that the model
def lav_msa_translation_preprocessor(ds):
  def to_inputs_and_targets(ex):
    return{
        "inputs": tf.strings.join(["translate Levantine to MSA: ",ex["source"]]),
        "targets": ex["target"]
    }
  return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [26]:
t5.data.TaskRegistry.remove("translation_lav_msa")
t5.data.TaskRegistry.add(
    #name of the Task
    "translation_lav_msa",
    #Supply a function which returns a tf.data.Dataset
    dataset_fn=lav_msa_translation_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[lav_msa_translation_preprocessor],
    # Lowercase targets before computing metrics.
    postprocess_fn = t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.bleu],
    # Not required, but helps for mixing and auto-caching.
    num_input_examples=lav_msa_example_count,
    # output_features
    output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(filepath))
)

###Maghrib to MSA Task

In [27]:
mag_msa_split_csv_path = {
    "train": "data/train/train_mag_msa.csv",
    "validation": "data/val/val_mag_msa.csv"
}
mag_msa_example_count = {
    "train": len(mag_msa_train),
    "validation": len(mag_msa_val)
}

In [28]:
def mag_msa_translation_dataset_fn(split, shuffle_files=False):
  ds = tf.data.TextLineDataset(mag_msa_split_csv_path[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["",""],
                        field_delim=",", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE
  )
  ds = ds.map(lambda *example: dict(zip(["source", "target"], example)) )
  return ds

for example in tfds.as_numpy(mag_msa_translation_dataset_fn("train").take(5)):
    print(example)
    print(example['source'].decode())
    print(example['target'].decode())
    print(len(example['source']))
    print(len(example['target']))

{'source': b' \xd9\x84\xd8\xa7 \xd8\xa7\xd8\xb9\xd9\x84\xd9\x85 \xd8\xa7\xd9\x84\xd8\xa7\xd9\x86 \xd9\x87\xd8\xb0\xd8\xa7 \xd8\xa7\xd9\x84\xd8\xb7\xd8\xa8\xd9\x8a\xd8\xa8 \xd8\xa7\xd8\xb0\xd9\x87\xd8\xa8 \xd8\xa7\xd9\x84\xd9\x8a\xd9\x87 \xd9\x85\xd8\xa8\xd8\xa7\xd8\xb4\xd8\xb1\xd8\xa9 \xd8\xa7\xd9\x85 \xd8\xa7\xd8\xae\xd8\xaf \xd9\x85\xd9\x88\xd8\xb9\xd8\xaf\xd8\xa7 \xd9\x85\xd8\xb3\xd8\xa8\xd9\x82\xd8\xa7   ', 'target': b' \xd9\x85\xd8\xa7 \xd8\xb9\xd9\x84\xd9\x89 \xd8\xa8\xd8\xa7\xd9\x84\xd9\x8a \xd8\xaf\xd8\xb1\xd9\x83 \xd8\xa7\xd9\x84\xd8\xb7\xd8\xa8\xd9\x8a\xd8\xa8 \xd9\x87\xd8\xaf\xd8\xa7 \xd9\x84\xd8\xa7\xd8\xb2\xd9\x85 \xd9\x86\xd8\xb1\xd9\x88\xd8\xad \xd9\x84\xd9\x87 \xd8\xaf\xd9\x8a\xd8\xb1\xd9\x83\xd8\xaa \xd9\x88\xd9\x84\xd8\xa7 \xd9\x86\xd8\xaf\xd9\x8a  \xd8\xb1\xd9\x88\xd9\x86\xd8\xaf\xd9\x8a\xda\xa2\xd9\x88 \xd9\x82\xd8\xa8\xd9\x84 '}
 لا اعلم الان هذا الطبيب اذهب اليه مباشرة ام اخد موعدا مسبقا   
 ما على بالي درك الطبيب هدا لازم نروح له ديركت ولا ندي  رونديڢو قبل 
111
1

In [29]:
  #turn the ds of dictionaries and change the keys to inputs and targets that the model
def mag_msa_translation_preprocessor(ds):
    def to_inputs_and_targets(ex):
      return{
          "inputs": tf.strings.join(["translate Maghrib to MSA: ",ex["source"]]),
          "targets": ex["target"]
      }
    return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [30]:
t5.data.TaskRegistry.remove("translation_mag_msa")
t5.data.TaskRegistry.add(
    #name of the Task
    "translation_mag_msa",
    #Supply a function which returns a tf.data.Dataset
    dataset_fn=mag_msa_translation_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[mag_msa_translation_preprocessor],
    # Lowercase targets before computing metrics.
    postprocess_fn = t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.bleu],
    # Not required, but helps for mixing and auto-caching.
    num_input_examples=mag_msa_example_count,
    # output_features
    output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(filepath))
)

##Dataset Mixture

In [31]:
t5.data.MixtureRegistry.remove("translation_msa")
t5.data.MixtureRegistry.add(
    "translation_msa",
    ["translation_en_msa", "translation_lav_msa", "translation_mag_msa"],
     default_rate=1.0
)

##Pre-Training

In [32]:
#gotta get the base config and add the new tasks' task params
!wget "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json" -O data/config/t5-base-config.json

--2020-06-22 03:06:32--  https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.18.6
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.18.6|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1199 (1.2K) [application/json]
Saving to: ‘data/config/t5-base-config.json’


2020-06-22 03:06:33 (64.0 MB/s) - ‘data/config/t5-base-config.json’ saved [1199/1199]



In [33]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

#Using the base config from Huggingface T5 Model
config = transformers.T5Config.from_json_file(json_file="data/config/t5-base-config.json")
model = t5.models.HfPyTorchModel(config, "/tmp/hft5/", device)

In [34]:
ls /tmp/hft5

events.out.tfevents.1592795201.1978a2c727cd.137.0


In [35]:
# !wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1fEVj9jCxvcKn9zg8lO43i2sWZquegg5H' -O data/operative_config.gin

In [36]:
import gin
with gin.unlock_config():
  gin.parse_config_file("data/operative_config.gin")

In [37]:
STEPS = 10000 #@param {type: "integer"}
model.train(
    mixture_or_task_name="translation_msa",
    steps=STEPS,
    save_steps=STEPS/5,                                                   
    sequence_length={"inputs": 32, "targets": 32},
    split="train",
    batch_size=32,
    optimizer=functools.partial(transformers.AdamW, lr=1e-4),
)

INFO:absl:Saving checkpoint for step 0
INFO:absl:Saving checkpoint for step 2000
INFO:absl:Saving checkpoint for step 4000
INFO:absl:Saving checkpoint for step 6000
INFO:absl:Saving checkpoint for step 8000
INFO:absl:Saving final checkpoint for step 10000


##Evaluation

In [38]:
# Evaluate after fine-tuning
model.eval(
    "translation_msa",
    checkpoint_steps=STEPS,
    sequence_length={"inputs": 32, "targets": 32},
    batch_size=32,
)

INFO:absl:Loading from /tmp/hft5/model-10000.checkpoint
INFO:absl:eval/translation_en_msa/bleu at step 10000: 9.195
INFO:absl:eval/translation_lav_msa/bleu at step 10000: 2.157
INFO:absl:eval/translation_mag_msa/bleu at step 10000: 1.149


##Predictions

In [39]:
inputs = [
    "translate English to MSA: Tom was also there.",
    "translate English to MSA: Please send us more information",
]
model.predict(
    inputs,
    sequence_length={"inputs": 32},
    batch_size=2,
    output_file="/tmp/hft5/example_predictions.txt",
)

INFO:absl:translate English to MSA: Tom was also there.
  -> . ⁇  “ “ “ “ “ “ “ “ “ “ “ “ “ “ “ “ “
INFO:absl:translate English to MSA: Please send us more information
  -> b dieishe toce youhehehehe""""""""
