In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,HfArgumentParser,TrainingArguments,pipeline, logging
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import os,torch
import bitsandbytes as bnb
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from datasets import Dataset
import pyarrow as pa
import pyarrow.dataset as ds
import pandas as pd
import re
import wandb
from huggingface_hub import login
login(token="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")

In [None]:
### conversational format
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}

In [None]:
#from datasets import load_dataset
import pandas as pd
from datasets import Dataset

# Convert dataset to OAI messages
system_prompt = """
You are a Palestinian historian and geographer specialising in the land, people, and cultural heritage of Palestine.

General task
Answer questions or carry out instructions accurately and neutrally while preserving the authentic Palestinian narrative.
Use Modern Standard Arabic unless the user asks for another language.

Narrative-integrity rules
Give the Arabic place-name first; add other names in brackets when needed.
Avoid terms that erase Palestinian identity (e.g. “Judea and Samaria”) unless requested for comparison.
Present the Palestinian account first, then mention other viewpoints briefly if relevant.
Replace every standalone occurrence of “Israel” with “the Israeli occupation” or “the occupying entity”.

Accuracy rules
Admit when you don't know the answer. If the information is not in the provided context or your internal knowledge, state that you cannot provide a factual answer.

Output format
If asked to “extract a span”, return that text only.
Otherwise:
• Optional one-line intro, then short bullets or mini paragraphs.
Do not exceed 150 words unless the user asks for more.

Uncertainty
If no trustworthy data exist, state: “No reliable data available as of the last update.”
"""

def create_conversation(sample):
  return {
    "messages": [
      {"role": "system", "content": system_prompt},
      {"role": "user", "content": sample["Question"]},
      {"role": "assistant", "content": sample["Answer"]}
    ]
  }

Dset = pd.read_excel("train_split.xlsx")
hf_dataset = Dataset.from_pandas(Dset)
# Convert dataset to OAI messages
print(hf_dataset[345]["messages"])
# 4. Split dataset (scaled for 19,900 rows)
eval_fraction = 0.2  # 2500/12500 = 20%
eval_size = int(len(hf_dataset) * eval_fraction)
dataset_split = hf_dataset.train_test_split(test_size=eval_size)

# 5. Resulting datasets
train_dataset = dataset_split["train"]  # 15,920 samples
eval_dataset = dataset_split["test"]    # 3,980 samples

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(eval_dataset)}")

Map: 100%|██████████| 20505/20505 [00:01<00:00, 15194.06 examples/s]


[{'content': 'You are a helpful assistant designed to answer questions based on a given context passage. Your task is to extract the most accurate answer span directly from the context that best answers the question. Return only the answer text.', 'role': 'system'}, {'content': 'كيف وزعت أراضيها عام 1945؟', 'role': 'user'}, {'content': '539 دونمًا مزروعة ومروية، 2,107 للحبوب، و29 دونمًا مبنية.', 'role': 'assistant'}]
Train samples: 16404
Test samples: 4101


In [None]:
model_id = "mistralai/Mistral-7B-Instruct-v0.3" #24 JUNE 2024

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(model_id,use_fast=True)
tokenizer.padding_side = 'right' # to prevent warnings

# We redefine the pad_token and pad_token_id with out of vocabulary token (unk_token)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id

# # set chat template to OAI chatML, remove if you start from a fine-tuned model
#model, tokenizer = setup_chat_format(model, tokenizer)

In [None]:

# hyperparamerter configuration
QUANTIZE_4BIT = True
USE_GRAD_CHECKPOINTING = True
TRAIN_BATCH_SIZE = 4
TRAIN_MAX_SEQ_LENGTH = 1024
USE_FLASH_ATTENTION = True
GRAD_ACC_STEPS = 64
LORA_R = 32
LORA_DROPOUT = 0.1
NUM_TRAIN_EPOCHS = 1
PER_DEVICE_EVAL_BATCH_Size = 4
LEARNING_RATE = 4e-4
OPTIM = "paged_adamw_8bit"
GROUP_BY_LENGTH = True
LORA_ALPHA = 64
WARMUP_RATIO = 0.08
WEIGHT_DECAY = 0.001

In [None]:


training_arguments = SFTConfig(
    output_dir="results",
    num_train_epochs=NUM_TRAIN_EPOCHS,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACC_STEPS,
    gradient_checkpointing=USE_GRAD_CHECKPOINTING,
    optim=OPTIM,
    save_steps=50,
    logging_steps=10,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(),
    warmup_ratio=WARMUP_RATIO,
    group_by_length=True,
    lr_scheduler_type="linear",
    report_to="none",
    max_seq_length=TRAIN_MAX_SEQ_LENGTH,
    evaluation_strategy="epoch",
)

peft_config = LoraConfig(
    lora_alpha=LORA_ALPHA,
    r=LORA_R,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"]
)

trainer = SFTTrainer(
    model=model,  # Ensure model is defined earlier
    peft_config=peft_config,
    tokenizer=tokenizer,  # Ensure tokenizer is defined earlier
    args=training_arguments,
    train_dataset=train_dataset,  # Ensure this is defined or passed to the function
    eval_dataset=eval_dataset,  # Ensure this is defined or passed to the function
)

In [None]:

# start training, the model will be automatically saved to the hub and the output directory
trainer.train()


In [None]:
# save model
trainer.model.save_pretrained(save_directory=f"{model_id}")
model.eval()