In [1]:
# !pip install plot-keras-history transformers babel

In [2]:
# !git clone https://github.com/OmarSayedMostafa/Nuanced-Arabic-Dialect-Identification.git

In [3]:
# cd Nuanced-Arabic-Dialect-Identification/

In [1]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

In [2]:
import pandas as pd
import re
import random
import string
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.text import text_to_word_sequence
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
import torch 
import torch.nn as nn
from sklearn.metrics import classification_report, f1_score
from utilities import clean_arabic_tweet
# optimizer from hugging face transformers
from transformers import AdamW
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

## load data

In [7]:
pd.set_option('display.max_colwidth', -1)

train_data_path = './clean_/DA_train_labeled.tsv'
test_data_path = './data/DA_dev_labeled.tsv'

train_dataframe = pd.read_csv(train_data_path, sep='\t')
test_dataframe = pd.read_csv(test_data_path, sep='\t')

In [8]:
del train_dataframe['#4_province_label']
del test_dataframe['#4_province_label']

## clean data

In [43]:
tarin_tweets_cleaned = (train_dataframe['#2_tweet'].apply(clean_arabic_tweet))
test_tweets_cleaned = (test_dataframe['#2_tweet'].apply(clean_arabic_tweet))


train_dataframe['#2_tweet_clean']= tarin_tweets_cleaned
test_dataframe['#2_tweet_clean']= test_tweets_cleaned

In [44]:
train_dataframe.head(10)

Unnamed: 0,#1_tweetid,#2_tweet,#3_country_label,#2_tweet_clean,classes_id
0,TRAIN_0,حاجة حلوة اكيد,Egypt,حاجة حلوة اكيد,0
1,TRAIN_1,عم بشتغلوا للشعب الاميركي اما نحن يكذبوا ويغشوا ويسرقوا ويقتلو شعوبهم ويعملوا لصالح اعدائهم,Iraq,عم بشتغلوا لشعب الاميركي اما نحن يكذبوا ويغشوا ويسرقوا ويقتلو شعوبهم ويعملوا لصالح اعدائهم,1
2,TRAIN_2,ابشر طال عمرك,Saudi_Arabia,ابشر طال عمرك,2
3,TRAIN_3,منطق 2017: أنا والغريب علي إبن عمي وأنا والغريب وإبن عمي علي أخويا. #قطع_العلاقات_مع_قطر #موريتانيا_مع_قطر,Mauritania,منطق انا والغريب علي ابن عمي وانا والغريب وابن عمي علي اخويا قطع العلاقات مع قطر موريتانيا مع قطر,3
4,TRAIN_4,شهرين وتروح والباقي غير صيف ملينا,Algeria,شهرين وتروح والباقي غير صيف ملينا,4
5,TRAIN_5,يابنتى والله ما حد متغاظ ولا مفروس منك ولا بيحسدك انتى عره اساسا.,Syria,يابنتى واله ما حد متغاظ ولا مفروس منك ولا بيحسدك انتى عره اساسا,5
6,TRAIN_6,نفس الوقت بأكد على صاحبتي ان اي هدف هتحطه وتخططله هيبوظ والأفضل التشاؤم واننا نتوقع الأسوء دايما والفشل عشان منعشمش نفسنا ع الفاضي,Egypt,نفس الوقت باكد على صاحبتي ان اي هدف هتحطه وتخطله هيبوظ والافضل التشاؤم وانا نتوقع الاسوء دايما والفشل عشان منعشمش نفسنا ع الفاضي,0
7,TRAIN_7,م تبطلي خرا بقا علشان مطلعهوش عليكي احترمي نفسك URL …,Egypt,م تبطلي خرا بقا علشان مطلعهوش عليكي احترمي نفسك,0
8,TRAIN_8,ما يله دخل !,Oman,ما يله دخل,6
9,TRAIN_9,هو حلو بس يتخربط ع طلاب المدراس ليك مايغيرونه عدنا,Iraq,هو حلو بس يتخربط ع طلاب المدراس ليك مايغيرونه عدنا,1


