In [1]:
!pip install transformers[torch]



In [2]:
!pip install accelerate -U



In [3]:
!pip install -q transformers
!pip install -q datasets

In [21]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering,DataCollatorWithPadding
from transformers import TrainingArguments, Trainer

In [23]:
model_checkpoint = "HooshvareLab/bert-fa-base-uncased"
max_length = 512 # The maximum length of a feature (question and context)
doc_stride = 256 # The authorized overlap between two part of the context when splitting it is needed.
batch_size = 8
lr = 3e-5
epoch = 10

In [24]:
datasets = load_dataset("SeyedAli/Persian-Text-QA")
datasets['train'][0]

Repo card metadata block was not found. Setting CardData to empty.


{'id': 1,
 'title': 'شرکت فولاد مبارکه اصفهان',
 'context': 'شرکت فولاد مبارکۀ اصفهان، بزرگ\u200cترین واحد صنعتی خصوصی در ایران و بزرگ\u200cترین مجتمع تولید فولاد در کشور ایران است، که در شرق شهر مبارکه قرار دارد. فولاد مبارکه هم\u200cاکنون محرک بسیاری از صنایع بالادستی و پایین\u200cدستی است. فولاد مبارکه در ۱۱ دوره جایزۀ ملی تعالی سازمانی و ۶ دوره جایزۀ شرکت دانشی در کشور رتبۀ نخست را بدست آورده\u200cاست و همچنین این شرکت در سال ۱۳۹۱ برای نخستین\u200cبار به عنوان تنها شرکت ایرانی با کسب امتیاز ۶۵۴ تندیس زرین جایزۀ ملی تعالی سازمانی را از آن خود کند. شرکت فولاد مبارکۀ اصفهان در ۲۳ دی ماه ۱۳۷۱ احداث شد و اکنون بزرگ\u200cترین واحدهای صنعتی و بزرگترین مجتمع تولید فولاد در ایران است. این شرکت در زمینی به مساحت ۳۵ کیلومتر مربع در نزدیکی شهر مبارکه و در ۷۵ کیلومتری جنوب غربی شهر اصفهان واقع شده\u200cاست. مصرف آب این کارخانه در کمترین میزان خود، ۱٫۵٪ از دبی زاینده\u200cرود برابر سالانه ۲۳ میلیون متر مکعب در سال است و خود یکی از عوامل کم\u200cآبی زاینده\u200cرود شناخته می\u200cشود.',
 'questio

In [25]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
data_collector = DataCollatorWithPadding(tokenizer=tokenizer)

In [26]:
def prepare_train_features(examples):
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True)

    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized_examples.pop("offset_mapping")
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []
    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)
        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)
        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])
            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1
            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1
            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [27]:
# the datasets library does cashing itself, batched is multitreading for fast-tokenizer
tokenized_ds = datasets.map(prepare_train_features, batched=True, remove_columns=datasets["train"].column_names)

In [28]:
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at HooshvareLab/bert-fa-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
!pip install accelerate>=0.20.1

In [29]:
from sklearn.metrics import precision_recall_fscore_support,accuracy_score

In [31]:
args = TrainingArguments(
    f"result",
    evaluation_strategy = "epoch",
    learning_rate=lr,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epoch,
    weight_decay=0.0001)

In [37]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_ds['train'],
    eval_dataset=tokenized_ds['validation'],
    tokenizer=tokenizer,
    data_collator=data_collector
    )

In [38]:
# start training
trainer.train()

Epoch,Training Loss,Validation Loss
1,0.66,1.929703
2,0.8676,2.104479
3,0.4174,2.490557
4,0.2623,3.002457
5,0.1855,3.686428
6,0.1199,4.284048
7,0.0642,4.813411
8,0.042,4.994059
9,0.0195,4.952233
10,0.0085,5.187868


TrainOutput(global_step=11260, training_loss=0.2798460031275639, metrics={'train_runtime': 5491.64, 'train_samples_per_second': 16.403, 'train_steps_per_second': 2.05, 'total_flos': 1.2852930249848928e+16, 'train_loss': 0.2798460031275639, 'epoch': 10.0})

In [39]:
trainer.evaluate()

{'eval_loss': 5.187867641448975,
 'eval_runtime': 13.3813,
 'eval_samples_per_second': 69.5,
 'eval_steps_per_second': 8.744,
 'epoch': 10.0}

In [44]:
from transformers import pipeline

qa_pipeline = pipeline("question-answering", model='/content/drive/MyDrive/PersianQA', tokenizer='/content/drive/MyDrive/PersianQA')

text = r"""سلام من سید علی میر محمد حسینی هستم 25 سالمه و به پردازش زبان طبیعی علاقه دارم """
questions = ["اسمم چیه؟", "چند سالمه؟", "به چی علاقه دارم؟"]

for question in questions:
    print(qa_pipeline({"context": text, "question": question}))

{'score': 0.3693717122077942, 'start': 8, 'end': 30, 'answer': 'سید علی میر محمد حسینی'}
{'score': 3.54418041581539e-08, 'start': 36, 'end': 44, 'answer': '25 سالمه'}
{'score': 0.9994200468063354, 'start': 50, 'end': 67, 'answer': 'پردازش زبان طبیعی'}


In [40]:
trainer.save_model('/content/drive/MyDrive/PersianQA')

In [41]:
!zip -r /content/drive/MyDrive/PersianQA/PersianQA.zip '/content/drive/MyDrive/PersianQA'

  adding: content/drive/MyDrive/PersianQA/ (stored 0%)
  adding: content/drive/MyDrive/PersianQA/config.json (deflated 47%)
  adding: content/drive/MyDrive/PersianQA/pytorch_model.bin (deflated 8%)
  adding: content/drive/MyDrive/PersianQA/tokenizer_config.json (deflated 45%)
  adding: content/drive/MyDrive/PersianQA/special_tokens_map.json (deflated 42%)
  adding: content/drive/MyDrive/PersianQA/vocab.txt (deflated 62%)
  adding: content/drive/MyDrive/PersianQA/tokenizer.json (deflated 72%)
  adding: content/drive/MyDrive/PersianQA/training_args.bin (deflated 49%)
