In [1]:
import copy
from collections import namedtuple
import torch
import pandas as pd
import torch.nn as nn
import numpy as np
from conlleval import eval_f1score
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import AutoTokenizer, AutoModel, BertPreTrainedModel, BertModel, AdamW

I0510 17:03:46.905205 139825465988864 file_utils.py:41] PyTorch version 1.3.0 available.
  from ._conv import register_converters as _register_converters


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
PATH_TRAIN = "data/train/train_dataset.txt"
PATH_VAL = "data/val/val_dataset.txt"
PATH_AQMAR_TEST = "data/test/aqmar_test_dataset.txt"
PATH_NEWS_TEST = "data/test/news_test_dataset.txt"
PATH_TWEETS_TEST = "data/test/tweets_test_dataset.txt"
PATH_SEMILABELED = "data/semi_labeled/semi_labeled_dataset.txt"
PATH_STUDENT = "data/student_dataset.txt"

FULL_FINETUNE = True

In [4]:
label_to_id = {"O":0, "B-ORG":1, "I-ORG":2, "B-PER":3, "I-PER":4, "B-LOC":5, "I-LOC":6}
id_to_label = {value: key for key, value in label_to_id.items()}

In [5]:
arabert_tokenizer = AutoTokenizer.from_pretrained("aubmindlab/bert-base-arabertv01",do_lower_case=False)