## convert class name to class id

In [46]:
classes_names = train_dataframe['#3_country_label'].unique().tolist()
classes_map={}
for i,class_name in enumerate(classes_names):
    classes_map[class_name]=i


def find_class_id_from_name(class_name):
    return classes_map[class_name]

In [47]:
train_dataframe['classes_id']=train_dataframe['#3_country_label'].apply(find_class_id_from_name)
test_dataframe['classes_id']=test_dataframe['#3_country_label'].apply(find_class_id_from_name)


train_x = train_dataframe['#2_tweet_clean'].tolist()
train_y = train_dataframe['classes_id'].tolist()

test_x = test_dataframe['#2_tweet_clean'].tolist()
test_y = test_dataframe['classes_id'].tolist()

## toknize the data for bert input

In [48]:
tokenizer = AutoTokenizer.from_pretrained("bashar-talafha/multi-dialect-bert-base-arabic")
# tokenize and encode sequences in the training set
tokens_train = tokenizer.batch_encode_plus(
    train_x,
    max_length = 50,
    pad_to_max_length=True,
    truncation=True,
    return_token_type_ids=False
)

# tokenize and encode sequences in the validation set
tokens_val = tokenizer.batch_encode_plus(
    test_x,
    max_length = 50,
    pad_to_max_length=True,
    truncation=True,
    return_token_type_ids=False
)

In [49]:
# for train set
train_seq = torch.tensor(tokens_train['input_ids'])
train_mask = torch.tensor(tokens_train['attention_mask'])
train_y = torch.tensor(train_y)

# for validation set
val_seq = torch.tensor(tokens_val['input_ids'])
val_mask = torch.tensor(tokens_val['attention_mask'])
val_y = torch.tensor(test_y)

## prepare data generator for pytorch training

In [15]:
#define a batch size
batch_size = 32
# wrap tensors
train_data = TensorDataset(train_seq, train_mask, train_y)
# sampler for sampling the data during training
train_sampler = RandomSampler(train_data)
# dataLoader for train set
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
# wrap tensors
val_data = TensorDataset(val_seq, val_mask, val_y)
# sampler for sampling the data during training
val_sampler = SequentialSampler(val_data)
# dataLoader for validation set
val_dataloader = DataLoader(val_data, sampler = val_sampler, batch_size=batch_size)

## load arabert pretrained model and freeze its layer to avoid retraining

In [16]:
bert = AutoModel.from_pretrained("bashar-talafha/multi-dialect-bert-base-arabic")
# freeze all the parameters
for param in bert.parameters():
    param.requires_grad = False

In [None]:
## create the top layers for fine tuning arabert for classification task

In [17]:
class BERT_Arch(nn.Module):
    def __init__(self, bert):
        super(BERT_Arch, self).__init__()
        self.bert = bert 
        # dropout layer
        self.dropout = nn.Dropout(0.3)
        # relu activation function
        self.relu =  nn.ReLU()
        # dense layer 1
        self.fc1 = nn.Linear(768,512)
        # dense layer 2 (Output layer)
        self.fc2 = nn.Linear(512,21)
        #softmax activation function
        self.softmax = nn.LogSoftmax(dim=1)

    #define the forward pass
    def forward(self, sent_id, mask):
        #pass the inputs to the model  
        _, cls_hs = self.bert(sent_id, attention_mask=mask, return_dict=False)
        x = self.fc1(cls_hs)
        x = self.relu(x)
        x = self.dropout(x)
        # output layer
        x = self.fc2(x)
        # apply softmax activation
        x = self.softmax(x)
        return x

In [18]:
device = 'cuda'
model = BERT_Arch(bert)
model = model.to(device)

In [22]:
# define the optimizer
optimizer = AdamW(model.parameters(), lr = 1e-3)
#compute the class weights
y = train_dataframe['#3_country_label']
class_weights = compute_class_weight('balanced', np.unique(y), y)

