## Start


In [None]:
# Step 1: Install Required Libraries
!pip install pandas datasets --quiet

In [None]:
# Step 2: Import and Load Data
import pandas as pd
import re
from datasets import Dataset

# Load your CSV from Google Drive or upload manually
file_path = "/content/drive/MyDrive/Kenya_Quantised_LLM/data/train.csv"  # Replace with your path or use file upload
t_file_path = "/content/drive/MyDrive/Kenya_Quantised_LLM/data/test.csv"  # Replace with your path or use file upload
raw_df = pd.read_csv(file_path)
t_raw_df = pd.read_csv(t_file_path)

In [None]:
# Step 3: Define Cleaning & Formatting Function
def restructure_prompt(prompt):
    # Remove self-introduction
    prompt = re.sub(r"^I am a nurse.*?in Kenya\.\s*", "", prompt, flags=re.DOTALL)

    # Split by 'Questions:' or 'Questions' block
    if "Questions:" in prompt:
        parts = prompt.split("Questions:", maxsplit=1)
    elif "Questions" in prompt:
        parts = prompt.split("Questions", maxsplit=1)
    else:
        parts = [prompt, ""]

    vignette = parts[0].strip()
    questions = parts[1].strip()

    return (
        "<|system|>\n"
        "You are a highly experienced Kenyan clinical nurse. Provide clear, concise, and empathetic answers.\n"
        "<|user|>\n"
        f"{vignette}\n\nQuestions:\n{questions}\n"
        "<|assistant|>"
    )

In [None]:

# Step 4: Apply to Dataset
raw_df["input_text"] = raw_df["Prompt"].apply(restructure_prompt)
raw_df["output_text"] = raw_df["Clinician"]

t_raw_df["input_text"] = t_raw_df["Prompt"].apply(restructure_prompt)

In [None]:
# Step 5: Create HuggingFace Dataset for Fine-tuning
hf_dataset = Dataset.from_pandas(raw_df[["input_text", "output_text"]])
hf_dataset.to_csv("/content/drive/MyDrive/Kenya_Quantised_LLM/data/finetune_dataset.csv")
hf_t_dataset = Dataset.from_pandas(t_raw_df[["input_text"]])
hf_t_dataset.to_csv("/content/drive/MyDrive/Kenya_Quantised_LLM/data/test_dataset.csv")

## Training

In [None]:
# Step 1: Install Required Libraries
!pip install transformers peft bitsandbytes accelerate datasets --quiet

In [None]:
# Step 2: Imports
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import LoraConfig, get_peft_model, TaskType
import pandas as pd
from datasets import Dataset
import torch

In [None]:
# Step 3: Load Local Fine-tune and Test Datasets
train_path = "/content/drive/MyDrive/Kenya_Quantised_LLM/data/finetune_dataset.csv"
test_path = "/content/drive/MyDrive/Kenya_Quantised_LLM/data/test_dataset.csv"

train_df = pd.read_csv(train_path)
test_df = pd.read_csv(test_path)

# Convert to HuggingFace Datasets
train_dataset_full = Dataset.from_pandas(train_df[["input_text", "output_text"]])
test_dataset = Dataset.from_pandas(test_df[["input_text"]])

# Split train into train/validation
train_test_split = train_dataset_full.train_test_split(test_size=0.1)
train_dataset = train_test_split["train"]
val_dataset = train_test_split["test"]

In [None]:
# Step 4: Load Tokenizer and Model (TinyLLaMA example)
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto"
)

In [None]:
# Step 5: Apply LoRA
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

model = get_peft_model(model, lora_config)

In [None]:
# Step 6: Tokenize
def tokenize(example):
    return tokenizer(
        example["input_text"],
        text_target=example["output_text"],
        truncation=True,
        padding="max_length",
        max_length=512
    )

train_dataset = train_dataset.map(tokenize, batched=True)
val_dataset = val_dataset.map(tokenize, batched=True)
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

In [None]:
# Step 7: TrainingArguments
training_args = TrainingArguments(
    output_dir="qlora-tinyllama-clinician",
    per_device_train_batch_size=2,
    num_train_epochs=5,
    learning_rate=5e-5,
    bf16=torch.cuda.is_available(),
    logging_steps=10,
    save_strategy="epoch",
    report_to="none"
)

In [None]:
# Step 8: Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

trainer.train()

In [None]:
# Step 9: Save Final Model
model.save_pretrained("qlora-tinyllama-clinician-final")
tokenizer.save_pretrained("qlora-tinyllama-clinician-final")
print("✅ Fine-tuned model saved to qlora-tinyllama-clinician-final")


