In [1]:
filepath = open("data/semi_labeled/semi_labeled_dataset.txt", "r", encoding = "utf-8").read().split("\n")
outputpath = open("semi_labeled_unlabeled.txt", 'w', encoding = 'utf-8')

for line in filepath:
    if(len(line)>1):
        lineTokens = line.split()
        t1 = lineTokens[0]
        outputpath.write("{}\n".format(t1))
    else: outputpath.write('\n')

In [2]:
# !pip install transformers
import transformers
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import AutoTokenizer, AutoModel, BertPreTrainedModel, BertModel, AdamW
import pandas as pd
from collections import namedtuple


In [3]:
def preprocess(filepath, max_length=512):
    arabert_tokenizer = AutoTokenizer.from_pretrained("aubmindlab/bert-base-arabertv01",do_lower_case=False)
    data = pd.read_csv(filepath, encoding="utf-8", delim_whitespace=True, header=None, skip_blank_lines=False)
    Instance = namedtuple("Instance", ["tokenized_text", "input_ids", "input_mask"])
    dataset = []
    text = ["[CLS]"]
    for w in data[0]:
        if str(w) == "nan":
            text.append("[SEP]")
            str_text = " ".join(text)
            tokenized_text = arabert_tokenizer.tokenize(str_text)
#             print(len(tokenized_text))
            if len(tokenized_text) > 512: 
                text= ["[CLS]"]
                continue
            cnt = 0 
            new_labels = []

            input_ids = arabert_tokenizer.convert_tokens_to_ids(tokenized_text)

            input_mask = [1] * len(input_ids)

            while len(input_ids) < max_length:
                input_ids.append(0)
                input_mask.append(0)

            dataset.append(Instance(tokenized_text, input_ids,
                        input_mask))
            text = ["[CLS]"]
            continue
    
        text.append(str(w))
    return dataset


In [4]:
def transform_to_tensors(dataset):
    tensors_input_ids = []
    tensors_input_mask = []
    for i in dataset:
        tensors_input_ids.append(i.input_ids)
        tensors_input_mask.append(i.input_mask)
#         print(i.input_ids)
    return torch.tensor(tensors_input_ids), torch.tensor(tensors_input_mask)

In [5]:
dataset_predict = preprocess("semi_labeled_unlabeled.txt")
dataset_predict