print(class_weights)

[0.5527916  4.65116279 4.65116279 0.2334812  0.36643459 2.33100233
 2.33100233 1.55279503 0.77760498 4.65116279 1.16550117 0.66622252
 2.3364486  4.65116279 0.46728972 5.81395349 4.65116279 0.77700078
 1.16414435 1.5576324  2.33100233]


In [40]:
# function for evaluating the model
def evaluate():
    print("\nEvaluating...")
    # deactivate dropout layers
    model.eval()
    total_loss, total_accuracy = 0, 0
    # empty list to save the model predictions
    total_preds = []
    # iterate over batches
    for step,batch in enumerate(val_dataloader):
        # Progress update every 50 batches.
        if step % 50 == 0 and not step == 0:
        # Calculate elapsed time in minutes.
        # elapsed = format_time(time.time() - t0)
        # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(val_dataloader)))
        # push the batch to gpu
        batch = [t.to(device) for t in batch]
        sent_id, mask, labels = batch
        # deactivate autograd
        with torch.no_grad():
            # model predictions
            preds = model(sent_id, mask)
            # all_predictions.append(preds)
            # compute the validation loss between actual and predicted values
            loss = cross_entropy(preds,labels)
            total_loss = total_loss + loss.item()
        preds = preds.detach().cpu().numpy()
        total_preds.append(preds)
    # compute the validation loss of the epoch
    avg_loss = total_loss / len(val_dataloader) 
    # reshape the predictions in form of (number of samples, no. of classes)
    total_preds  = np.concatenate(total_preds, axis=0)
    return avg_loss, total_preds

In [41]:
# function to train the model
def train():
    global model
    model.train()
    total_loss, total_accuracy = 0, 0
    # empty list to save model predictions
    total_preds=[]
    # iterate over batches
    for step,batch in enumerate(train_dataloader):
        # progress update after every 50 batches.
        if step % 50 == 0 and not step == 0:
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(train_dataloader)))
        # push the batch to gpu
        batch = [r.to(device) for r in batch]
        sent_id, mask, labels = batch
        # clear previously calculated gradients 
        model.zero_grad()        
        # get model predictions for the current batch
        preds = model(sent_id, mask)
        # compute the loss between actual and predicted values
        loss = cross_entropy(preds, labels)
        # add on to the total loss
        total_loss = total_loss + loss.item()
        # backward pass to calculate the gradients
        loss.backward()
        # clip the the gradients to 1.0. It helps in preventing the exploding gradient problem
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        # update parameters
        optimizer.step()
        # model predictions are stored on GPU. So, push it to CPU
        preds=preds.detach().cpu().numpy()
        # append the model predictions
        total_preds.append(preds)
    # compute the training loss of the epoch
    avg_loss = total_loss / len(train_dataloader)
  
    # predictions are in the form of (no. of batches, size of batch, no. of classes).
    # reshape the predictions in form of (number of samples, no. of classes)
    total_preds  = np.concatenate(total_preds, axis=0)

    #returns the loss and predictions
    return avg_loss, total_preds

In [42]:
# converting list of class weights to a tensor
weights= torch.tensor(class_weights,dtype=torch.float)
# push to GPU
weights = weights.to(device)
# define the loss function
cross_entropy  = nn.NLLLoss(weight=weights) 
# number of training epochs
epochs = 150
# set initial loss to infinite
best_valid_loss = float('inf')
# empty lists to store training and validation loss of each epoch
train_losses=[]
valid_losses=[]
#for each epoch
for epoch in range(epochs):
    print('\n Epoch {:} / {:}'.format(epoch + 1, epochs))
    #train model
    train_loss, _ = train()
    #evaluate model
    valid_loss, all_prediction = evaluate()
    print(classification_report(val_y, np.argmax(all_prediction, axis=1), target_names=classes_names))
    #save the best model
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'saved_weights.pt')
    # append training and validation loss
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    print(f'\nTraining Loss: {train_loss:.3f}')
    print(f'Validation Loss: {valid_loss:.3f}')


 Epoch 1 / 150
  Batch    50  of    657.
  Batch   100  of    657.
  Batch   150  of    657.
  Batch   200  of    657.
  Batch   250  of    657.
  Batch   300  of    657.
  Batch   350  of    657.
  Batch   400  of    657.
  Batch   450  of    657.
  Batch   500  of    657.
  Batch   550  of    657.
  Batch   600  of    657.
  Batch   650  of    657.