## Basic Model

In [None]:
!pip install transformers datasets peft accelerate bitsandbytes --quiet


In [None]:
import pandas as pd
from datasets import Dataset

In [None]:
df = pd.read_csv("/content/drive/MyDrive/Kenya_Quantised_LLM/data/train.csv")
df.rename({"Prompt": "input_text", "Clinician": "output_text"}, inplace=True, axis=1)
df = df[["input_text", "output_text"]]
df.head()

In [None]:
dataset = Dataset.from_pandas(df)
dataset = dataset.train_test_split(test_size=0.1)

train_dataset = dataset["train"]
test_dataset = dataset["test"]

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"  # or your custom quantized model

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto"
)


In [None]:
from peft import LoraConfig, get_peft_model, TaskType

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    bias="none"
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()


In [None]:
def tokenize_function(example):
    return tokenizer(
        example["input_text"],
        text_target=example["output_text"],
        padding="max_length",
        truncation=True,
        max_length=256
    )

train_dataset = train_dataset.map(tokenize_function, batched=True)
eval_dataset = test_dataset.map(tokenize_function, batched=True)


In [None]:
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=1,
    logging_dir="./logs",
    eval_strategy="epoch",
    save_strategy="epoch",
    fp16=True,
    save_total_limit=1,
    report_to="none"
)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()


### Inference

In [None]:
test_df = pd.read_csv("/content/drive/MyDrive/Kenya_Quantised_LLM/data/test.csv")

test_df.rename({"Prompt": "input_text"}, inplace=True, axis=1)

def tokenize_for_inference(example):
    return tokenizer(
        example["input_text"],
        padding="max_length",
        truncation=True,
        max_length=256,
        return_tensors=None  # Must return dict, not tensors
    )

tokenized_test_dataset = test_dataset.map(tokenize_for_inference)


In [None]:
from tqdm import tqdm

model.eval()
model = model.to("cuda")

generated_output = []

for example in tqdm(tokenized_test_dataset):
    input_ids = tokenizer(example["input_text"], return_tensors="pt", truncation=True, max_length=256).input_ids.to("cuda")

    output_ids = model.generate(
        input_ids = input_ids,
        max_new_tokens=100,
        do_sample=False,
        temperature=0.7
    )

    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    generated_output.append(output_text)

In [None]:
generated_output

## Trial 3

In [None]:
import pandas as pd
from datasets import Dataset
import re

In [None]:
df = pd.read_csv("/content/drive/MyDrive/Kenya_Quantised_LLM/data/train_raw.csv")
df.head()

In [None]:
def restructure_prompt(prompt):
    import re

    # Step 1: Match the nurse intro sentence (up to first period after "Kenya")
    match = re.match(r"(I am a nurse.*?in .*?county in Kenya\.)\s*(.*)", prompt, flags=re.IGNORECASE | re.DOTALL)

    if match:
        system_msg = match.group(1).strip()
        main_prompt = match.group(2).strip()
    else:
        # Fallback if pattern isn't found
        system_msg = "You are a highly experienced Kenyan clinical nurse."
        main_prompt = prompt.strip()

    return (
        f"### SYSTEM\n{system_msg}\n\n"
        f"### PROMPT\n{main_prompt}\n\n"
        f"### RESPONSE\n"
    )

df["input_text"] = df["Prompt"].apply(restructure_prompt)
df["output_text"] = df["Clinician"]

df = df[["input_text", "output_text"]]
df.head()


In [None]:
for i in range(10):
    print(f"case {i}: \n{df['input_text'][i]}\n\n")

In [None]:
dataset= Dataset.from_pandas(df)
dataset = dataset.train_test_split(test_size=0.1)

train_dataset= dataset["train"]
test_dataset= dataset["test"]

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"  # or your custom quantized model

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto"
)


In [None]:
from peft import LoraConfig, get_peft_model, TaskType

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    bias="none"
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()


In [None]:
def tokenize_function(example):
    return tokenizer(
        example["input_text"],
        text_target=example["output_text"],
        padding="max_length",
        truncation=True,
        max_length=512
    )

train_dataset = train_dataset.map(tokenize_function, batched=True)
eval_dataset = test_dataset.map(tokenize_function, batched=True)


In [None]:
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=5,
    logging_dir="./logs",
    eval_strategy="epoch",
    save_strategy="epoch",
    fp16=True,
    save_total_limit=1,
    report_to="none"
)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()


### Inference

In [None]:
test_df = pd.read_csv("/content/drive/MyDrive/Kenya_Quantised_LLM/data/test.csv")

