In [8]:
!pip install transformers datasets accelerate peft



In [12]:
from transformers import BertTokenizerFast, BertForQuestionAnswering, TrainingArguments, Trainer, EvalPrediction, AutoModelForQuestionAnswering
from datasets import load_dataset
from peft import PeftModelForQuestionAnswering, get_peft_config
from collections import Counter
import re
import string

# Load the SQuAD v2 dataset
dataset = load_dataset('squad_v2')
train_dataset = dataset['train']
eval_dataset = dataset['validation']

# Load the fast tokenizer
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

# Tokenization function
def tokenize_function(examples):
    encodings = tokenizer(
        examples['question'],
        examples['context'],
        truncation=True,
        padding='max_length',
        max_length=512,
        return_offsets_mapping=True
    )

    start_positions = []
    end_positions = []

    for i, (context, answer) in enumerate(zip(examples['context'], examples['answers'])):
        start_position = None
        end_position = None

        if answer['answer_start']:
            start_idx = answer['answer_start'][0]
            end_idx = start_idx + len(answer['text'][0])

            offset_mapping = encodings['offset_mapping'][i]

            for j, (offset_start, offset_end) in enumerate(offset_mapping):
                if offset_start <= start_idx and offset_end > start_idx:
                    start_position = j
                if offset_start < end_idx and offset_end >= end_idx:
                    end_position = j
                    break

            if start_position is not None and end_position is not None:
                start_positions.append(start_position)
                end_positions.append(end_position)
            else:
                start_positions.append(0)
                end_positions.append(0)
        else:
            start_positions.append(0)
            end_positions.append(0)

    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
    return encodings



small_train_dataset = train_dataset.select(range(5000))
small_eval_dataset = eval_dataset.select(range(1000))

# Tokenize the dataset first
train_dataset = small_train_dataset.map(tokenize_function, batched=True)
eval_dataset = small_eval_dataset.map(tokenize_function, batched=True)



# LoRA Configuration
config = {
    "peft_type": "LORA",
    "task_type": "QUESTION_ANS",
    "inference_mode": False,
    "r": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.05,
    "fan_in_fan_out": False,
    "bias": "none",
    # Added the specific target modules as per your previous config
    "target_modules": [
        "bert.encoder.layer.0.attention.self.query",
        "bert.encoder.layer.0.attention.self.key",
        "bert.encoder.layer.0.attention.self.value",
    ]
}

peft_config = get_peft_config(config)
model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")
peft_model = PeftModelForQuestionAnswering(model, peft_config)
peft_model.print_trainable_parameters()


# Training arguments
training_args = TrainingArguments(
    evaluation_strategy="epoch",
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
)



# Normalization Function
def normalize_answer(s):
    """Lower text and remove punctuation, articles, and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

# F1 Score Calculation
def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if len(prediction_tokens) == 0 or len(ground_truth_tokens) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(prediction_tokens == ground_truth_tokens)
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

# Exact Match Score Calculation
def exact_match_score(prediction, ground_truth):
    return int(normalize_answer(prediction) == normalize_answer(ground_truth))

def compute_metrics(eval_pred):
    predictions_tuple, labels_tuple = eval_pred.predictions, eval_pred.label_ids
    start_logits, end_logits = predictions_tuple
    start_positions, end_positions = labels_tuple

    f1 = 0.0
    exact_match = 0

    for i in range(len(start_positions)):
        start_pred = start_logits[i].argmax()
        end_pred = end_logits[i].argmax()

        pred_ans = tokenizer.decode(eval_dataset[i]['input_ids'][start_pred:end_pred + 1])
        true_ans = tokenizer.decode(eval_dataset[i]['input_ids'][start_positions[i]:end_positions[i] + 1])

        f1 += f1_score(pred_ans, true_ans)
        exact_match += exact_match_score(pred_ans, true_ans)

    return {'f1': f1/len(start_positions), 'exact_match': exact_match/len(start_positions)}

# Trainer
trainer = Trainer(
    model=lora_model,  # Use the LoRA-enhanced model
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics
)

# Fine-tune
trainer.train()


Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 76,804 || all params: 108,968,452 || trainable%: 0.0704827852376943


OutOfMemoryError: ignored