Evaluating...
  Batch    50  of    157.
  Batch   100  of    157.
  Batch   150  of    157.


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                      precision    recall  f1-score   support

               Egypt       0.76      0.60      0.67      1041
                Iraq       0.25      0.66      0.37       664
        Saudi_Arabia       0.26      0.68      0.38       520
          Mauritania       0.00      0.00      0.00        53
             Algeria       0.00      0.00      0.00       430
               Syria       0.20      0.10      0.14       278
                Oman       0.20      0.01      0.01       355
             Tunisia       0.31      0.02      0.04       173
             Lebanon       0.00      0.00      0.00       157
             Morocco       0.14      0.20      0.17       207
            Djibouti       0.00      0.00      0.00        27
United_Arab_Emirates       0.00      0.00      0.00       157
              Kuwait       0.00      0.00      0.00       105
               Libya       0.15      0.29      0.20       314
             Bahrain       0.00      0.00      0.00        52
       

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                      precision    recall  f1-score   support

               Egypt       0.69      0.76      0.72      1041
                Iraq       0.29      0.63      0.40       664
        Saudi_Arabia       0.26      0.70      0.38       520
          Mauritania       0.00      0.00      0.00        53
             Algeria       0.00      0.00      0.00       430
               Syria       0.20      0.13      0.16       278
                Oman       0.56      0.01      0.03       355
             Tunisia       1.00      0.01      0.01       173
             Lebanon       0.00      0.00      0.00       157
             Morocco       0.15      0.17      0.16       207
            Djibouti       0.00      0.00      0.00        27
United_Arab_Emirates       0.00      0.00      0.00       157
              Kuwait       0.00      0.00      0.00       105
               Libya       0.19      0.36      0.25       314
             Bahrain       0.00      0.00      0.00        52
       

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                      precision    recall  f1-score   support

               Egypt       0.73      0.72      0.72      1041
                Iraq       0.26      0.67      0.38       664
        Saudi_Arabia       0.24      0.65      0.35       520
          Mauritania       0.00      0.00      0.00        53
             Algeria       0.00      0.00      0.00       430
               Syria       0.20      0.16      0.18       278
                Oman       0.00      0.00      0.00       355
             Tunisia       0.21      0.02      0.04       173
             Lebanon       0.00      0.00      0.00       157
             Morocco       0.21      0.17      0.19       207
            Djibouti       0.00      0.00      0.00        27
United_Arab_Emirates       0.00      0.00      0.00       157
              Kuwait       0.00      0.00      0.00       105
               Libya       0.20      0.27      0.23       314
             Bahrain       0.00      0.00      0.00        52
       

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                      precision    recall  f1-score   support

               Egypt       0.74      0.69      0.71      1041
                Iraq       0.31      0.58      0.41       664
        Saudi_Arabia       0.25      0.75      0.37       520
          Mauritania       0.00      0.00      0.00        53
             Algeria       0.00      0.00      0.00       430
               Syria       0.20      0.17      0.18       278
                Oman       1.00      0.01      0.02       355
             Tunisia       0.50      0.01      0.02       173
             Lebanon       0.50      0.01      0.01       157
             Morocco       0.14      0.20      0.16       207
            Djibouti       0.00      0.00      0.00        27
United_Arab_Emirates       0.00      0.00      0.00       157
              Kuwait       0.00      0.00      0.00       105
               Libya       0.20      0.38      0.26       314
             Bahrain       0.00      0.00      0.00        52
       

KeyboardInterrupt: 