I0510 17:03:48.630679 139825465988864 configuration_utils.py:283] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/aubmindlab/bert-base-arabertv01/config.json from cache at /home/chadi/.cache/torch/transformers/edefbd57b711b1796edd80ad0058293ec6e302f92fba0fcdd7138805dc6164ab.f6fc50854095aaf1023a82f7d5210b2df75a0334997d2daf64453496246d7b2d
I0510 17:03:48.631953 139825465988864 configuration_utils.py:319] Model config BertConfig {
  "_num_labels": 2,
  "architectures": null,
  "attention_probs_dropout_prob": 0.1,
  "bad_words_ids": null,
  "bos_token_id": null,
  "decoder_start_token_id": null,
  "directionality": "bidi",
  "do_sample": false,
  "early_stopping": false,
  "eos_token_id": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "is_encoder_decoder": f

In [6]:
def clean_label(label):
    if "B-ORG" in label:
        return "B-ORG"
    elif "I-ORG" in label:
        return "I-ORG"
    elif "B-PER" in label:
        return "B-PER"
    elif "I-PER" in label:
        return "I-PER"
    elif "B-LOC" in label:
        return "B-LOC"
    elif "I-LOC" in label:
        return "I-LOC"
    elif "O" in label:
        return "O"

In [7]:
def preprocess_data(PATH_DATASET, tokenizer, max_length=512):
    data = pd.read_csv(PATH_DATASET, encoding="utf-8", delim_whitespace=True, usecols=[0,1], header=None, skip_blank_lines=False)
    Instance = namedtuple("Instance", ["tokenized_text", "input_ids", "input_mask", "labels", "label_ids"])
    dataset = []
    text = ["[CLS]"]
    labels = ["O"]
    for w, l in zip(data[0], data[1]):
        if str(w) == "nan" and str(l) == "nan":
            text.append("[SEP]")
            labels.append("O")
            
            str_text = " ".join(text)
            tokenized_text = arabert_tokenizer.tokenize(str_text)
            
            cnt = 0 
            new_labels = []
            label_ids = []
            for i in tokenized_text:
                if "##" in i:
                    tok_label = labels[cnt - 1]
                    if "B-" in tok_label:
                        tok_label = tok_label.replace("B-", "I-")
                        
                    tok_label = clean_label(tok_label)
                    new_labels.append(tok_label)
                    label_ids.append(label_to_id[tok_label])
                else:
                    new_labels.append(labels[cnt])
                    label_ids.append(label_to_id[clean_label(labels[cnt])])
                    cnt += 1
                                    
            input_ids = tokenizer.convert_tokens_to_ids(tokenized_text)
                
            input_mask = [1] * len(input_ids)
            
            while len(input_ids) < max_length:
                input_ids.append(0)
                input_mask.append(0)
                label_ids.append(label_to_id["O"])
            
            dataset.append(Instance(tokenized_text, input_ids,
                            input_mask, new_labels, label_ids))

            text = ["[CLS]"]
            labels = ["O"]
            continue
            
        
        text.append(str(w))
        labels.append(str(l))
        
        
    return dataset

In [8]:
def preprocess_student_data(PATH_DATASET, tokenizer, max_length=512):
    data = pd.read_csv(PATH_DATASET, encoding="utf-8", delim_whitespace=True, header=None, skip_blank_lines=False, error_bad_lines=False)
    Instance = namedtuple("Instance", ["tokenized_text", "input_ids", "input_mask", "labels", "label_ids"])
    dataset = []
    text = ["[CLS]"]
    labels = ["O"]
    for w, l in zip(data[0], data[1]):
        if str(w) == "nan" and str(l) == "nan":
            text.append("[SEP]")
            labels.append("O")
            
            
            label_ids = []
            cnt = 0

            for i in text:
                tok_label = labels[cnt]
                label_ids.append(label_to_id[tok_label])
                cnt+=1
                                    
            input_ids = tokenizer.convert_tokens_to_ids(text)
                
            input_mask = [1] * len(input_ids)
            
            while len(input_ids) < max_length:
                input_ids.append(0)
                input_mask.append(0)
                label_ids.append(label_to_id["O"])
            
            dataset.append(Instance(text, input_ids,
                            input_mask, labels, label_ids))

            text = ["[CLS]"]
            labels = ["O"]
            continue
            
        
        text.append(str(w))
        labels.append(str(l))
        
        
    return dataset

In [9]:
def preprocess_semilabeled(PATH_DATASET, tokenizer, max_length=512):
    data = pd.read_csv(PATH_DATASET, encoding="utf-8", delim_whitespace=True, header=None, skip_blank_lines=False, error_bad_lines=False)
    Instance = namedtuple("Instance", ["tokenized_text", "input_ids", "input_mask", "labels", "label_ids", "proba_classes"])
    dataset = []
    text = ["[CLS]"]
    labels = ["O"]
    proba_classes = [[0.0, 0.0, 0.0]]
    cnt_error = 0
    for ins in data.values:
        if str(ins[0]) == "nan" and str(ins[1]) == "nan":
            text.append("[SEP]")
            labels.append("O")
            proba_classes.append([0.0, 0.0, 0.0])
            
            str_text = " ".join(text)
            tokenized_text = arabert_tokenizer.tokenize(str_text)
            
            cnt = 0 
            new_labels = []
            new_proba_classes = []
            label_ids = []
            
            if len(tokenized_text) < 512:
                try:
                    for i in tokenized_text:
                        if "##" in i:
                            tok_label = labels[cnt - 1]
                            if "B-" in tok_label:
                                tok_label = tok_label.replace("B-", "I-")

                            tok_label = clean_label(tok_label)
                            new_labels.append(tok_label)
                            new_proba_classes.append(proba_classes[cnt - 1])
                            label_ids.append(label_to_id[tok_label])
                        else:
                            new_labels.append(labels[cnt])
                            new_proba_classes.append(proba_classes[cnt])
                            label_ids.append(label_to_id[clean_label(labels[cnt])])
                            cnt += 1

                    input_ids = tokenizer.convert_tokens_to_ids(tokenized_text)

                    input_mask = [1] * len(input_ids)

                    while len(input_ids) < max_length:
                        input_ids.append(0)
                        input_mask.append(0)
                        label_ids.append(label_to_id["O"])

                    dataset.append(Instance(tokenized_text, input_ids,
                                    input_mask, new_labels, label_ids, new_proba_classes))
                
                except Exception as e:
                    print("An error occured")
                    cnt_error += 1
                    
            text = ["[CLS]"]
            labels = ["O"]
            proba_classes = [[0.0, 0.0, 0.0]]
            continue
            
        
        text.append(str(ins[0]))
        labels.append(str(ins[1]))
        
        temp_proba_classes = [str(p) for p in ins[2:]]
        temp_proba_classes = " ".join(temp_proba_classes).replace("[","").replace("]","")
        temp_proba_classes = [float(p) for p in temp_proba_classes.split(" ")[1:]]
        proba_classes.append(temp_proba_classes)

    return dataset



In [10]:
def preprocess_semilabeled_without_proba(PATH_DATASET, tokenizer, max_length=512):
    data = pd.read_csv(PATH_DATASET, encoding="utf-8", delim_whitespace=True, usecols=[0,1], header=None, skip_blank_lines=False, error_bad_lines=False)
    Instance = namedtuple("Instance", ["tokenized_text", "input_ids", "input_mask", "labels", "label_ids"])
    dataset = []
    text = ["[CLS]"]
    labels = ["O"]
    cnt_error = 0
    for ins in data.values:
        if str(ins[0]) == "nan" and str(ins[1]) == "nan":
            text.append("[SEP]")
            labels.append("O")
            str_text = " ".join(text)
            tokenized_text = arabert_tokenizer.tokenize(str_text)

            cnt = 0 
            new_labels = []
            label_ids = []
            
            if len(tokenized_text) < 512:
                try:
                    for i in tokenized_text:
                        if "##" in i:
                            tok_label = labels[cnt - 1]
                            if "B-" in tok_label:

                                tok_label = tok_label.replace("B-", "I-")

                            tok_label = clean_label(tok_label)
                            new_labels.append(tok_label)
                            label_ids.append(label_to_id[tok_label])
                        else:
                            new_labels.append(labels[cnt])
                            label_ids.append(label_to_id[clean_label(labels[cnt])])
                            cnt += 1

                    input_ids = tokenizer.convert_tokens_to_ids(tokenized_text)

                    input_mask = [1] * len(input_ids)

                    while len(input_ids) < max_length:
                        input_ids.append(0)
                        input_mask.append(0)
                        label_ids.append(label_to_id["O"])

                    dataset.append(Instance(tokenized_text, input_ids,
                                        input_mask, new_labels, label_ids))
                except Exception as e:
                    print("An error occured")
                    cnt_error += 1

            text = ["[CLS]"]
            labels = ["O"]
            continue
                
        text.append(str(ins[0]))
        labels.append(str(ins[1]))
        
    print("Errors:{}".format(cnt_error))
    return dataset

In [11]:
def transform_to_tensors(dataset):
    tensors_input_ids = []
    tensors_input_mask = []
    tensors_label_ids = []
    for i in dataset:
        tensors_input_ids.append(i.input_ids)
        tensors_input_mask.append(i.input_mask)
        tensors_label_ids.append(i.label_ids)
        
    return torch.tensor(tensors_input_ids), torch.tensor(tensors_input_mask), torch.tensor(tensors_label_ids)

In [12]:
def transform_to_tensors_semilabeled(dataset):
    tensors_input_ids = []
    tensors_input_mask = []
    for i in dataset:
        tensors_input_ids.append(i.input_ids)
        tensors_input_mask.append(i.input_mask)
        
    return torch.tensor(tensors_input_ids), torch.tensor(tensors_input_mask)

In [13]:
class ModifiedBertForTokenClassification(BertPreTrainedModel):
    def __init__(self, config, num_labels=7):
        super().__init__(config)
        self.num_labels = num_labels

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, self.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        outputs =  logits # (logits,) + outputs[2:] add hidden states and attention if they are here
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs =  loss # (loss,) + outputs

        return outputs  # (loss), scores, (hidden_states), (attentions)


In [14]:
def predict(model, filename, dataset, predict_dataloader, device="cpu"):
    global id_to_label
    model.eval()
    
    with torch.no_grad():
        fw =  open("{}".format(filename), "w", encoding="utf-8")
        cnt = 0
        for batch in tqdm(predict_dataloader):
            input_ids, input_mask = batch
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            output = model(input_ids=input_ids, attention_mask=input_mask)
            length = len(dataset[cnt].tokenized_text)
            
            for w in range(length):
                word = dataset[cnt].tokenized_text[w]
                if word != "[CLS]" and word != "[SEP]":
                    if dataset[cnt].labels[w] != "O":
                        pred_label = dataset[cnt].labels[w]
                    else:
                        pred_label = id_to_label[torch.argmax(output.squeeze(0)[w]).item()]
                    
                    fw.write("{} {} \n".format(word, pred_label))
            fw.write("\n")
            cnt += 1
        fw.close()

In [15]:
def evaluate(model, filename, dataset, dataloader):
    global id_to_label
    model.eval()
    f1_score = 0
    
    with torch.no_grad():
        fw =  open("{}".format(filename), "w")
        cnt = 0
        for batch in tqdm(dataloader):
            input_ids, input_mask, _ = batch
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            output = model(input_ids=input_ids, attention_mask=input_mask)

            length = len(dataset[cnt].tokenized_text)
            for w in range(length):
                word = dataset[cnt].tokenized_text[w]
                true_label = clean_label(dataset[cnt].labels[w])
                pred_label = id_to_label[torch.argmax(output.squeeze(0)[w]).item()]
                fw.write("{} {} {}\n".format(word, true_label, pred_label))
            fw.write("\n")
            cnt += 1
        fw.close()

        _, f1_score_arr = eval_f1score("{}".format(filename))
        
    return f1_score_arr
    

In [16]:
def train(model, optimizer, train_dataloader, val_dataloader, dataset_val, accumulation_steps=32, epochs=1, device="cpu"):
    model.to(device)
    best_f1_score = 0
    best_model = None
    
    for epoch in range(epochs):
        training_loss = 0.0
        val_loss = 0.0

        model.train()
        cnt_step = 0
        for batch in tqdm(train_dataloader):
            
            input_ids, input_mask, label_ids = batch
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            label_ids = label_ids.to(device)
            
            loss = model(input_ids=input_ids, attention_mask=input_mask, labels=label_ids)
            training_loss += loss.data.item()
            
            loss = loss / accumulation_steps
            loss.backward()
            
            if (cnt_step + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
            cnt_step += 1

        training_loss /= cnt_step
        
        model.eval()
        with torch.no_grad():
            for batch in tqdm(val_dataloader):
                input_ids, input_mask, label_ids = batch
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                label_ids = label_ids.to(device)

                loss = model(input_ids=input_ids, attention_mask=input_mask, labels=label_ids)
                val_loss += loss.data.item()

            val_loss /= len(val_dataloader)

            print("epoch {}: training loss {}, val loss {}".format(epoch, training_loss, val_loss))
            
        f1_score_arr = evaluate(model, "val.txt", dataset_val, val_dataloader)
        
        if f1_score_arr[3] > best_f1_score:
                best_f1_score = f1_score_arr[3]
                best_model = copy.deepcopy(model)
                print("We have a better model with an F1 Score: {}".format(best_f1_score))
            
    return best_model

In [17]:
def train_loss(model, optimizer, train_dataloader, val_dataloader, dataset_val, accumulation_steps=32, epochs=1, device="cpu"):
    model.to(device)
    best_loss = 10000
    best_model = None
    
    for epoch in range(epochs):
        training_loss = 0.0
        val_loss = 0.0

        model.train()
        cnt_step = 0
        for batch in tqdm(train_dataloader):
            
            input_ids, input_mask, label_ids = batch
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            label_ids = label_ids.to(device)
            
            loss = model(input_ids=input_ids, attention_mask=input_mask, labels=label_ids)
            training_loss += loss.data.item()
            
            loss = loss / accumulation_steps
            loss.backward()
            
            if (cnt_step + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
            cnt_step += 1

        training_loss /= cnt_step
        
        model.eval()
        with torch.no_grad():
            for batch in tqdm(val_dataloader):
                input_ids, input_mask, label_ids = batch
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                label_ids = label_ids.to(device)

                loss = model(input_ids=input_ids, attention_mask=input_mask, labels=label_ids)
                val_loss += loss.data.item()

            val_loss /= len(val_dataloader)

            print("epoch {}: training loss {}, val loss {}".format(epoch, training_loss, val_loss))
            
        if best_loss > val_loss:
            best_loss = val_loss
            best_model = copy.deepcopy(model)
            print("We have a better model")
            
    return best_model

## Train Teacher Model

In [18]:
teacher_arabert_model = ModifiedBertForTokenClassification.from_pretrained("aubmindlab/bert-base-arabertv01")

I0510 17:04:14.496253 139825465988864 configuration_utils.py:283] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/aubmindlab/bert-base-arabertv01/config.json from cache at /home/chadi/.cache/torch/transformers/edefbd57b711b1796edd80ad0058293ec6e302f92fba0fcdd7138805dc6164ab.f6fc50854095aaf1023a82f7d5210b2df75a0334997d2daf64453496246d7b2d
I0510 17:04:14.498865 139825465988864 configuration_utils.py:319] Model config BertConfig {
  "_num_labels": 2,
  "architectures": null,
  "attention_probs_dropout_prob": 0.1,
  "bad_words_ids": null,
  "bos_token_id": null,
  "decoder_start_token_id": null,
  "directionality": "bidi",
  "do_sample": false,
  "early_stopping": false,
  "eos_token_id": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "is_encoder_decoder": f

In [19]:
dataset_train = preprocess_data(PATH_TRAIN, arabert_tokenizer)
dataset_val = preprocess_data(PATH_VAL, arabert_tokenizer)
dataset_aqmar_test = preprocess_data(PATH_AQMAR_TEST, arabert_tokenizer)
dataset_news_test = preprocess_data(PATH_NEWS_TEST, arabert_tokenizer)
dataset_tweets_test = preprocess_data(PATH_TWEETS_TEST, arabert_tokenizer)

In [20]:
train_tensors_input_ids, train_tensors_input_mask, train_tensors_label_ids = transform_to_tensors(dataset_train)
val_tensors_input_ids, val_tensors_input_mask, val_tensors_label_ids = transform_to_tensors(dataset_val)
test_aqmar_tensors_input_ids, test_aqmar_tensors_input_mask, test_aqmar_tensors_label_ids = transform_to_tensors(dataset_aqmar_test)
test_news_tensors_input_ids, test_news_tensors_input_mask, test_news_tensors_label_ids = transform_to_tensors(dataset_news_test)
test_tweets_tensors_input_ids, test_tweets_tensors_input_mask, test_tweets_tensors_label_ids = transform_to_tensors(dataset_tweets_test)

In [21]:
train_tensor_dataset = TensorDataset(train_tensors_input_ids, train_tensors_input_mask, train_tensors_label_ids)
val_tensor_dataset = TensorDataset(val_tensors_input_ids, val_tensors_input_mask, val_tensors_label_ids)
test_aqmar_tensor_dataset = TensorDataset(test_aqmar_tensors_input_ids, test_aqmar_tensors_input_mask, test_aqmar_tensors_label_ids)
test_news_tensor_dataset = TensorDataset(test_news_tensors_input_ids, test_news_tensors_input_mask, test_news_tensors_label_ids)
test_tweets_tensor_dataset = TensorDataset(test_tweets_tensors_input_ids, test_tweets_tensors_input_mask, test_tweets_tensors_label_ids)

In [22]:
train_dataloader = DataLoader(train_tensor_dataset, batch_size=1)
val_dataloader = DataLoader(val_tensor_dataset, batch_size=1)
test_aqmar_dataloader = DataLoader(test_aqmar_tensor_dataset, batch_size=1)
test_news_dataloader = DataLoader(test_news_tensor_dataset, batch_size=1)
test_tweets_dataloader = DataLoader(test_tweets_tensor_dataset, batch_size=1)

In [23]:
optimizer_grouped_parameters = None
param_optimizer = list(teacher_arabert_model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']

if FULL_FINETUNE:
    print('ALL FINETUNE')
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
else:
    print('NO ALL FINETUNE')
    optimizer_grouped_parameters = [
        {'params': teacher_arabert_model.classifier.parameters(),
         'weight_decay_rate': 0.01}
    ]

optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)

ALL FINETUNE


In [24]:
trained_teacher_model = train_loss(teacher_arabert_model, optimizer, train_dataloader, val_dataloader, dataset_val, epochs=4, device=device)

100%|██████████| 4147/4147 [04:59<00:00, 13.84it/s]
100%|██████████| 2555/2555 [01:03<00:00, 40.07it/s]
  0%|          | 2/4147 [00:00<03:56, 17.56it/s]

epoch 0: training loss 0.28163901283193493, val loss 0.11237545299599833
We have a better model


100%|██████████| 4147/4147 [05:08<00:00, 13.43it/s]
100%|██████████| 2555/2555 [01:04<00:00, 39.29it/s]
  0%|          | 2/4147 [00:00<04:07, 16.73it/s]

epoch 1: training loss 0.06832305082777003, val loss 0.12113542770089786


100%|██████████| 4147/4147 [05:07<00:00, 13.48it/s]
100%|██████████| 2555/2555 [01:04<00:00, 39.48it/s]
  0%|          | 2/4147 [00:00<04:12, 16.41it/s]

epoch 2: training loss 0.041857792907891905, val loss 0.1278511223064188


100%|██████████| 4147/4147 [05:08<00:00, 13.45it/s]
100%|██████████| 2555/2555 [01:05<00:00, 39.29it/s]

epoch 3: training loss 0.02915042966799278, val loss 0.144541788735899





In [25]:
evaluate(trained_teacher_model, "val.txt", dataset_val, val_dataloader)
evaluate(trained_teacher_model, "test_aqmar.txt", dataset_aqmar_test, test_aqmar_dataloader)
evaluate(trained_teacher_model, "test_news.txt", dataset_news_test, test_news_dataloader)
evaluate(trained_teacher_model, "test_tweets.txt", dataset_tweets_test, test_tweets_dataloader)

100%|██████████| 2555/2555 [01:07<00:00, 38.03it/s]
  0%|          | 4/2456 [00:00<01:03, 38.74it/s]

processed 94131 tokens with 3118 phrases; found: 3031 phrases; correct: 2317.
accuracy:  97.35%; precision:  76.44%; recall:  74.31%; FB1:  75.36
              LOC: precision:  83.33%; recall:  93.78%; FB1:  88.25  1104
              ORG: precision:  54.79%; recall:  50.58%; FB1:  52.60  553
              PER: precision:  79.62%; recall:  71.13%; FB1:  75.14  1374
[88.24940047961631, 52.604166666666664, 75.13736263736264, 75.36184745487071]


100%|██████████| 2456/2456 [01:04<00:00, 37.86it/s]
  1%|▏         | 4/292 [00:00<00:07, 37.42it/s]

processed 88841 tokens with 2886 phrases; found: 2673 phrases; correct: 1710.
accuracy:  95.65%; precision:  63.97%; recall:  59.25%; FB1:  61.52
              LOC: precision:  70.76%; recall:  59.78%; FB1:  64.81  1067
              ORG: precision:  26.58%; recall:  37.98%; FB1:  31.27  523
              PER: precision:  75.35%; recall:  64.92%; FB1:  69.74  1083
[64.80686695278969, 31.2710911136108, 69.74358974358975, 61.521856449001625]


100%|██████████| 292/292 [00:07<00:00, 38.23it/s]
  1%|          | 5/982 [00:00<00:23, 41.09it/s]

processed 17655 tokens with 1195 phrases; found: 1125 phrases; correct: 849.
accuracy:  94.31%; precision:  75.47%; recall:  71.05%; FB1:  73.19
              LOC: precision:  82.96%; recall:  70.88%; FB1:  76.44  311
              ORG: precision:  56.62%; recall:  52.12%; FB1:  54.28  325
              PER: precision:  83.23%; recall:  85.15%; FB1:  84.18  489
[76.44444444444444, 54.27728613569322, 84.17786970010341, 73.18965517241381]


100%|██████████| 982/982 [00:25<00:00, 38.36it/s]


processed 22133 tokens with 513 phrases; found: 418 phrases; correct: 269.
accuracy:  96.45%; precision:  64.35%; recall:  52.44%; FB1:  57.79
              LOC: precision:  79.03%; recall:  47.12%; FB1:  59.04  124
              ORG: precision:  34.94%; recall:  29.00%; FB1:  31.69  83
              PER: precision:  67.30%; recall:  69.27%; FB1:  68.27  211
[59.036144578313255, 31.693989071038253, 68.26923076923077, 57.787325456498394]


[59.036144578313255, 31.693989071038253, 68.26923076923077, 57.787325456498394]

In [26]:
torch.save(trained_teacher_model, "best_teacher_model.h5")
del trained_teacher_model
del teacher_arabert_model

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


## Predict Semi-Labeled Dataset 

In [27]:
teacher_model = torch.load("best_teacher_model.h5",map_location=torch.device(device))

In [28]:
dataset_predict = preprocess_semilabeled_without_proba(PATH_SEMILABELED, arabert_tokenizer)

An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occur

An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occur

An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occur

An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occur

An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occured
An error occur

In [29]:
predict_tensors_input_ids, predict_tensors_input_mask = transform_to_tensors_semilabeled(dataset_predict)

In [30]:
predict_tensor_dataset = TensorDataset(predict_tensors_input_ids, predict_tensors_input_mask)

In [31]:
predict_dataloader = DataLoader(predict_tensor_dataset, batch_size=1)

In [32]:
predict(teacher_model, PATH_STUDENT, dataset_predict, predict_dataloader, device)

100%|██████████| 53711/53711 [23:45<00:00, 37.69it/s]


## Train Student Model

In [33]:
student_arabert_model = ModifiedBertForTokenClassification.from_pretrained("aubmindlab/bert-base-arabertv01")

I0510 17:56:17.824970 139825465988864 configuration_utils.py:283] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/aubmindlab/bert-base-arabertv01/config.json from cache at /home/chadi/.cache/torch/transformers/edefbd57b711b1796edd80ad0058293ec6e302f92fba0fcdd7138805dc6164ab.f6fc50854095aaf1023a82f7d5210b2df75a0334997d2daf64453496246d7b2d
I0510 17:56:17.825569 139825465988864 configuration_utils.py:319] Model config BertConfig {
  "_num_labels": 2,
  "architectures": null,
  "attention_probs_dropout_prob": 0.1,
  "bad_words_ids": null,
  "bos_token_id": null,
  "decoder_start_token_id": null,
  "directionality": "bidi",
  "do_sample": false,
  "early_stopping": false,
  "eos_token_id": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "is_encoder_decoder": f

In [34]:
student_dataset_train = preprocess_student_data(PATH_STUDENT, arabert_tokenizer)

In [35]:
student_train_tensors_input_ids, student_train_tensors_input_mask, student_train_tensors_label_ids = transform_to_tensors(student_dataset_train)

In [36]:
student_train_tensor_dataset = TensorDataset(student_train_tensors_input_ids, student_train_tensors_input_mask, student_train_tensors_label_ids)

In [37]:
student_train_dataloader = DataLoader(student_train_tensor_dataset, batch_size=1)

In [38]:
optimizer_grouped_parameters = None
param_optimizer = list(student_arabert_model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']

if FULL_FINETUNE:
    print('ALL FINETUNE')
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
else:
    print('NO ALL FINETUNE')
    optimizer_grouped_parameters = [
        {'params': student_arabert_model.classifier.parameters(),
         'weight_decay_rate': 0.01}
    ]

optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)

ALL FINETUNE


In [39]:
student_teacher_model = train_loss(student_arabert_model, optimizer, student_train_dataloader, val_dataloader, dataset_val, epochs=4, device=device)

100%|██████████| 53711/53711 [1:06:06<00:00, 13.54it/s]
100%|██████████| 2555/2555 [01:05<00:00, 38.96it/s]
  0%|          | 2/53711 [00:00<52:40, 16.99it/s]

epoch 0: training loss 0.24410190692526343, val loss 0.1390509720883968
We have a better model


100%|██████████| 53711/53711 [1:05:52<00:00, 13.59it/s]
100%|██████████| 2555/2555 [01:05<00:00, 39.17it/s]
  0%|          | 2/53711 [00:00<54:34, 16.40it/s]

epoch 1: training loss 0.16812767760870434, val loss 0.1580906663756816


100%|██████████| 53711/53711 [1:05:47<00:00, 13.61it/s]
100%|██████████| 2555/2555 [01:04<00:00, 39.36it/s]
  0%|          | 2/53711 [00:00<55:27, 16.14it/s]

epoch 2: training loss 0.13343738509957978, val loss 0.16918142978833267


100%|██████████| 53711/53711 [1:05:49<00:00, 13.60it/s]
100%|██████████| 2555/2555 [01:05<00:00, 39.22it/s]

epoch 3: training loss 0.1063113289395121, val loss 0.17889171409619556





In [40]:
evaluate(student_teacher_model, "student_val.txt", dataset_val, val_dataloader)
evaluate(student_teacher_model, "student_test_aqmar.txt", dataset_aqmar_test, test_aqmar_dataloader)
evaluate(student_teacher_model, "student_test_news.txt", dataset_news_test, test_news_dataloader)
evaluate(student_teacher_model, "student_test_tweets.txt", dataset_tweets_test, test_tweets_dataloader)

100%|██████████| 2555/2555 [01:07<00:00, 37.91it/s]
  0%|          | 4/2456 [00:00<01:07, 36.41it/s]

processed 94131 tokens with 3118 phrases; found: 3589 phrases; correct: 2352.
accuracy:  96.53%; precision:  65.53%; recall:  75.43%; FB1:  70.14
              LOC: precision:  64.77%; recall:  91.85%; FB1:  75.97  1391
              ORG: precision:  45.35%; recall:  60.27%; FB1:  51.76  796
              PER: precision:  77.75%; recall:  70.87%; FB1:  74.15  1402
[75.96964586846543, 51.75627240143369, 74.14965986394556, 70.13567914119577]


100%|██████████| 2456/2456 [01:04<00:00, 38.83it/s]
  1%|▏         | 4/292 [00:00<00:07, 36.50it/s]

processed 88841 tokens with 2886 phrases; found: 3352 phrases; correct: 1909.
accuracy:  95.33%; precision:  56.95%; recall:  66.15%; FB1:  61.21
              LOC: precision:  62.54%; recall:  69.28%; FB1:  65.74  1399
              ORG: precision:  23.25%; recall:  51.64%; FB1:  32.06  813
              PER: precision:  74.12%; recall:  67.22%; FB1:  70.50  1140
[65.74004507888806, 32.06106870229008, 70.50479766374634, 61.20551458800898]


100%|██████████| 292/292 [00:07<00:00, 37.07it/s]
  0%|          | 4/982 [00:00<00:25, 37.77it/s]

processed 17655 tokens with 1195 phrases; found: 1255 phrases; correct: 914.
accuracy:  94.14%; precision:  72.83%; recall:  76.49%; FB1:  74.61
              LOC: precision:  72.65%; recall:  74.45%; FB1:  73.54  373
              ORG: precision:  56.25%; recall:  61.19%; FB1:  58.62  384
              PER: precision:  85.74%; recall:  89.33%; FB1:  87.50  498
[73.54138398914517, 58.61601085481682, 87.5, 74.61224489795919]


100%|██████████| 982/982 [00:26<00:00, 37.62it/s]


processed 22133 tokens with 513 phrases; found: 571 phrases; correct: 307.
accuracy:  95.83%; precision:  53.77%; recall:  59.84%; FB1:  56.64
              LOC: precision:  60.66%; recall:  53.37%; FB1:  56.78  183
              ORG: precision:  36.21%; recall:  42.00%; FB1:  38.89  116
              PER: precision:  56.62%; recall:  75.12%; FB1:  64.57  272
[56.77749360613811, 38.888888888888886, 64.57023060796645, 56.642066420664214]


[56.77749360613811, 38.888888888888886, 64.57023060796645, 56.642066420664214]

In [41]:
optimizer_grouped_parameters = None
param_optimizer = list(student_teacher_model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']

if FULL_FINETUNE:
    print('ALL FINETUNE')
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
else:
    print('NO ALL FINETUNE')
    optimizer_grouped_parameters = [
        {'params': student_teacher_model.classifier.parameters(),
         'weight_decay_rate': 0.01}
    ]

optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)

ALL FINETUNE


In [42]:
finetuned_student_teacher_model = train_loss(student_teacher_model, optimizer, train_dataloader, val_dataloader, dataset_val, epochs=4, device=device)

100%|██████████| 4147/4147 [05:03<00:00, 13.65it/s]
100%|██████████| 2555/2555 [01:04<00:00, 39.55it/s]
  0%|          | 2/4147 [00:00<04:01, 17.16it/s]

epoch 0: training loss 0.0769748305420798, val loss 0.0996225939988074
We have a better model


100%|██████████| 4147/4147 [05:05<00:00, 13.57it/s]
100%|██████████| 2555/2555 [01:04<00:00, 39.49it/s]
  0%|          | 2/4147 [00:00<04:14, 16.30it/s]

epoch 1: training loss 0.04176969075727229, val loss 0.11419789121150101


100%|██████████| 4147/4147 [05:05<00:00, 13.57it/s]
100%|██████████| 2555/2555 [01:04<00:00, 39.38it/s]
  0%|          | 2/4147 [00:00<04:12, 16.40it/s]

epoch 2: training loss 0.026019873369021472, val loss 0.14112530970352088


100%|██████████| 4147/4147 [05:03<00:00, 13.66it/s]
100%|██████████| 2555/2555 [01:04<00:00, 39.45it/s]

epoch 3: training loss 0.018926680339868194, val loss 0.1598931491144531





In [43]:
evaluate(finetuned_student_teacher_model, "finetuned_student_val.txt", dataset_val, val_dataloader)
evaluate(finetuned_student_teacher_model, "finetuned_student_test_aqmar.txt", dataset_aqmar_test, test_aqmar_dataloader)
evaluate(finetuned_student_teacher_model, "finetuned_student_test_news.txt", dataset_news_test, test_news_dataloader)
evaluate(finetuned_student_teacher_model, "finetuned_student_test_tweets.txt", dataset_tweets_test, test_tweets_dataloader)

100%|██████████| 2555/2555 [01:07<00:00, 37.94it/s]
  0%|          | 4/2456 [00:00<01:06, 37.02it/s]

processed 94131 tokens with 3118 phrases; found: 3014 phrases; correct: 2405.
accuracy:  97.58%; precision:  79.79%; recall:  77.13%; FB1:  78.44
              LOC: precision:  87.16%; recall:  95.51%; FB1:  91.15  1075
              ORG: precision:  62.19%; recall:  57.93%; FB1:  59.98  558
              PER: precision:  81.17%; recall:  72.89%; FB1:  76.81  1381
[91.14785992217898, 59.98271391529819, 76.80712572798905, 78.44096542726679]


100%|██████████| 2456/2456 [01:06<00:00, 36.79it/s]
  1%|▏         | 4/292 [00:00<00:08, 34.62it/s]

processed 88841 tokens with 2886 phrases; found: 2653 phrases; correct: 1776.
accuracy:  95.96%; precision:  66.94%; recall:  61.54%; FB1:  64.13
              LOC: precision:  75.27%; recall:  61.20%; FB1:  67.51  1027
              ORG: precision:  26.69%; recall:  40.98%; FB1:  32.33  562
              PER: precision:  80.17%; recall:  67.86%; FB1:  73.50  1064
[67.51091703056768, 32.327586206896555, 73.50280051701851, 64.12709875428779]


100%|██████████| 292/292 [00:07<00:00, 38.26it/s]
  0%|          | 4/982 [00:00<00:24, 39.43it/s]

processed 17655 tokens with 1195 phrases; found: 1134 phrases; correct: 896.
accuracy:  94.81%; precision:  79.01%; recall:  74.98%; FB1:  76.94
              LOC: precision:  84.49%; recall:  73.35%; FB1:  78.53  316
              ORG: precision:  62.50%; recall:  56.66%; FB1:  59.44  320
              PER: precision:  86.14%; recall:  89.75%; FB1:  87.91  498
[78.52941176470587, 59.43536404160476, 87.90983606557377, 76.94289394589954]


100%|██████████| 982/982 [00:25<00:00, 37.78it/s]

processed 22133 tokens with 513 phrases; found: 510 phrases; correct: 299.
accuracy:  95.88%; precision:  58.63%; recall:  58.28%; FB1:  58.46
              LOC: precision:  82.68%; recall:  50.48%; FB1:  62.69  127
              ORG: precision:  38.18%; recall:  42.00%; FB1:  40.00  110
              PER: precision:  55.68%; recall:  74.15%; FB1:  63.60  273
[62.686567164179095, 40.00000000000001, 63.59832635983263, 58.45552297165201]





[62.686567164179095, 40.00000000000001, 63.59832635983263, 58.45552297165201]

In [44]:
torch.save(finetuned_student_teacher_model, "best_finetuned_student_model.h5")

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