[Instance(tokenized_text=['[CLS]', 'كان', 'في', 'بدايته', 'السبعينات', 'كان', 'يقدم', 'بعض', 'العروض', 'المسرحي', '##ه', 'والمسابقات', 'الثقافي', '##ه', 'في', 'نادي', 'الهلال', 'وبعد', 'الانتهاء', 'من', 'عمله', 'في', 'النادي', 'يتجه', 'الي', 'مبني', 'التلفزيون', 'السعودي', 'الذي', 'كان', 'في', 'بداياته', 'في', 'ذلك', 'الوقت', 'ليقوم', 'بت', '##الي', '##ف', 'وتمثيل', 'بعض', 'الفقرات', 'والق', '##ف', '##شات', 'الكوميدي', '##ه', 'واول', 'عمل', 'له', 'عرفه', 'الناس', 'منه', 'هو', 'الضيف', 'الغريب', 'في', 'عام', '[UNK]', 'من', 'اعداد', 'واخراج', 'ابراهيم', 'الحمدان', '[SEP]'], input_ids=[17028, 3176, 660, 35617, 57608, 3176, 13875, 1662, 33420, 46951, 909, 59686, 45206, 909, 660, 11751, 34722, 12430, 52837, 731, 9716, 660, 34556, 13350, 1521, 11083, 57535, 45769, 5935, 3176, 660, 48194, 660, 2296, 19590, 24689, 424, 14759, 903, 40462, 1662, 46312, 12373, 903, 15730, 53622, 909, 12391, 2888, 706, 9602, 19456, 3579, 751, 18979, 33553, 660, 2806, 60122, 731, 18124, 39639, 43718, 45405, 17030, 

In [6]:
predict_tensors_input_ids, predict_tensors_input_mask = transform_to_tensors(dataset_predict)

In [7]:
predict_tensors_input_ids

tensor([[17028,  3176,   660,  ...,     0,     0,     0],
        [17028, 27610, 19580,  ...,     0,     0,     0],
        [17028, 12430,  2020,  ...,     0,     0,     0],
        ...,
        [17028,  4008,  2527,  ...,     0,     0,     0],
        [17028,  2107,   110,  ...,     0,     0,     0],
        [17028,  4008,  2806,  ...,     0,     0,     0]])

In [8]:
class ModifiedBertForTokenClassification(BertPreTrainedModel):
    def __init__(self, config, num_labels=7):
        super().__init__(config)
        self.num_labels = num_labels

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, self.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        outputs =  logits # (logits,) + outputs[2:] add hidden states and attention if they are here
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs =  loss # (loss,) + outputs

        return outputs  # (loss), scores, (hidden_states), (attentions)


In [42]:
from tqdm import tqdm
def predict(model, filename, dataset, predict_dataloader, device="cpu"):
    global id_to_label
    model.eval()
    
    with torch.no_grad():
        fw =  open("{}".format(filename), "w", encoding="utf-8")
        cnt = 0
        for batch in tqdm(predict_dataloader):
            input_ids, input_mask = batch
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            output = model(input_ids=input_ids, attention_mask=input_mask)
#             print(output)
            length = len(dataset[cnt].tokenized_text)
            for w in range(length):
                word = dataset[cnt].tokenized_text[w]
#                 true_label = clean_label(dataset[cnt].labels[w])
                pred_label = id_to_label[torch.argmax(output.squeeze(0)[w]).item()]
                fw.write("{} {} \n".format(word, pred_label))
            fw.write("\n")
            cnt += 1
        fw.close()

In [9]:
PATH = "best_model.h5"

In [11]:
model = torch.load(PATH,map_location=torch.device('cpu'))


In [12]:
predict_tensor_dataset = TensorDataset(predict_tensors_input_ids, predict_tensors_input_mask)

In [14]:
predict_dataloader = DataLoader(predict_tensor_dataset, batch_size=1)

In [29]:
label_to_id = {"O":0, "B-ORG":1, "I-ORG":2, "B-PER":3, "I-PER":4, "B-LOC":5, "I-LOC":6}
id_to_label = {value: key for key, value in label_to_id.items()}

In [43]:
predict(model, "predictoutput.txt", dataset_predict, predict_dataloader)





  0%|                                                                                        | 0/56156 [00:00<?, ?it/s]



  0%|                                                                             | 1/56156 [00:02<42:44:05,  2.74s/it]



  0%|                                                                             | 2/56156 [00:05<43:03:39,  2.76s/it]



  0%|                                                                             | 3/56156 [00:08<44:00:27,  2.82s/it]



  0%|                                                                             | 4/56156 [00:11<44:27:52,  2.85s/it]



  0%|                                                                             | 5/56156 [00:14<45:39:41,  2.93s/it]



  0%|                                                                             | 6/56156 [00:17<45:48:47,  2.94s/it]



  0%|                                                                             | 7/56156 [00:20<46:32:06,  2.98s/it]



  0%|       

  0%|                                                                            | 66/56156 [03:07<43:04:16,  2.76s/it]



  0%|                                                                            | 67/56156 [03:09<42:37:59,  2.74s/it]



  0%|                                                                            | 68/56156 [03:12<42:22:55,  2.72s/it]



  0%|                                                                            | 69/56156 [03:15<42:07:53,  2.70s/it]



  0%|                                                                            | 70/56156 [03:17<42:25:36,  2.72s/it]



  0%|                                                                            | 71/56156 [03:20<43:17:59,  2.78s/it]



  0%|                                                                            | 72/56156 [03:23<44:37:09,  2.86s/it]



  0%|                                                                            | 73/56156 [03:27<46:09:28,  2.96s/it]



  0%|           

  0%|▏                                                                          | 132/56156 [06:11<42:48:52,  2.75s/it]



  0%|▏                                                                          | 133/56156 [06:14<42:30:14,  2.73s/it]



  0%|▏                                                                          | 134/56156 [06:17<42:31:59,  2.73s/it]



  0%|▏                                                                          | 135/56156 [06:20<42:41:51,  2.74s/it]



  0%|▏                                                                          | 136/56156 [06:22<42:40:38,  2.74s/it]



  0%|▏                                                                          | 137/56156 [06:25<43:19:44,  2.78s/it]



  0%|▏                                                                          | 138/56156 [06:28<42:46:44,  2.75s/it]



  0%|▏                                                                          | 139/56156 [06:31<42:12:11,  2.71s/it]



  0%|▏          

  0%|▎                                                                          | 198/56156 [09:08<41:28:35,  2.67s/it]



  0%|▎                                                                          | 199/56156 [09:11<41:18:38,  2.66s/it]



  0%|▎                                                                          | 200/56156 [09:14<41:17:48,  2.66s/it]



  0%|▎                                                                          | 201/56156 [09:16<41:16:55,  2.66s/it]



  0%|▎                                                                          | 202/56156 [09:19<41:09:52,  2.65s/it]



  0%|▎                                                                          | 203/56156 [09:22<41:15:49,  2.65s/it]



  0%|▎                                                                          | 204/56156 [09:24<41:13:15,  2.65s/it]



  0%|▎                                                                          | 205/56156 [09:27<41:12:01,  2.65s/it]



  0%|▎          

  0%|▎                                                                          | 264/56156 [12:06<41:40:55,  2.68s/it]



  0%|▎                                                                          | 265/56156 [12:08<40:26:12,  2.60s/it]



  0%|▎                                                                          | 266/56156 [12:11<40:34:38,  2.61s/it]



  0%|▎                                                                          | 267/56156 [12:13<40:47:30,  2.63s/it]



  0%|▎                                                                          | 268/56156 [12:16<40:43:24,  2.62s/it]



  0%|▎                                                                          | 269/56156 [12:19<41:51:18,  2.70s/it]



  0%|▎                                                                          | 270/56156 [12:22<41:44:28,  2.69s/it]



  0%|▎                                                                          | 271/56156 [12:24<41:34:40,  2.68s/it]



  0%|▎          

KeyboardInterrupt: 