In [None]:
!pip install --upgrade simpletransformers tokenizers==0.9.4

In [None]:
!git clone https://github.com/aub-mind/arabert
!pip install pyarabic

In [None]:
from arabert.preprocess import ArabertPreprocessor
arabert_prep = ArabertPreprocessor(model_name='aragpt2-medium')

In [None]:
!wget "https://github.com/UBC-NLP/aoc_id/raw/master/data/train/MultiTrain.Shuffled.csv"
!head -n 3 MultiTrain.Shuffled.csv

In [None]:
import csv
import random

train = []
test = []
eval = []

dialect = {
    'MSA': '[MSA] ',
    'DIAL_EGY': "[EGYPTIAN] ",
    'DIAL_LEV': "[LEVANTINE] ",
    'DIAL_GLF': "[GULF] "
}

with open("./MultiTrain.Shuffled.csv", "r") as prompts:
  rdr = csv.reader(prompts)
  lines = 0
  for line in rdr:
    lines += 1
    if len(line) == 3 and len(line[0]) > 0:
      if random.random() > 0.4:
        train.append(dialect[line[1]] + arabert_prep.preprocess(line[2]) + '<|endoftext|>')
      elif random.random() > 0.2:
        test.append(dialect[line[1]] + arabert_prep.preprocess(line[2])  + '<|endoftext|>')
      else:
        eval.append(dialect[line[1]] + arabert_prep.preprocess(line[2])  + '<|endoftext|>')

In [None]:
!wget https://www.dropbox.com/s/jslg6fzxeu47flu/DART.zip
!unzip DART.zip
!head -n 5 DART/cf-data/EGY.txt

In [None]:
for dialect in [['EGY', '[EGYPTIAN] '], ['GLF', '[GULF] '], ['LEV', '[LEVANTINE] '], ['MGH', '[MAGHREBI] ']]:
  with open("./DART/cf-data/" + dialect[0] + ".txt", "r") as prompts:
    rdr = csv.reader(prompts, delimiter='\t')
    lines = 0
    for line in rdr:
      lines += 1
      if line == 1:
        continue
      if len(line) == 3:
        if random.random() > 0.4:
          train.append(dialect[1] + arabert_prep.preprocess(line[2])  + '<|endoftext|>')
        elif random.random() > 0.2:
          test.append(dialect[1] + arabert_prep.preprocess(line[2])  + '<|endoftext|>')
        else:
          eval.append(dialect[1] + arabert_prep.preprocess(line[2])  + '<|endoftext|>')

In [None]:
!git clone https://github.com/ryancotterell/arabic_dialect_annotation
!gunzip arabic_dialect_annotation/annotated_data.tar.gz
!tar -xvf arabic_dialect_annotation/annotated_data.tar

In [None]:
!ls annotated_data

In [None]:
!head -n 5 annotated_data/gulf

In [None]:
for dialect in [['egyptian', '[EGYPTIAN] '], ['gulf', '[GULF] '], ['levantine', '[LEVANTINE] '], ['maghrebi', '[MAGHREBI] '], ['msa', '[MSA] ']]:
  with open("./annotated_data/" + dialect[0], "r") as prompts:
    rdr = csv.reader(prompts, delimiter='\t')
    lines = 0
    for line in rdr:
      lines += 1
      if line == 1:
        continue
      if len(line) == 2:
        if random.random() > 0.4:
          train.append(dialect[1] + arabert_prep.preprocess(line[1])  + '<|endoftext|>')
        elif random.random() > 0.2:
          test.append(dialect[1] + arabert_prep.preprocess(line[1]) + '<|endoftext|>')
        else:
          eval.append(dialect[1] + arabert_prep.preprocess(line[1]) + '<|endoftext|>')

In [None]:
open('./train.txt', 'w').write("\n".join(train))
open('./test.txt', 'w').write("\n".join(test))
open('./eval.txt', 'w').write("\n".join(test))
len(train) // 117500

In [None]:
from simpletransformers.language_modeling import LanguageModelingModel

train_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "train_batch_size": 8, # multiples of 8 are best; 16 currently hits a gpu limit
    "num_train_epochs": 10,
    "fp16": False,
    "mlm": False,
}

ft_model = LanguageModelingModel('gpt2', 'aubmindlab/aragpt2-medium', args=train_args)

In [None]:
# dialect tokens
ft_model.tokenizer.add_tokens(["[EGYPTIAN]", "[MSA]", "[LEVANTINE]", "[GULF]", "[MAGHREBI]"])
ft_model.model.resize_token_embeddings(len(ft_model.tokenizer))

In [None]:
ft_model.train_model("./train.txt", eval_file="./test.txt")

In [None]:
ft_model.eval_model("./eval.txt")

In [None]:
ft_model.tokenizer.save_pretrained("./dialects")

In [None]:
ft_model.model.save_pretrained("./dialects")