test_df.rename({"Prompt": "input_text"}, inplace=True, axis=1)

def tokenize_for_inference(example):
    return tokenizer(
        example["input_text"],
        padding="max_length",
        truncation=True,
        max_length=256,
        return_tensors=None  # Must return dict, not tensors
    )

tokenized_test_dataset = test_dataset.map(tokenize_for_inference)

In [None]:
from tqdm import tqdm

model.eval()
model = model.to("cuda")

generated_output = []

for example in tqdm(tokenized_test_dataset):
    input_ids = tokenizer(example["input_text"], return_tensors="pt", truncation=True, max_length=256).input_ids.to("cuda")

    output_ids = model.generate(
        input_ids = input_ids,
        max_new_tokens=100,
        do_sample=False,
        temperature=0.7
    )

    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    generated_output.append(output_text)

In [None]:
generated_output

## Trial 4

In [None]:
!pip install trl transformers datasets peft accelerate bitsandbytes evaluate --quiet

In [None]:
import pandas as pd
from datasets import Dataset
import re

df = pd.read_csv("/content/drive/MyDrive/Kenya_Quantised_LLM/data/train_raw.csv")
df.head()

In [None]:
def restructure_prompt(prompt):
    import re

    # Step 1: Match the nurse intro sentence (up to first period after "Kenya")
    match = re.match(r"(I am a nurse.*?in .*?county in Kenya\.)\s*(.*)", prompt, flags=re.IGNORECASE | re.DOTALL)

    if match:
        system_msg = match.group(1).strip()
        main_prompt = match.group(2).strip()
    else:
        # Fallback if pattern isn't found
        system_msg = "You are a highly experienced Kenyan clinical nurse."
        main_prompt = prompt.strip()

    return (
        f"### SYSTEM\n{system_msg}\n\n"
        f"### PROMPT\n{main_prompt}\n\n"
        f"### RESPONSE\n"
    )


df["input_text"] = df["Prompt"].apply(restructure_prompt)
df["output_text"] = df["Clinician"]

df = df[["input_text", "output_text"]]
df.head()


In [None]:
dataset= Dataset.from_pandas(df)
dataset = dataset.train_test_split(test_size=0.1)

train_dataset= dataset["train"]
eval_dataset= dataset["test"]

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"  # or your custom quantized model

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto"
)


In [None]:
# Define SFTT trainer with QLoRA
from peft import LoraConfig
from trl import SFTTrainer
from transformers import TrainingArguments

peft_config = LoraConfig(
    r=16,
    lora_alpha=64,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    logging_dir="./logs",
    num_train_epochs=15,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    learning_rate=2e-4,
    fp16=True,
    report_to="none"
)

In [None]:
!pip install rouge_score --quiet

In [None]:
# Train with SFTTrainer

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    args=training_args,
    formatting_func=lambda ex: f"{ex['input_text']}{ex['output_text']}",
    # compute_metrics=compute_metrics
)

trainer.train()

### Inference

In [None]:
test_df = pd.read_csv("/content/drive/MyDrive/Kenya_Quantised_LLM/data/test.csv")
# test_df.rename(columns={"Prompt": "input_text"}, inplace=True)

In [None]:
import re
test_df["input_text"] = test_df["Prompt"].apply(restructure_prompt)


In [None]:
import re

def format_prompt(row):
    prompt = row.get("input_text", "").strip()
    competency = row.get("Nursing Competency", "general practice")
    panel = row.get("Clinical Panel", "general clinical care")

    match = re.match(r"(I am a nurse.*?in .*?county in Kenya\.)\s*(.*)", prompt, flags=re.IGNORECASE | re.DOTALL)
    if match:
        system_intro = match.group(1).strip()
        main_prompt = match.group(2).strip()
    else:
        system_intro = "You are a highly experienced Kenyan clinical nurse."
        main_prompt = prompt.strip()

    system_message = (
        f"{system_intro} I specialize in '{competency}' and I am to provide a response for the '{panel}' clinical panel."
    )

    return (
        f"### SYSTEM\n{system_message}. Provide clear, structured, and comprehensive answers. Include diagnosis, recommended investigations, and step-by-step management.\n\n"
        f"### PROMPT\n{main_prompt}\n\n"
        f"### RESPONSE\n"
    )


In [None]:
from tqdm import tqdm
import torch

model.eval()
device = torch.device("cuda")
model.to(device)

generated_output = []

for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
    formatted_prompt = format_prompt(row)

    inputs = tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(device)

    with torch.no_grad():
        output_ids = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=250,
            do_sample=True,
            temperature=0.1,
            top_p=0.9,
            repetition_penalty=1.1
        )

    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    generated_output.append(output_text)


In [None]:
 generated_output

In [None]:
test_df["Generated_Response"] = generated_output
# test_df.to_csv("test_predictions.csv", index=False)

test_df.head()

In [None]:
def extract_response(text):
    if "### RESPONSE" in text:
        return text.split("### RESPONSE", 1)[-1].strip()
    else:
        return text.strip()  # fallback if not formatted

test_df["Cleaned_Response"] = test_df["Generated_Response"].apply(extract_response)


In [None]:
test_df.rename({"Cleaned_Response": "Clinician"}, inplace=True, axis=1)
test_df.head()

In [None]:
for i in range(10):
    print(f"case {i}: ###CASE\n{test_df['input_text'][i]}\n###RESPONSE\n{test_df['Clinician'][i]}\n\n")

In [None]:
test_df[["Master_Index", "Clinician"]].to_csv("/content/drive/MyDrive/Kenya_Quantised_LLM/submissions/qlora_submission_1.csv", index=False)

## Trial 5

In [None]:
!pip install trl transformers datasets peft accelerate bitsandbytes evaluate --quiet

In [None]:
import pandas as pd
from datasets import Dataset
import re

df = pd.read_csv("/content/drive/MyDrive/Kenya_Quantised_LLM/data/train_raw.csv")
df.head()

In [None]:
def restructure_prompt(prompt):
    import re

    # Step 1: Match the nurse intro sentence (up to first period after "Kenya")
    match = re.match(r"(I am a nurse.*?in .*?county in Kenya\.)\s*(.*)", prompt, flags=re.IGNORECASE | re.DOTALL)

    if match:
        system_msg = match.group(1).strip()
        main_prompt = match.group(2).strip()
    else:
        # Fallback if pattern isn't found
        system_msg = "You are a highly experienced Kenyan clinical nurse."
        main_prompt = prompt.strip()

    return (
        f"### SYSTEM\n{system_msg}. Ensure you match the diagnosis with symptoms\n\n"
        f"### PROMPT\n{main_prompt}\n\n"
        f"### RESPONSE\n"
    )


df["input_text"] = df["Prompt"].apply(restructure_prompt)
df["output_text"] = df["Clinician"]

df = df[["input_text", "output_text"]]
df.head()



In [None]:
for i in range(10):
    print(f"case {i}: \n{df['input_text'][i]}\n\n")

In [None]:
dataset= Dataset.from_pandas(df)
dataset = dataset.train_test_split(test_size=0.1)

train_dataset= dataset["train"]
eval_dataset= dataset["test"]

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"  # or your custom quantized model

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto"
)


In [None]:
# Define SFTT trainer with QLoRA
from peft import LoraConfig
from trl import SFTTrainer
from transformers import TrainingArguments

lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj", "v_proj", "k_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ]
)

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=5e-5,
    num_train_epochs=3,
    lr_scheduler_type="cosine",
    warmup_steps=100,
    logging_dir="./logs",
    save_total_limit=1,
    eval_strategy="epoch",
    logging_strategy="steps",
    logging_steps=50,
    save_strategy="epoch",
    push_to_hub=False,
    bf16=True  # if supported
)

In [None]:
!pip install rouge_score

In [None]:
import evaluate

rouge = evaluate.load("rouge")

def compute_metrics(eval_preds):
    preds, labels = eval_preds

    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Strip and align
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]

    # Compute ROUGE
    results = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    # Optional: Round results
    results = {k: round(v, 4) for k, v in results.items()}
    return results


In [None]:
def tokenize_example(example):
    return tokenizer(
        example["input_text"],
        text_target=example["output_text"],
        padding="max_length",
        truncation=True,
        max_length=512
    )

# Apply to both train and eval
tokenized_train_dataset = train_dataset.map(tokenize_example, batched=True)
tokenized_eval_dataset = eval_dataset.map(tokenize_example, batched=True)

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True,
    return_tensors="pt"
)

In [None]:
import torch
import torch.nn.functional as F

from transformers import Trainer

class CustomSFTTrainer(SFTTrainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits

        # Shift so that tokens <n> predict <n+1>
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        loss = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=tokenizer.pad_token_id,
        )

        return (loss, outputs) if return_outputs else loss


In [None]:
# Train with SFTTrainer

from trl import SFTTrainer

trainer = CustomSFTTrainer(
    model=model,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_eval_dataset,
    args=training_args,
    peft_config=lora_config,
    data_collator=data_collator,
    compute_metrics=compute_metrics  # Optional
)

trainer.train()